Matrix multiplication on GPU - XLA

Accelerating Nx with XLA and the GPU

Matrix multiplication on GPU - XLA

Run in Livebook

    {:nx, "~> 0.4.0"},
    {:scidata, "~> 0.1.9"},
    {:axon, "~> 0.3.0"},
    {:exla, "~> 0.4"}
  system_env: %{"XLA_TARGET" => "cuda111"}

Before running notebook

This notebook has a dependency on EXLA. XLA supports systems with direct access to an NVidia GPU, AMD ROCm or a Google TPU. According to the documentation, EXLA will try to find a precompiled version that matches your system. If it doesn’t find a match. you will need to install CUDA and CuDNN for your system.

The notebook is currently configured for Nvidia GPU via

system_env: %{"XLA_TARGET" => "cuda111"}

Review the configuration documentation for more options.

We had to install CUDA and CuDNN but that was several months ago. Your experience may vary from ours.


This Livebook is a transformation of a Python Jupyter Notebook from’s From Deep Learning Foundations to Stable Diffusion, Practical Deep Learning for Coders part 2, 2022. Specifically, it mimics the CUDA portion of

The purpose of the transformation is to bring the concepts to Elixir focused developers. The object-oriented Python/PyTorch implementation is transformed into a functional programming implementation using Nx and Axon

Experimenting with backend control

In this notebook, we are going to experiment with swapping out backends in the same notebook. One of the strengths of Elixir’s numerical processing approach is the concept of a backend. The same Nx code can run on several different backends. This allows Nx to adapt to changes in numerical libaries and technology. Currently, Nx has support for Tensorflow’s XLA and PyTorch’s TorchScript. Theoretically, backends for SOC type devices should be possible.

We chose not to set the backend globally throughout the notebook. At the beginning of the notebook we’ll repeat the approach we used in 01a_matmul_using_CPU. We begin with the Elixir Binary backend. You’ll see that it isn’t quick multiplying 10,000 rows of MNIST data by some arbitrary weights. We’ll then repeat the same multiplication using an NVidia 1080Ti GPU. The 1080 Ti is not the fastest GPU, but it is tremendously faster than a “large” set of data on the BinaryBackend.

226,000 times faster on an old GPU

Default - BinaryBackend

# Without choosing a backend, Nx defaults to Nx.BinaryBackend
# Just in case you rerun the notebook, let's make sure the default backend is BinaryBackend
# Setting to the Nx default backend

We’ll pull down the MNIST data

{train_images, train_labels} =
{train_images_binary, train_tensor_type, train_shape} = train_images

Convert into Tensors and normalize to between 0 and 1

train_tensors =
  |> Nx.from_binary(train_tensor_type)
  |> Nx.reshape({60000, 28 * 28})
  |> Nx.divide(255)

We’ll separate the data into 50,000 train images and 10,000 validation images.

x_train_cpu = train_tensors[0..49_999]
x_valid_cpu = train_tensors[50_000..59_999]
{x_train_cpu.shape, x_valid_cpu.shape}

Training is more stable when random numbers are initialized with a mean of 0.0 and a variance of 1.0

mean = 0.0
variance = 1.0
weights_cpu = Nx.random_normal({784, 10}, mean, variance, type: {:f, 32})

In order to simplify timing the performance of the function, we’ll use an 0 parameter anonymous function. Invoking the anonymous function will always use the two parameters, x_valid_cpu and weights_cpu.

large_nx_mult_fn = fn ->, weights_cpu) end

The following anonymous function takes function and the number of times to make the call to the function.

repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end

Timing the average duration of the dot multiply function to run. The cell will output the average and total elapsed time

repeat_times = 5
{elapsed_time_micro, _} =, [large_nx_mult_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

{backend, _device} = Nx.default_backend()

"#{backend} CPU avg time in #{avg_elapsed_time_ms} milliseconds, total_time #{elapsed_time_micro / 1000} milliseconds"

XLA using GPU

We’ll switch to the XLA backend and use the cuda device. If you have a different device, replace all the :cuda specifications with your device.

Nx.default_backend({EXLA.Backend, device: :cuda})

In the following cell, we transfer the target data onto the GPU.

x_valid_cuda = Nx.backend_transfer(x_valid_cpu, {EXLA.Backend, client: :cuda})
weights_cuda = Nx.backend_transfer(weights_cpu, {EXLA.Backend, client: :cuda})

An anonymous function that calls with data on the GPU

exla_gpu_mult_fn = fn ->, weights_cuda) end

We’ll warm up the GPU by looping through 5 function calls and then timing the next 5 function calls.

repeat_times = 5
# Warm up one epoch
{elapsed_time_micro, _} =, [exla_gpu_mult_fn, repeat_times])
# The real timing starts here
{elapsed_time_micro, _} =, [exla_gpu_mult_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

{backend, [device: device]} = Nx.default_backend()

"#{backend} #{device} avg time in #{avg_elapsed_time_ms} milliseconds total_time #{elapsed_time_micro / 1000} milliseconds"
x_valid_cpu = Nx.backend_transfer(x_valid_cuda, Nx.BinaryBackend)
weights_cpu = Nx.backend_transfer(weights_cuda, Nx.BinaryBackend)

fastai, livebook, axon, foundations, xla, deep_learning