Along the Axon

Exploring Elixir Machine Learning

Classifying Simple Fashion Types - Sean Moriarity Run in Livebook

  {:axon, "~> 0.5.1"},
  {:exla, "~> 0.5.3"},
  {:req, "~> 0.3.10"},
  {:scidata, "~> 0.1.10"}

Elixir FashionMNIST Challenge

I challenged an Elixir community to an Elixir Fashion MNIST challenge, The idea was derived from a Twitter post by Jeremy Howard. Jeremy was teaching his Deep Learning from the Foundations 2022 course. He used the Fashion MNIST dataset as an accessible deep learning problem. By using accuracy, anyone with patience and a CPU could try to compete for a better accuracy measure. Having a GPU would be faster, but using the CPU works pretty well.

Sean Moriarity’s Blog Post

Sean created an excellent Dockyard blog post inviting Elixir folks to join the Elixir FashionMNIST Challenge. In his post, he presented his initial approach to improving upon my initial baseline. He achieved a 90.7% accuracy on the test set. However, he didn’t link to a Livebook implementation. This notebook attempts to match his blog post in a Livebook notebook that you can use to repeat his results. I also tried the other two epoch sizes and added Sean to the leaderboard.


Hyperparameters in machine learning are choices the developer makes that shape the training of a model. However, what model to use is one of those choices but it isn’t a simple hyperparameter. Let’s create a map with our simple parameter choices. It should make it easier to see some key training choices. We can then reference the choices later in our notebook.

hyperparams = %{
  epochs: 5,
  batch_size: 32

Retrieving and exploring the dataset

The Fashion MNIST dataset is available for free online. The Elixir SciData library provides an easy technique to access the training and test datasets.

{train_images, train_labels} =
# Normalize and batch images
{images_binary, images_type, images_shape} = train_images

train_images =
  |> Nx.from_binary(images_type)
  |> Nx.reshape({:auto, 28, 28, 1})
  |> Nx.divide(255)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels

train_labels =
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))


The next cell was one of the key improvements that Sean made. The pipeline concept provides a place for dynamically modifying the original dataset. Without having to capture more data, augmenting input data improves the quality of the data. When working on a machine learning problem, the business team often asks how much data is required. There isn’t a magic number, or rule of thumb, for how much data is required. However, useful augmentation of the data can improve the quality of the data at a low acquisition cost.

seed = 42

{batched_images, _} =
  |> Nx.to_batched(hyperparams[:batch_size])
  |> Enum.map_reduce(Nx.Random.key(seed), fn batch, key ->
    fun =
        fn regular, key ->
          {mask, key} = Nx.Random.uniform(key)
          flipped = Nx.reverse(regular, axes: [1])
          augmented =, 0.5), flipped, regular)
          {augmented, key}
        compiler: EXLA

    fun.(batch, key)

batched_labels =
  |> Nx.to_batched(hyperparams[:batch_size])
  |> Enum.to_list()

Defining the model

Sean created a custom, but small, convolutional neural network (CNN). Two convolutional blocks with max pooling. This is an incremental improvement over the initial model. You can experiment with other model approaches to improve the score.

model =
  |> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu)
  |> Axon.max_pool(kernel_size: 2)
  |> Axon.conv(64, kernel_size: {3, 3}, padding: :same, activation: :relu)
  |> Axon.max_pool(kernel_size: 2)
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)


Sean also improved the training process. He added a learning rate scheduler instead of the constant default learning rate from the Adam optimizer. Learning rate schedules are a fantastic area for experimentation. Sean used 1850 transition steps which is just under one epoch at a batch size of 32. 60,000/32 = 1,875.

What happens when the decay rate is changed? What is the best nominal learning rate? What happens with the transition steps are increased or decreased? Is a decay optimal or should it be an increasing learning rate? What about a one-cycle approach? These are all hyperparameter decisions that can be varied to improve the model test set accuracy.

Based upon the results of PyTorch challenge from last winter, every leaderboard change overtook the others for all 3 epoch levels. Five epochs is enough to experiment with different model and training approaches. If 5 epochs is more accurate than the current leaderboard, then try the 20 and 50 epochs for completeness. Sean only provided a 5 epoch accuracy. His accuracy 90.7% improves upon my initial notebook. So let’s try the 20 epoch level and the 50 epoch level. We’ll credit Sean with all three numbers.

training_seed = 42
learning_rate = 1.0e-3

schedule =
    transition_steps: 1850,
    decay_rate: 0.5

optimizer = Axon.Optimizers.adam(schedule)

trained_model_state =
  |> Axon.Loop.trainer(:categorical_cross_entropy, optimizer)
  |> Axon.Loop.metric(:accuracy)
  |>, batched_labels), %{},
    epochs: hyperparams[:epochs],
    compiler: EXLA,
    seed: training_seed

Comparison with the test data leaderboard

Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.

Let’s get the test data.

{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, _, _} = test_images

test_images =
  |> Nx.from_binary(images_type)
  |> Nx.reshape({:auto, 28, 28, 1})
  |> Nx.divide(255)

{test_labels_binary, _, _} = test_labels

test_labels =
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.

test_batched_images = Nx.to_batched(test_images, hyperparams[:batch_size])
test_batched_labels = Nx.to_batched(test_labels, hyperparams[:batch_size])

|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|>, test_batched_labels), trained_model_state,
  compiler: EXLA

Challenge: #ElixirFashionML

Sean’s #ElixirFashionMLChallenge Leaderboard (Accuracy) on 8/13/2023

  • 5 Epochs - 90.7%
  • 20 Epochs - 91.1%
  • 50 Epochs - 90.9%

We have an 5 epoch accuracy of 90.7% vs Jeremy’s 12/15 accuracy of 92.7%. That still leaves plenty of opportunities for the community to leap to the top of the leaderboard


We highly recommend purchasing Sean Moriarity’s book, Machine Learning in Elixir. He and Jose’ started the Elixir numerical compute capability. The book explains many important concepts about training models in Elixir.

fastai, livebook, axon, foundations, xla, deep_learning

Elixir FashionMNIST Challenge

Elixir State of the Art Challenge using FashionMNIST

fastai, livebook, axon, foundations, xla, deep_learning

Exploring Elixir Nx Backends

How do backends work for Nx? What is the impact of XLA and TorchScript backends on performance?

fastai, nx, axon, foundations, deep_learning

Matrix multiplication on GPU - TorchScript

Accelerating Nx with Torchscript

fastai, livebook, axon, foundations, deep_learning

Matrix multiplication on GPU - XLA

Accelerating Nx with XLA and the GPU

fastai, livebook, axon, foundations, xla, deep_learning

Matrix multiplication on CPU - XLA

Accelerating Nx with XLA and just the CPU

fastai, livebook, axon, foundations, xla, deep_learning

Intro to Deep Learning from the Foundations, in Elixir

Developing a set of notebooks to follow the Deep Learning from the Foundations to Stable Diffusion Course

fastai, axon, foundations, deep_learning

Matrix multiplication from foundations

CPU focus introduction to matrix multiplication for deep learning

fastai, livebook, axon, foundations, matrix_multiplication, deep_learning

Sharing Axon Models on Huggingface

Finding a place to share our models

axon, huggingface, imagenet

Introducing Along the Axon

An Exploration of the Elixir Machine Learning Framework - Axon