Elixir FashionMNIST Challenge

Elixir State of the Art Challenge using FashionMNIST

Run in Livebook

Mix.install([
  {:axon, "~> 0.5.1"},
  {:exla, "~> 0.5.3"},
  {:req, "~> 0.3.10"},
  {:scidata, "~> 0.1.10"}
])

Introduction

This livebook is inspired by the Classifying handwritten digits notebook in the Axon documentation. FashionMNIST was designed as a drop-in replacment for the MNIST dataset. Instead of digits, there are grey scale images of clothing types. Like MNIST, there are 10 kinds of images. FashionMNIST was designed as a harder problem than the digits dataset. You can check the difficulty by running this notebook for 10 epochs. Notice the training accuracy will be lower than the corresponding MNIST notebook when using the exact same model and epochs.

State of the Art

In a December tweet, Jeremy Howard created a challenge for the machine learning community. Can anyone beat his accuracy in 5, 20 or 50 epochs. The challenge’s epoch accuracy approach is open to the community and inclusive because the compute requirements are broader. It doesn’t matter whether you are running on an NVidia 1060, 4080, or some GPU in the cloud. In fact, because the problem is small enough, you can even use your CPU and patience. A CPU cloud resource can be used on a free Huggingface Space or Fly.io. If you only have a CPU, be sure to use the EXLA or TorchX backends because they are faster than the pure Elixir default.

One implied rule that isn’t written in Jeremy’s challenge, the model must be trained using only the original FashionMNIST training dataset. Participants can’t add any extra images to the training set. For example, you can’t use generative AI to create new fashion training data images.

Leaderboard (Accuracy) on 12/15/2022

Using Axon, we should be able to match those mid December numbers. The techniques that Jeremy used can be built in the Nx family of libraries. The foundations for the necessary tools and techniques are in the Axon, Nx, Kino, and NxImage libraries. Going through training resources, and hints I’ll provide, should allow participants to improve the score. Try implementing one techique and share your results. If you improve the accuracy, I’ll add you to the leaderboard. I’ll also keep track of everyone who has been on the leaderboard.

By competing with each other and sharing, we’ll all learn the best techniques for building a State of the Art model in Elixir. Also, I strongly recommend sharing techniques that you try that don’t improve the leaderboard. If you try something, you learn something. When you share, everyone learns something.

If we can match the numbers, then we might be able to get close to the current leaderboard. But let’s try the 12/15 leaderboard first.

Hyperparameters

Hyperparameters in machine learning are choices the developer makes that shape the training of a model. However, what model to use is one of those choices but it isn’t a simple hyperparameter. Let’s create a Map with our simple parameter choices. It should make it easier to see some key choices that we are making. We can then reference the choices later in our notebook. When you add a new technique, you are probably going to make some hyperparameter choices. Please add your choices to this datastructure. When we get further along, I plan upon sharing the reasoning for a separate hyperparameter data structure.

hyperparams = %{
  epochs: 5,
  batch_size: 32
}

Retrieving and exploring the dataset

The Fashion MNIST dataset is available for free online. The Elixir SciData library provides an easy technique to access the training and test datasets.

{train_images, train_labels} = Scidata.FashionMNIST.download()
# Normalize and batch images
{images_binary, images_type, images_shape} = train_images

batched_images =
  images_binary
  |> Nx.from_binary(images_type)
  |> Nx.reshape(images_shape)
  |> Nx.divide(255)
  |> Nx.to_batched(hyperparams[:batch_size])
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels

batched_labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(hyperparams[:batch_size])

Defining the model

We’ll use the same model from the MNIST example. By starting with an extremely simple model, I’ve left room for challenge participants to try different models. Remember, the models have to start with random weights. Pre-trained models can’t be used on the leaderboard. However, you can learn from trying a pre-trained model. Check out Sean’s Machine Learning for Elixir book for an example.

model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

All Axon models start with an input layer to tell subsequent layers what shapes to expect. We then use Axon.flatten/2 which flattens the previous layer by squeezing all dimensions but the first dimension into a single dimension. Our model consists of 2 fully connected layers with 128 and 10 units respectively. The first layer uses :relu activation which returns max(0, input) element-wise. The final layer uses :softmax activation to return a probability distribution over the 10 labels.

Training

In Axon we express the task of training using a declarative loop API. First, we need to specify a loss function and optimizer, there are many built-in variants to choose from. In this example, we’ll use categorical cross-entropy and the Adam optimizer. We will also keep track of the accuracy metric. Finally, we run training loop passing our batched images and labels. We’ll train for 10 epochs using the EXLA compiler.

Based upon the results of PyTorch challenge from last winter, every leaderboard change overtook the others for all 3 epoch levels. Five epochs is enough to experiment with different model and training approaches. If 5 epochs is more accurate than the current leaderboard, then try the 20 and 50 epochs for completeness

trained_model_params =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.run(Stream.zip(batched_images, batched_labels), %{},
    epochs: hyperparams[:epochs],
    compiler: EXLA
  )

Comparison with the test data leaderboard

Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.

Let’s get the test data.

{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, _test_images_type, test_images_shape} = test_images

test_batched_images =
  test_images_binary
  |> Nx.from_binary(images_type)
  |> Nx.reshape(test_images_shape)
  |> Nx.divide(255)
  |> Nx.to_batched(hyperparams[:batch_size])
# One-hot-encode and batch labels
{test_labels_binary, _test_labels_type, _shape} = test_labels

test_batched_labels =
  test_labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(hyperparams[:batch_size])

Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.

Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(
  Stream.zip(test_batched_images, test_batched_labels),
  trained_model_params,
  compiler: EXLA
)

Challenge: #ElixirFashionML

#ElixirFashionMLChallenge Leaderboard (Accuracy) on 7/30/2023

We have an 5 epoch accuracy of 87.4% vs Jeremy’s 12/15 accuracy of 92.7%. That should plenty of opportunities for the community to leap to the top of the leaderboard

How can you beat this initial result?

I’ll provide a quick set of resources and expand upon important resources at a later time. For now, start reading, try various Livebook notebooks, and watch some videos.

Resources

We highly recommend purchasing Sean Moriarity’s book, Machine Learning in Elixir. He and Jose’ started the Elixir numerical compute capability. The book explains many important concepts about training models in Elixir.

Nicolo` G created a batch of Livebook notebooks that translated Python book examples into Nx. The notebooks can be found on his Github account <!– livebook:{“break_markdown”:true} –> The techniques to achieve the SOTA are taught in the Fast.ai Part 2 course. There are three parts of the course: StableDiffusion, Deep Learning Foundations, StableDiffusion from scratch. Deep Learching Foundations focused on the skills for this challenge and are found in Lessons 10.second_half through 19.first_half. About 18 hours of videos on the PyTorch implementation to reach the SOTA numbers. I have some Livebook notebooks for Lesson 10-11. I struggled for a while on translating the object oriented concepts into a similar approach in Elixir before I decided that the object-oriented abstractions probably weren’t worth translating. The calculations and tools are important though. Axon and Kino have elements that can provide some of the same ease of use as Fast.ai. Kino can be used, with Axon, to create the visualizations and tools that are created in the course. Axon and NxImage have elements that can be combined to create other capabilities taught in the course. I’ll have more thoughts and hints to share soon. ## Why is it important for Elixir folks to try to beat Jeremy’s 12/15 SOTA values By implementing the techniques from Fast.ai’s Lessons 10-18, we will be learning how to train a very accurate model using lower compute costs. When a business is trying to use a model in production, normally businesses want the best performing model that fits the problem constraints. By learning techniques to improve model performance while also reducing the compute training requirements, we help reduce costs and have a better chance of meeting business goals. <!– livebook:{“break_markdown”:true} –> While it may seem that FashionMNIST is a simple problem, all of the techniques used to reach SOTA were originally combined in the 2018 DawnBench competition. Fast.ai students in a study group teamed up to compete against well funded companies and came in second place (Imagenet) and first place (CIFAR). Unlike this competition, the DawnBench competition was a time-based competition.

fastai, livebook, axon, foundations, xla, deep_learning