Along the Axon
Exploring Elixir Machine Learning
Classifying Simple Fashion Types - Sean Moriarity
https://github.com/meanderingstream/dl_foundations_in_elixir/blob/main/ElixirFashionML_Challenge/fashion_mnist_sean_m.livemd
Mix.install([
{:axon, "~> 0.5.1"},
{:exla, "~> 0.5.3"},
{:req, "~> 0.3.10"},
{:scidata, "~> 0.1.10"}
])
Elixir FashionMNIST Challenge
I challenged an Elixir community to an Elixir Fashion MNIST challenge, https://alongtheaxon.com/blog/fashion_mnist_challenge. The idea was derived from a Twitter post by Jeremy Howard. Jeremy was teaching his Deep Learning from the Foundations 2022 course. He used the Fashion MNIST dataset as an accessible deep learning problem. By using accuracy, anyone with patience and a CPU could try to compete for a better accuracy measure. Having a GPU would be faster, but using the CPU works pretty well.
Sean Moriarity’s Blog Post
Sean created an excellent Dockyard blog post inviting Elixir folks to join the Elixir FashionMNIST Challenge. In his post, he presented his initial approach to improving upon my initial baseline. He achieved a 90.7% accuracy on the test set. However, he didn’t link to a Livebook implementation. This notebook attempts to match his blog post in a Livebook notebook that you can use to repeat his results. I also tried the other two epoch sizes and added Sean to the leaderboard.
Hyperparameters
Hyperparameters in machine learning are choices the developer makes that shape the training of a model. However, what model to use is one of those choices but it isn’t a simple hyperparameter. Let’s create a map with our simple parameter choices. It should make it easier to see some key training choices. We can then reference the choices later in our notebook.
hyperparams = %{
epochs: 5,
batch_size: 32
}
Retrieving and exploring the dataset
The Fashion MNIST dataset is available for free online. The Elixir SciData library provides an easy technique to access the training and test datasets.
{train_images, train_labels} = Scidata.FashionMNIST.download()
# Normalize and batch images
{images_binary, images_type, images_shape} = train_images
train_images =
images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels
train_labels =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
Pipeline
The next cell was one of the key improvements that Sean made. The pipeline concept provides a place for dynamically modifying the original dataset. Without having to capture more data, augmenting input data improves the quality of the data. When working on a machine learning problem, the business team often asks how much data is required. There isn’t a magic number, or rule of thumb, for how much data is required. However, useful augmentation of the data can improve the quality of the data at a low acquisition cost.
seed = 42
{batched_images, _} =
train_images
|> Nx.to_batched(hyperparams[:batch_size])
|> Enum.map_reduce(Nx.Random.key(seed), fn batch, key ->
fun =
Nx.Defn.jit(
fn regular, key ->
{mask, key} = Nx.Random.uniform(key)
flipped = Nx.reverse(regular, axes: [1])
augmented = Nx.select(Nx.greater(mask, 0.5), flipped, regular)
{augmented, key}
end,
compiler: EXLA
)
fun.(batch, key)
end)
batched_labels =
train_labels
|> Nx.to_batched(hyperparams[:batch_size])
|> Enum.to_list()
Defining the model
Sean created a custom, but small, convolutional neural network (CNN). Two convolutional blocks with max pooling. This is an incremental improvement over the initial model. You can experiment with other model approaches to improve the score.
model =
Axon.input("features")
|> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: 2)
|> Axon.conv(64, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: 2)
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
Training
Sean also improved the training process. He added a learning rate scheduler instead of the constant default learning rate from the Adam optimizer. Learning rate schedules are a fantastic area for experimentation. Sean used 1850 transition steps which is just under one epoch at a batch size of 32. 60,000/32 = 1,875.
What happens when the decay rate is changed? What is the best nominal learning rate? What happens with the transition steps are increased or decreased? Is a decay optimal or should it be an increasing learning rate? What about a one-cycle approach? These are all hyperparameter decisions that can be varied to improve the model test set accuracy.
Based upon the results of PyTorch challenge from last winter, every leaderboard change overtook the others for all 3 epoch levels. Five epochs is enough to experiment with different model and training approaches. If 5 epochs is more accurate than the current leaderboard, then try the 20 and 50 epochs for completeness. Sean only provided a 5 epoch accuracy. His accuracy 90.7% improves upon my initial notebook. So let’s try the 20 epoch level and the 50 epoch level. We’ll credit Sean with all three numbers.
training_seed = 42
learning_rate = 1.0e-3
schedule =
Axon.Schedules.exponential_decay(
5.0e-3,
transition_steps: 1850,
decay_rate: 0.5
)
optimizer = Axon.Optimizers.adam(schedule)
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, optimizer)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(Stream.zip(batched_images, batched_labels), %{},
epochs: hyperparams[:epochs],
compiler: EXLA,
seed: training_seed
)
Comparison with the test data leaderboard
Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.
Let’s get the test data.
{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, _, _} = test_images
test_images =
test_images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
{test_labels_binary, _, _} = test_labels
test_labels =
test_labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.
test_batched_images = Nx.to_batched(test_images, hyperparams[:batch_size])
test_batched_labels = Nx.to_batched(test_labels, hyperparams[:batch_size])
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(Stream.zip(test_batched_images, test_batched_labels), trained_model_state,
compiler: EXLA
)
Challenge: #ElixirFashionML
Sean’s #ElixirFashionMLChallenge Leaderboard (Accuracy) on 8/13/2023
- 5 Epochs - 90.7%
- 20 Epochs - 91.1%
- 50 Epochs - 90.9%
We have an 5 epoch accuracy of 90.7% vs Jeremy’s 12/15 accuracy of 92.7%. That still leaves plenty of opportunities for the community to leap to the top of the leaderboard
Resources
We highly recommend purchasing Sean Moriarity’s book, Machine Learning in Elixir. He and Jose’ started the Elixir numerical compute capability. The book explains many important concepts about training models in Elixir.
fastai, livebook, axon, foundations, xla, deep_learning
Elixir FashionMNIST Challenge
Elixir State of the Art Challenge using FashionMNIST
fastai, livebook, axon, foundations, xla, deep_learning
Exploring Elixir Nx Backends
How do backends work for Nx? What is the impact of XLA and TorchScript backends on performance?
fastai, nx, axon, foundations, deep_learning
Matrix multiplication on GPU - TorchScript
Accelerating Nx with Torchscript
fastai, livebook, axon, foundations, deep_learning
Matrix multiplication on GPU - XLA
Accelerating Nx with XLA and the GPU
fastai, livebook, axon, foundations, xla, deep_learning
Matrix multiplication on CPU - XLA
Accelerating Nx with XLA and just the CPU
fastai, livebook, axon, foundations, xla, deep_learning
Intro to Deep Learning from the Foundations, in Elixir
Developing a set of notebooks to follow the Fast.ai Deep Learning from the Foundations to Stable Diffusion Course
fastai, axon, foundations, deep_learning
Matrix multiplication from foundations
CPU focus introduction to matrix multiplication for deep learning
fastai, livebook, axon, foundations, matrix_multiplication, deep_learning