Matrix multiplication from foundations

CPU focus introduction to matrix multiplication for deep learning

Matrix multiplication from foundations - CPU only

  {:nx, "~> 0.4.0"},
  {:binary, "~> 0.0.5"},
  {:stb_image, "~> 0.5.2"},
  {:scidata, "~> 0.1.9"},
  {:kino, "~> 0.7.0"},
  {:axon, "~> 0.3.0"}

Run in Livebook


This Livebook is a transformation of a Python Jupyter Notebook from’s From Deep Learning Foundations to Stable Diffusion, Practical Deep Learning for Coders part 2, 2022. Specifically, it mimics

The purpose of the transformation is to bring the concepts to Elixir focused developers. The object-oriented Python/PyTorch implementation is transformed into a functional programming implementation using Nx and Axon

About’s Teaching Philosophy

We’ll be leveraging the best available research on teaching methods to try to fix these problems with technical teaching, including:


In other words, focus on student success from the beginning. Help students become confident in their growing skills. Use a code-first approach to teaching Deep Learning. Provide plenty of examples of functioning neural networks that can be applied by the students.

This part 2 course is not exactly like the above description

The part 1 course fits the above description really well. When we went through each of the past 4 years of the part 1 course, we felt like the course spoke to our needs. In the best world, Elixir developers could learn from the part 1 course and come away several kinds of near state of the art neural net models running in Elixir. However, an Elixir version of the part 1 course doesn’t exist yet.

The part 2 course has a different focus. Jeremy Howard likes to call part 2 the (im)practical deep learning for coders. Part 2 goes under the hood and helps students understand the pieces of a neural network and a best practice focused training library. The foundations are taught with examples that help students understand how the pieces really work. It’s impractical because the problems are simpler, already solved, examples that don’t translate to a real-world problem. The examples used in part 2 are smaller, well known problems, but the focus is on understanding how the software skills you use daily are transformed into a neural network concepts that can utilize the GPU for time efficient training. The confidence gained from the part 2 course is knowing how to modify and change a model to fit your domain situation.

Foundation notebooks and previous videos

The 2022 Python/PyTorch “from the foundations” notebooks are in This notebook is being written while the live course is happening. course work is restricted to paid participants until after the course is completed. The notebooks are available in GitHub, but the videos and forum converstations are restricted. After the course completes, the videos and forums are open to everyone in the form of a massive open online course. However, fundamentals don’t change that much. The 2019 course videos,, would be a fine video companion for these Elixir notebooks, for now.

The first two lesson videos from the 2022 course were released early. In the second lesson, Jeremy covers the first portion of this notebook.

To Stable Diffusion

The 2022 part 2 course is Foundations to Stable Diffusion. In 2022, is focusing on understanding the pieces of Stable Diffusion and discussions on the latest research papers that improve upon Stable Diffusion. As a taste of what is coming, has released the videos from the first 2 weeks.

At the current time, Stable Diffusion doesn’t run in Elixir. part 2 is split into two types of notebooks. A set of notebooks focused on Stable Diffusion and another set of notebooks focused on the foundations. For now, we are focused only on the foundation notebooks.’s book

There was a recent Twitter discussion expressing a desire to see Deep Learning for Coders with Fastai and PyTorch: AI Applications Without a PhD examples in Elixir/Livebook. The meanderingstream/dl_foundations_in_elixir notebooks correspond to chapters 17, 18 and 19 in the book. Further resources related to the book can be found on the book page.

Part 2 Foundations approach

We’ll start with standard Elixir examples of the fundamentals. An Elixir focused developer should recognize the standard Elixir code. The part 2 “Game” is:

Because we are transforming Python/PyTorch into Elixir, some concepts don’t perfectly match back to the original Python code. There are some library differences and some of the tooling for Elixir and Livebook don’t perfectly match. Nx, Axon and Livebook are very recent technologies and their capabilities are growing each month.

Because we are mapping from Python/PyTorch to Elixir and the vast majority of machine learning examples are written in Python, we are often going to show the original Python from the notebook on top of the Elixir code. Hopefully this will help Elixir developers transform other PyTorch code into Elixir code

# Pytorch
# some python from a Jupyter notebook
# --> The result
#     from executing the cell goes here

some elixir code here

Brief Introduction to Elixir and Numerical Elixir

Elixir’s primary numerical datatypes and structures are not optimized for numerical programming. Nx is a library built to bridge that gap.

Elixir Nx is a numerical computing library to smoothly integrate to typed, multidimensional data implemented on other platforms (called tensors). This support extends to the compilers and libraries that support those tensors. Nx has three primary capabilities:

From Note that this url is a really a livebook notebook. When you click on the Run in Livebook button, it navigates to an intermediate page where you can choose the location of your LiveBook application. It then opens the page in your LiveBook application.

Course Start: From the foundations

Jeremy’s introduction: This part of the course will require some serious tenacity and a certain amount of patience. We think you are going to learn a lot. A lot people have given Jeremy feedback that the previous iteration of this course is the Best Course they’ve ever done. This course will be dramatically better than any previous version. Hopefully you’ll find that the hard work and patience pays off.

Our goal in this course is to get to stable diffusion from the foundations. We have to define what are the foundations. Jeremy resticted the Python foundations to:

In Elixir we’ll have our own foundation.

To be clear, we are allowed to use other libraries once we have reimplemented them correctly. If we reimplement something from NumPy or PyTorch, we are then allowed to use those libraries. Sometimes we are going to implement things that haven’t been created before. Those things will become part of our own library. We are going to be calling that library miniai. We are going to be building our own little framework as we go.

One challenge that we have, the models we use in Stable Diffusion were trained on millions of dollars of equipment for months. We don’t have the time or money for those compute resources. Another trick we are going to do is create identical but smaller versions of them. Once we have them working, we’ll be allowed to use the big pre-trained version.

So we are going to have to end up with our own variational auto-encoder, our own U-Net, our own CLIP encoder, and so forth.

To certain extent, Jeremy assumes that you’ve gone through part 1. If you find something that doesn’t make sense to you. Go back to the part 1 course or Google for what you don’t understand. For stuff that wasn’t covered in part 1, we’ll go over it thoroughly and carefully.

Reference: Jeremy’s discussion in the Lesson 10 video.

Elixir foundations

In our foundations version, we’ll make the following assumptions throughout these Elixir versions of’s notebooks:

The documentation for Nx and Axon are found at and

To run these notebooks, you will need to install a local version of Livebook or get access to a cloud server. Many of our foundation notebooks don’t need a GPU. Nx comes with an Elixir only BinaryBackend that runs on any CPU that supports Livebook. If EXLA or Torchx aren’t in the Mix.install at the top of a notebook, it can be run on any computer. Please give it a try.

We’ll follow roughly the same approach as the PyTorch version of the course. We’ll start with standard Elixir, with some additional libraries. Once we’ve implemented a capability, we’ll move on using Nx and Axon libraries. We’ll invent our own libraries as needed.

Getting the Data

We are going to need some input data. uses MNIST for this part of the course. Elixir has the SciData library that contains small standard datasets, including MNIST.

We are diverging from the cell by cell transformation of the 01_matmul.ipynb because SciData works differently from the .pth files used in

{train_images, train_labels} =
{test_images, test_labels} = Scidata.MNIST.download_test()
# Let's unpack the images like...
{train_images_binary, tensor_type, train_shape} = train_images
{test_images_binary, tensor_type, test_shape} = test_images

The source for MNIST training data returns normalized data with a shape of (50000, 784). 50,000 items that is 784 numbers long. The numbers are all between 0 and 1. We’ll need to change our binary into numbers and divide the numbers by 255 to normalize the values.

# Normalize the values first.
train_normalized_long_list =
  |> value -> value / 255 end)

The data source used split the 60,000 image MNIST train data into 50,000 train images and 10,000 validation images. We’ll do a similar split after the first 50,000 images.

{train_list_784, valid_list_784} =
  Enum.chunk_every(train_normalized_long_list, 784)
  |> Enum.split(50_000)
train_imgs_28_28 =
    fn img ->
      Enum.chunk_every(img, 28)

Let’s check that we still have 50000 images and that the count of rows in the first image is 28 and the count of columns in the first row of the first image is 28

{Enum.count(train_imgs_28_28), Enum.count(, 0)),
 Enum.count(, 0), 0))}

Visualizing Normalized Data

We have a normalized image in memory, how do we check that it really represents an image?

first_img_28_28 =, 0)

We don’t know of a convenient method to convert a normalized list of lists into an image. However, if we convert to a tensor, we can load the tensor into StbImage. We are going to cheat and look ahead at some concepts described below, but we’ll be able to show the image.

first_img =
  |> row ->, fn column ->
      round(column * 255)
  |> Nx.tensor(type: :u8)
  |> Nx.reshape({28, 28, 1})
  |> StbImage.from_nx()
  |> StbImage.to_binary(:png)
# Python
# mpl.rcParams['image.cmap'] = 'gray'
# plt.imshow(list(chunks(lst1, 28)));

# Kino currently assumes the image is larger than the box
image =, :png)
label ="**MNIST Image**")

images = [
  Kino.Layout.grid([image, label], boxed: true)

Kino.Layout.grid(images, columns: 3)

Matrix and tensor

Let’s pull an individual value from a list of lists

# Find a row with some non-zero values

# 8th row
first_non_zero_in_row =, 8)
  |> Enum.find_index(fn x -> x != 0.0 end)
# Let's find a value somewhere in that list of lists, 8)

Convenience module to make it easier to access an element in list of lists

defmodule Matrix do
  def at(matrix, row, column) do, row) |>
end, 8, 10)

Now that we’ve demonstrated how to load SciData into normal Elixir list of lists, access elements within the list of list, and shown the in-memory image data. Let’s start using Nx Tensors instead of lists of lists.

x_tensors =
  |> Nx.from_binary(tensor_type)
  |> Nx.reshape({60000, 28 * 28})
  |> Nx.divide(255)

Again, we’ll split the SciData training dataset into train and valid. We’ll use the names that are in the notebook

x_train = x_tensors[0..49_999]
x_valid = x_tensors[50_000..59_999]
{x_train.shape, x_valid.shape}

CAUTION: Even though it kind of looks like we called a function on a data object, all we really did was access the shape field of a struct. Just simple data field access. The human readable representation of the struct simplifies the representation to make it easier to see. Tensors can have a lot of data in their struct fields. Type is another field. See how the type and shape are scunched together in the print view.


Let’s load an Nx normalized tensor and visualize with Kino.Image

img_tensor =
  |> Nx.reshape({28, 28})

Let’s visualize the image like we did earlier, except this time the source is an Nx.Tensor

first_img_from_tensor =
  |> Nx.reshape({28, 28, 1})
  |> Nx.multiply(255)
  |> Nx.round()
  |> Nx.as_type({:u, 8})
  |> StbImage.from_nx()
  |> StbImage.to_binary(:png)
# Python
# plt.imshow(imgs[0]);

image =, :png)
label ="**MNIST Image from tensor**")

images = [
  Kino.Layout.grid([image, label], boxed: true)

Kino.Layout.grid(images, columns: 3)

Let’s parse out the classification labels of each image. Each element identifies the digit each handwritten image represents.

{train_y_binary, y_tensor_type, y_shape} = train_labels
y_tensors =
  |> Nx.from_binary(y_tensor_type)
  |> Nx.reshape(y_shape)

We’ll split into train and valid like the data source.

y_train = y_tensors[0..49_999]
y_valid = y_tensors[50_000..59_999]
{y_train.shape, y_valid.shape}

We couldn’t find a min function in Nx that corresponds to the min, or max, function in Python that works on tensors. We’ll convert to a flat, normal Elixir list. We’ll use Enum to find the min or max and then convert back to Nx tensor scalars.

# PyTorch
# y_train.min(), y_train.max()
# --> (tensor(0), tensor(9))

{Nx.tensor(Enum.min(Nx.to_flat_list(y_train))), Nx.tensor(Enum.max(Nx.to_flat_list(y_train)))}

Random Numbers

For now, we are going to treat the random number section of the notebook as a problem specific to PyTorch. The problematic situation comes from using OS.fork() to parallelize some work that calls the rand() function. In Python, the fork create a copy of the current process. The particular problem is the fork includes the global rnd_state of the parent process. Each process that calls rand() will receive the same sequence of psuedo-random numbers.

The discussion on how psuedo-random numbers in the video is well worth watching

TODO: How does Elixir handle psuedo-random number sequences in two Elixir processes.

Tensor rank

The rank of a tensor is the number of indices required to uniquely select each element of the tensor. Rank is also known as “order”, “degree”, or “ndims.”


In Livebook/Nx, the rank can be observed from the number of square bracket pairs behind the type label, i.e. s64.

# Rank 1 tensor
Nx.tensor([1, 2, 3])
# Rank 2 tensor
Nx.tensor([[1, 2], [2, 3]])
# Rank 3 tensor
Nx.tensor([[[1, 2], [2, 3]], [[4, 5], [5, 6]]])
# Rank 0 tensor

Online Matrix multiplication

We are working on the start of a forward pass of a very simple linear model, a multi-layer perceptron, for MNIST. We now need to multiply tensors together.

There are several websites that can provide visual examples of matrix multiplication.

Matrix multiplication is a fundamental capability of deep learning. We are going to look at how to do matrix multiplication in standard Elixir and then use Nx to perform the multiplication as a tensor.

Mutable data approaches vs Immutable data

Many software languages have mutable data. Certainly Python has mutable data. Let’s go into details about how immutable data in Elixir is different than working with mutable data.

In Python, Jeremy uses this approach to multiply two tensors.

for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i,j] += m1[i,k] * m2[k,j]

Turning it into a function would look like:

def py_multiply(m1, m2, t1)
    ar,ac = m1.shape # n_rows * n_cols
    br,bc = m2.shape
    for i in range(ar):         # 5
        for j in range(bc):     # 10
            for k in range(ac): # 784
                t1[i,j] += m1[i,k] * m2[k,j]

t1 is the resulting matrix. In the Python notebook, it is set to zeros via

t1 = torch.zeros(ar, bc)

t1 is mutable. The value at t1[i,j] is replaced with a new value via += . When the function has completed, the variable, call it t_init, passed as the third argument has the new values.

We’ve mentioned before that Elixir has immutable data. When trying the same kind of approach in Elixir.

defmodule DoesntWork
  def add(m1, m2, t)
    t = m1, m2

Let’s try it

defmodule DoesntWork do
  def add(m1, m2, t) do
    t = m1 + m2
t = 5
DoesntWork.add(3, 6, t)
"t is #{t} not 9"
"because Elixir's data is immutable"

The above long winded point is that matrix operations in pure Elixir can’t use the for loop approach in mutable data languages.

Matrix multiplication in plain Elixir

We’ll diverge from’s notebook to dig into how we can do matrix multiplication without using for loops. We’ll also focus on Elixir list of lists rather than Nx.tensors like the Python notebooks.

Let’s look at a matrix multiplication in traditional Elixir. This example was modified from

defmodule Matrix do
  def mult(m1, m2) do, fn x ->, fn y ->, y)
        |> {x, y} -> x * y end)
        |> Enum.sum()

  # transpose
  def transpose(m) do |>
# Let's set up an example multiplication using an example 
# from
m_3x3 = [[1, 2, 1], [0, 1, 0], [2, 3, 4]]

m_3x2 = [[2, 5], [6, 7], [1, 8]]
# Let's check that matrix multiplication works.  We should get
# [[15,27],
#  [ 6, 7],
#  [26,63]]

Matrix.mult(m_3x3, m_3x2)

Go to and put in your own matrix values to try it out.

Let’s dig into the Elixir code in our module

As Elixir developers, we understand how works. But not everyone may have a good understanding. Let’s explore takes one argument of some Enumerable like a list. It also receives another argument which must be a function, generally an anonymous function. For each element in the Enumerable, it calls the function with the current element and appends the result to a list. It then returns the resulting list.

some_list = [1, 2, 3]
another_list = [4, 5, 6], fn x ->
  # The inner map() can see the outer x, fn y ->
    IO.puts("x is #{x} y is #{y}")

Next we’ll dig into the transpose function.

transpose = fn m ->
  |> IO.inspect(label: "one of the rows in zip")

  # |>

# In matrix multiplication the another list needs to be vertical.
another_list = [[4, 5, 6]]

transpose = fn m ->
  # We know that the first and only element in the list is 
  # [{4}, {5}, {6}]
  # Which means that three items are in the Enumerable passed to
  #  the first is {4}
  #  the first is then transformed from a tuple, i.e. {something} into 
  #    a list of [something]

# In matrix multiplication the another_list needs to be vertical.
another_list = [[4, 5, 6]]


There is a really funky set of code above. &Tuple.to_list(&1) THe & is an Elixir capture operator. Here is a blog post explaining the capture operator.

Personally, we are more comfortable with the slightly more verbose form of creating an anonymouse function. Our minds grok this form easier. They both result in the same answer.

transpose = fn m ->
  # We know that the first and only element in the list is 
  # [{4}, {5}, {6}]
  # Which means that three items are in the Enumerable passed to
  #  the first is {4}
  #  the first is then transformed from a tuple, i.e. {something} into 
  #    a list of [something]
  |> x -> Tuple.to_list(x) end)

# In matrix multiplication the another_list needs to be vertical.
another_list = [[4, 5, 6]]


The next set of code takes the first matrix and the transform of the second matrix and zips them together. As before we see that we end up with list of tuples, fn x ->, fn y ->, y)

The next step is to take the list of list of list of tuples and run through with a multiply function and return a list., fn x ->, fn y ->, y)
    |> {x, y} -> x * y end)

Finally, we sum up the elements in the inner most list., fn x ->, fn y ->, y)
    |> {x, y} -> x * y end)
    |> Enum.sum()

And we get the same answer from calling the Matrix.mult function above.

Whew. I hope you could follow along and we didn’t lose you. We’ve now implemented matrix multiplication using standard Elixir. Thus, we can now use

t_3x3 = Nx.tensor(m_3x3)
t_3x2 = Nx.tensor(m_3x2)
{t_3x3.shape, t_3x2.shape}

We’ve now implemented matrix multiplication using standard Elixir. Thus, we can now use Nx.mult(). Still the same answer, just now it is with tensors, t_3x2)

Let’s measure how fast, really kind of slow, the BinaryBackend. Remember, we aren’t using the GPU in this notebook so don’t compare with the PyTorch when Jeremy is using a GPU.

Timing operations

In Elixir, the erlang timer tc function can be use to time function calls. Here is a link to a discussion on So we can call the same function multiple times, we’ll create a named anonymous repeat function. We’ll has create function that calls represents are target function. We’ll hard code the same arguments in the target function.

repeat = fn timed_fn, times -> Enum.each(1..times, fn _x -> timed_fn.() end) end
matrix_mult_w_dot_fn = fn ->, t_3x2) end
repeat_times = 50
{elapsed_time_micro, _} =, [matrix_mult_w_dot_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

"avg time in milliseconds #{avg_elapsed_time_ms} total_time #{elapsed_time_micro / 1000} milliseconds"

Not to bad performance but the tensors are small.

Matrix multiplication

Let’s create some tensor random weights with a mean of about 0.0 and variance of about 1.0

# PyTorch
# weights = torch.randn(784,10)
# bias = torch.zeros(10)
# weights, weights.max(), weights.mean(), weights.var()

mean = 0.0
variance = 1.0
weights = Nx.random_normal({784, 10}, mean, variance, type: {:f, 32})

# In elixir, Nx doesn't have the ability to create a Tensor of 0s or 1s.  We have to use 
# Axon's initializers
init_zeros = Axon.Initializers.zeros()
bias = init_zeros.({10}, {:f, 32})
{bias, weights}
{Nx.mean(weights), Nx.variance(weights)}

Let’s take the first 5 rows, m1, of the training dataset, 5x784, images x pixels. For every one of the 784 pixels in each row of the tensor, we need a weight multiplication factor. The weights map to each one of the 10 potential digits in the y_valid data, 784x10. The first column of weights will identify all of the weights to figure out whether the pixels represent a 0. The second column will determine the weights to tell us the probability the pixels represent a 1, etc. up to 9.

# PyTorch
# x_valid[:5]

m1 = x_valid[0..4]
m2 = weights
{m1.shape, m2.shape}
# PyTorch
# ar,ac = m1.shape # n_rows * n_cols
# br,bc = m2.shape
# (ar,ac),(br,bc)

{ar, ac} = m1.shape
{br, bc} = m2.shape
{{ar, ac}, {br, bc}}
# PyTorch
# t1 = torch.zeros(ar, bc)
# t1.shape

t1 = init_zeros.({ac, bc}, {:f, 32})

When we multiply matrices together, we take row 1 of the first matrix. We take column 1 of the second matrix. We multiply the row 1 elements and column 2 elements in turn. r1[1] times c1[1], r1[2] times c2[2]…. and we sum them together. The sum would give the value for the very first cell in the resulting 5x10 matrix

Let’s compare the time to multiply two standard Elixir matrices with the time to multiply using Nx tensors with the BinaryBackend., m2)

Let’s time our Nx matrix multiplication.

dot_m1_m2_fn = fn ->, m2) end

repeat_times = 50
{elapsed_time_micro, _} =, [dot_m1_m2_fn, repeat_times])
avg_elapsed_time_ms = elapsed_time_micro / 1000 / repeat_times

"avg time in milliseconds #{avg_elapsed_time_ms} total_time #{elapsed_time_micro / 1000} milliseconds"

Let’s return to closely following the notebook.

Elementwise ops

The point of this section is to perform a function on each element of the Python tensor. The Elixir implementation would use the non tensor data loaded above.

# PyTorch
# a = tensor([10., 6, -4])
# b = tensor([2., 8, 7])
# a,b
# --> (tensor([10.,  6., -4.]), tensor([2., 8., 7.]))

a = Nx.tensor([10.0, 6, -4])
b = Nx.tensor([2.0, 8, 7])
{a, b}
# PyTorch
# a + b
# --> tensor([12., 14.,  3.])

Nx.add(a, b)
# PyTorch
# (a < b).float().mean()
# --> tensor(0.67)

Nx.less(a, b)
|> Nx.as_type({:f, 32})
|> Nx.mean()
# PyTorch
# m = tensor([[1., 2, 3], [4,5,6], [7,8,9]]); m
# --> 
# tensor([[1., 2., 3.],
#         [4., 5., 6.],
#         [7., 8., 9.]])

# In Livebook, we don't need to specify what to show results
# on, if the item of interest is the last calculation.
# So we don't need the ;m at the end
m = Nx.tensor([[1.0, 2, 3], [4, 5, 6], [7, 8, 9]])

Frobenius norm:

We’ll use the Frobenius norm from time to time as we do generative modeling

It’s the sum over all of the rows and columns of the matrix. We are going to take each one and square it. We are going to add them up and take the square root

$$\| A \|F = \left( \sum{i,j=1}^n | a_{ij} |^2 \right)^{1/2}$$

Hint: you don’t normally need to write equations in LaTeX (really KaTeX) yourself, instead, you can click ‘edit’ in Wikipedia and copy the LaTeX from there (which is what Jeremy did for the above equation). Or on, click “Download: Other formats” in the top right, then “Download source”; rename the downloaded file to end in .tgz if it doesn’t already, and you should find the source there, including the equations to copy and paste. This is the source LaTeX that Jeremy pasted to render the equation above:

$$\| A \|_F = \left( \sum_{i,j=1}^n | a_{ij} |^2 \right)^{1/2}$$

In my case, I went to the notebook code, .ipynb file, to copy the KaTeX from Jeremy’s code

To implement Frobenius norm in Elixir, it is m times m, sum them up and square root.

# PyTorch
# (m*m).sum().sqrt()
# --> tensor(16.88)

Nx.multiply(m, m)
|> Nx.sum()
|> Nx.sqrt()

This looked like a complicated math function when you initially looked at it. A whole bunch of squiggly things. But when you look at the code, it’s just multiply itself, sum and then square root.

A lot of machine learning papers have complicated looking math notations for simple or relatively simple functions in code.


The term broadcasting describes how arrays with different shapes are treated during arithmetic operations. The term broadcasting was first used by Numpy.

From the Numpy Documentation:

The term broadcasting describes how numpy treats arrays with 
different shapes during arithmetic operations. Subject to certain 
constraints, the smaller array is “broadcast” across the larger 
array so that they have compatible shapes. Broadcasting provides a 
means of vectorizing array operations so that looping occurs in C
instead of Python. It does this without making needless copies of 
data and usually leads to efficient algorithm implementations.

In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors.

This section was adapted from Chapter 4 of the Computational Linear Algebra course.

In turn, it was copied from the 01_matmul.ipynb code

# PyTorch
# a
# --> tensor([10.,  6., -4.])

# PyTorch
# a > 0
# -> tensor([ True,  True, False])
Nx.greater(a, 0)

How are we able to do a > 0? 0 is being broadcast to have the same dimensions as a.

For instance you can normalize our dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar), using broadcasting.

Other examples of broadcasting with a scalar:

# PyTorch
# a + 1
# --> tensor([11.,  7., -3.])

Nx.add(a, 1)
# The scalar can be in either position
Nx.add(1, a)
m = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
# PyTorch
# 2*m
# --> 
# tensor([[ 2.,  4.,  6.],
#         [ 8., 10., 12.],
#         [14., 16., 18.]])
Nx.multiply(m, 2)
Nx.multiply(2, m)

Broadcasting a vector to a matrix

Although broadcasting a scalar is an idea that dates back to APL, the more powerful idea of broadcasting across higher rank tensors comes from a little known language called Yorick.

We can also broadcast a vector to a matrix:

# PyTorch
# c = tensor([10.,20,30]); c
# --> tensor([10., 20., 30.])

# the vector
c = Nx.tensor([10.0, 20.0, 30.0])
# the matrix
# PyTorch
# m.shape,c.shape
# --> (torch.Size([3, 3]), torch.Size([3]))

{m.shape, c.shape}
# PyTorch
# m + c
# --> 
# tensor([[11., 22., 33.],
#         [14., 25., 36.],
#         [17., 28., 39.]])

# The vector is broadcast across the matrix shape and added
Nx.add(c, m)
# reverse the order and still the same answer
Nx.add(m, c)

Here is the trick that allows the matrix and vector to be added. The expand_as method expands the vector to be the same shape as m.

We don’t really copy the rows, but it looks as if we did. In fact, the rows are given a stride of 0.

Elixir: We aren’t sure whether the Nx.broadcast actually copies the rows or looks like it does ala PyTorch

# PyTorch
# t = c.expand_as(m); t
# --> 
# tensor([[10., 20., 30.],
#         [10., 20., 30.],
#         [10., 20., 30.]])

t = Nx.broadcast(c, m.shape)

I don’t believe the following tensor code has an Nx equivalent

# PyTorch
#  # Not sure there is an Nx equivalent
# t.stride(), # Not sure there is an Nx equivalent

In PyTorch, tou can index with the special value [None] or use unsqueeze() to convert a 1-dimensional array into a 2-dimensional array (although one of those dimensions has value 1).

The Nx equivalent is the Nx.reshape.

# PyTorch
# c
# -->
# tensor([10., 20., 30.])

Both unsqueeze and c[something, something_else] map to Nx.reshape. We’ll just show the Nx.reshape once.

This is how we create a matrix with one row

# PyTorch
# c.unsqueeze(0), c[None, :]
# --> (tensor([[10., 20., 30.]]), tensor([[10., 20., 30.]]))

Nx.reshape(c, {1, :auto})
# PyTorch
# c.shape, c.unsqueeze(0).shape
# --> (torch.Size([3]), torch.Size([1, 3]))

{c.shape, Nx.reshape(c, {1, :auto}).shape}
# c.unsqueeze(1), c[:, None]
# --> (tensor([[10.],
#          [20.],
#          [30.]]),
#  tensor([[10.],
#          [20.],
#          [30.]]))

# This is how we create a matrix with one column.
Nx.reshape(c, {:auto, 1})
# PyTorch
# c.shape, c.unsqueeze(1).shape
# --> (torch.Size([3]), torch.Size([3, 1]))

{c.shape, Nx.reshape(c, {:auto, 1}).shape}

In PyTorch, they can skip trailling ‘:’s. And ‘…’ means ‘all preceding dimensions’

# PyTorch
# c[None].shape,c[...,None].shape
# --> (torch.Size([1, 3]), torch.Size([3, 1]))

{Nx.reshape(c, {1, :auto}).shape, Nx.reshape(c, {:auto, 1}).shape}

Below, we are taking the vector, transforming into a matrix with one column then we broadcast the result into a matrix of m shape.

# PyTorch
# c[:,None].expand_as(m)
# --> tensor([[10., 10., 10.],
# [20., 20., 20.],
# [30., 30., 30.]])

Nx.reshape(c, {:auto, 1})
|> Nx.broadcast(m.shape)

As a reminder, in this case we are adding the vector to each row.

# PyTorch
# m + c
# -->
# tensor([[11., 22., 33.],
#         [14., 25., 36.],
#         [17., 28., 39.]])

Nx.add(m, c)

Here we are transforming the vector c into a matrix with one column and then broadcasting into the shape of m. Then we add the two matrices together.

# PyTorch
# m + c[:,None]
# --> tensor([[11., 12., 13.],
# [24., 25., 26.],
# [37., 38., 39.]])

Nx.add(m, Nx.reshape(c, {:auto, 1}) |> Nx.broadcast(m.shape))

Here we are transforming the vector c into a matrix with one row and then broadcasting into the shape of m. Then we add the two matrices together.

# PyTorch
# m + c[None,:]
# tensor([[11., 22., 33.],
#         [14., 25., 36.],
#         [17., 28., 39.]])

# Nx.add(m, Nx.reshape(c, {1, :auto}) |> Nx.broadcast(m.shape) )
Nx.add(m, Nx.reshape(c, {1, :auto}) |> Nx.broadcast(m.shape))

Broadcasting Rules

# PyTorch
# c[None,:]
# --> tensor([[10., 20., 30.]])

Nx.reshape(c, {1, :auto})
# PyTorch
# c[None,:].shape
# --> torch.Size([1, 3])

Nx.reshape(c, {1, :auto}).shape
# PyTorch
# c[:,None]
# --> tensor([[10.],
#         [20.],
#         [30.]])

Nx.reshape(c, {:auto, 1})
# PyTorch
# c[:,None].shape
# --> torch.Size([3, 1])

Nx.reshape(c, {:auto, 1}).shape

Here we are taking a vector, c, and reshaping into a matrix of one column. Then we take a vector, c, and reshaping into a matrix of one row. Then the multiply function with expand the one column into a 3 columns with the same values. The same thing happens for the matrix of one row. It expands into 3 rows.

We end up with 3 rows of 10,20,30 and 3 columns of




When we multiply them together and we get this answer. This is an outer product without any special function. Just broadcasting. Not just products, we can do outer boolean operations, etc.

# PyTorch
# c[None,:] * c[:,None]
# --> tensor([[100., 200., 300.],
#         [200., 400., 600.],
#         [300., 600., 900.]])

Nx.multiply(Nx.reshape(c, {1, :auto}), Nx.reshape(c, {:auto, 1}))

Here is the examples of the outer boolean.

# PyTorch
# c[None] > c[:,None]
# --> tensor([[False,  True,  True],
#         [False, False,  True],
#         [False, False, False]])

Nx.greater(Nx.reshape(c, {:auto}), Nx.reshape(c, {:auto, 1}))

When operating on two arrays/tensors, Numpy/PyTorch compares their shapes element-wise. It starts with the trailing dimensions, and works its way forward. Two dimensions are compatible when

Arrays do not need to have the same number of dimensions. For example, if you have a 256*256*3 array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3

The numpy documentation includes several examples of what dimensions can and can not be broadcast together.

Matmul using Nx

As a reminder, we defined these tensors further back in the notebook.

tr =, weights)

Using the default BinaryBackend, the above dot() function returns in about 28 seconds on our Linux computer. Not nearly as quick as the broadcast example in the course.

Let’s explore how this same matrix multiplication works for different backends next. To keep things simple and focused, we’ll stop this notebook here and create different notebooks to focus on Nx on XLA using the CPU and XLA using the GPU.


fastai, livebook, axon, foundations, matrix_multiplication, deep_learning