Matrix multiplication on CPU - XLA

Accelerating Nx with XLA and just the CPU

Matrix multiplication on CPU- XLA

Run in Livebook

Mix.install(
  [
    {:nx, "~> 0.4.0"},
    {:scidata, "~> 0.1.9"},
    {:axon, "~> 0.3.0"},
    {:exla, "~> 0.4"}
  ]
)
Resolving Hex dependencies...
Dependency resolution completed:
New:
  axon 0.3.0
  castore 0.1.18
  complex 0.4.2
  elixir_make 0.6.3
  exla 0.4.0
  jason 1.4.0
  nimble_csv 1.2.0
  nx 0.4.0
  scidata 0.1.9
  xla 0.3.0
* Getting nx (Hex package)
* Getting scidata (Hex package)
* Getting axon (Hex package)
* Getting exla (Hex package)
* Getting elixir_make (Hex package)
* Getting xla (Hex package)
* Getting castore (Hex package)
* Getting jason (Hex package)
* Getting nimble_csv (Hex package)
* Getting complex (Hex package)
==> jason
Compiling 10 files (.ex)
Generated jason app
==> nimble_csv
Compiling 1 file (.ex)
Generated nimble_csv app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 27 files (.ex)
Generated nx app
==> axon
Compiling 24 files (.ex)
Generated axon app
==> elixir_make
Compiling 1 file (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /home/ml3/.cache/xla/0.3.0/cache/download/xla_extension-x86_64-linux-cpu.tar.gz into /home/ml3/.cache/mix/installs/elixir-1.14.1-erts-13.1/45e4038ac8aacd103fe2688496702add/deps/exla/cache
g++ -fPIC -I/home/ml3/.asdf/installs/erlang/25.1/erts-13.1/include -Icache/xla_extension/include -O3 -Wall -Wno-sign-compare -Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment -shared -std=c++14 c_src/exla/exla.cc c_src/exla/exla_nif_util.cc c_src/exla/exla_client.cc -o cache/libexla.so -Lcache/xla_extension/lib -lxla_extension -Wl,-rpath,'$ORIGIN/lib'
Compiling 21 files (.ex)
Generated exla app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> scidata
Compiling 13 files (.ex)
Generated scidata app
:ok

Before running notebook

This notebook has a dependency on EXLA. XLA support systems with direct access to an NVidia GPU, AMD ROCm or a Google TPU. According to the documentation, https://github.com/elixir-nx/nx/tree/main/exla#readme EXLA will try to find a precompiled version that matches your system. If it doesn’t find a match. you will need to install CUDA and CuDNN for your system.

The notebook is currently configured for Nvidia GPU via

system_env: %{"XLA_TARGET" => "cuda111"}

Review the configuration documentation for more options. https://hexdocs.pm/exla/EXLA.html#module-configuration

We had to install CUDA and CuDNN but that was several months ago. Your experience may vary from ours.

Context

This Livebook is a transformation of a Python Jupyter Notebook from Fast.ai’s From Deep Learning Foundations to Stable Diffusion, Practical Deep Learning for Coders part 2, 2022. Specifically, it mimics the CUDA portion of https://github.com/fastai/course22p2/blob/master/nbs/01_matmul.ipynb

The purpose of the transformation is to bring the Fast.ai concepts to Elixir focused developers. The object-oriented Python/PyTorch implementation is transformed into a functional programming implementation using Nx and Axon

Experimenting with backend control

In this notebook, we are going to experiment with swapping out backends in the same notebook. One of the strengths of Elixir’s numerical processing approach is the concept of a backend. The same Nx code can run on several different backends. This allows Nx to adapt to changes in numerical libaries and technology. Currently, Nx has support for Tensorflow’s XLA and PyTorch’s TorchScript. Theoretically, backends for SOC type devices should be possible.

We chose not to set the backend globally throughout the notebook. At the beginning of the notebook we’ll repeat the approach we used in 01a_matmul_using_CPU. We begin with the Elixir Binary backend. You’ll see that it isn’t quick multiplying 10,000 rows of MNIST data by some arbitrary weights. We’ll then repeat the same multiplication using an NVidia 1080Ti GPU. The 1080 Ti is not the fastest GPU, but it is tremendously faster than a “large” set of data on the BinaryBackend.

226,000 times faster on an old GPU

Backends

# Without choosing a backend, Nx defaults to Nx.BinaryBackend
Nx.default_backend()
{Nx.BinaryBackend, []}

Let’s change to EXLA with CPU

Nx.default_backend({EXLA.Backend, device: :host})
Nx.default_backend()
{EXLA.Backend, [device: :host]}

We’ll pull down the MNIST data

{train_images, train_labels} = Scidata.MNIST.download()
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}},
 {<<5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8,
    6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, ...>>, {:u, 8}, {60000}}}
{train_images_binary, train_tensor_type, train_shape} = train_images
{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}}
train_tensor_type
{:u, 8}

Convert into Tensors and normalize to between 0 and 1

train_tensors =
  train_images_binary
  |> Nx.from_binary(train_tensor_type)
  |> Nx.reshape({60000, 28 * 28})
  |> Nx.divide(255)

18:50:30.293 [info] XLA service 0x7fe6d40e2330 initialized for platform Host (this does not guarantee that XLA will be used). Devices:

18:50:30.295 [info]   StreamExecutor device (0): Host, Default Version
#Nx.Tensor<
  f32[60000][784]
  EXLA.Backend<host:0, 0.2851900150.286654488.81191>
  [
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
    ...
  ]
>

We’ll separate the data into 50,000 train images and 10,000 validation images.

x_train_cpu = train_tensors[0..49_999]
x_valid_cpu = train_tensors[50_000..59_999]
{x_train_cpu.shape, x_valid_cpu.shape}
{{50000, 784}, {10000, 784}}

Training is more stable when random numbers are initialized with a mean of 0.0 and a variance of 1.0

mean = 0.0
variance = 1.0
weights_cpu = Nx.random_normal({784, 10}, mean, variance, type: {:f, 32})
#Nx.Tensor<
  f32[784][10]
  EXLA.Backend<host:0, 0.2851900150.286654488.81194>
  [
    [-0.973583996295929, 1.3404284715652466, 0.5889155268669128, -0.06439179182052612, -2.2255215644836426, -0.3939111828804016, -1.5497547388076782, -1.1714494228363037, 1.0855729579925537, -0.4689534306526184],
    [-0.31778475642204285, 0.07520100474357605, 0.053238045424222946, 0.42360711097717285, -2.253004312515259, -0.3818463981151581, -0.5468025803565979, 1.3460612297058105, 1.509813904762268, 0.10178464651107788],
    [2.7212319374084473, -0.6341637969017029, 1.9983967542648315, 0.4862823486328125, 0.951216459274292, -0.8570582270622253, 1.7834625244140625, -0.1596108078956604, -0.369051992893219, 0.7038326263427734],
    [-1.321571946144104, -0.573075532913208, -0.5281657576560974, -1.528030276298523, 0.5641341209411621, -0.13296610116958618, -0.20917919278144836, -0.5405102372169495, 0.13647650182247162, 1.0692965984344482],
    [1.1940683126449585, -1.0889204740524292, 0.26889121532440186, -0.8505605459213257, 0.31284958124160767, 0.8289848566055298, 0.23549814522266388, 0.5921769738197327, 0.506867527961731, 0.6787563562393188],
    ...
  ]
>

In order to simplify timing the performance of the Nx.dot/2 function, we’ll use an 0 parameter anonymous function. Invoking the anonymous function will always use the two parameters, x_valid_cpu and weights_cpu.

large_nx_mult_fn = fn -> Nx.dot(x_valid_cpu, weights_cpu) end
#Function<43.3316493/0 in :erl_eval.expr/6>

The following anonymous function takes function and the number of times to make the call to the function.

repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end
#Function<41.3316493/2 in :erl_eval.expr/6>

Timing the average duration of the dot multiply function to run. The cell will output the average and total elapsed time

repeat_times = 5
{elapsed_time_micro, _} = :timer.tc(repeat, [large_nx_mult_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

{backend, device} = Nx.default_backend()

"#{backend} CPU avg time in #{avg_elapsed_time_ms} milliseconds, total_time #{elapsed_time_micro / 1000} milliseconds"
"Elixir.EXLA.Backend CPU avg time in 1.2837999999999998 milliseconds, total_time 6.419 milliseconds"

fastai, livebook, axon, foundations, xla, deep_learning