Jax — Numpy on GPUs and TPUs

Understanding the library and implementing an MLP from scratch

Tiago Toledo Jr.
Towards Data Science

--

Photo by Lucas Kepner on Unsplash

There are many Python libraries and frameworks as they are stars in our sky. Okay, maybe not that much but surely there is a lot of options one can choose from when dealing with any given task.

Jax is one of these libraries. It has become really popular in the last few months as a base framework to develop Machine Learning solutions, especially after being used a lot by the guys on the Deep Mind.

For any Data Scientist, more important than the tool is the base knowledge of the task at hand. However, having a good knowledge of the tools available may save a lot of time and make us more productive researchers. After all, we need to be able to run our hypothesis fast and securely if we aim to achieve good business results.

So, in this post, I will talk about Jax, explain what it is, why I think one should be familiar with him, its advantages, and how one can use it to implement a simple Multilayer-Perceptron.

By the end of the post, I hope you’ll have one more tool on your toolbox that may be useful in your daily work.

All the code for this post is available on Kaggle and on my Github.

What is Jax

Jax is a numerical/mathematical library very similar to the good old know Numpy. It was developed by Google with one objective in mind: to make Numpy easier to use and faster when dealing with typical Machine Learning tasks.

To achieve this result, Jax has some built-in implementations that allow for high-scale processing, such as parallelization and vectorization, for faster execution, such as Just-in-Time Compilation, and for easier Machine Learning algebra, such as autograd.

With that, one can natively speed up their Machine Learning pipelines without having to worry about writing too much code. If you were to parallelize your implementation using the multiprocessing library you know that this can quickly become overwhelming.

But, more interesting than that, is the ability of Jax to auto-compile your code directly on accelerators such as GPUs and TPUs without the need for any modification. The process is seamless.

This means you can write your code once, using a syntax similar to the Numpy one, test it on your CPU and then ship it to a GPU cluster without having to worry about anything.

Now, one may ask the very reasonable question: how does Jax is different from TensorFlow or Pytorch? Let’s look into that.

Jax x TensorFlow x Pytorch

TensorFlow and Pytorch have been on the playground for a very long time. They are fully working deep learning libraries with implementations to develop deep learning pipelines end-to-end.

Jax has no interest in being a full deep learning library. It aims to be the Numpy for accelerators. Therefore, you will not see Jax implementing data loaders or model validators the same way you shouldn’t expect Numpy to do that either.

But there are deep learning libraries being implemented in Jax as we speak. They aim to use these nice functionalities from the library to create faster and cleaner implementations. As one may know, TensorFlow and Pytorch both suffer from some technical debt that makes it harder to make some things efficient on them.

I do believe that frameworks based on Jax will become more prominent in our industry in the future, so I think knowing the basics of Jax is a good step to be on your feet when the standard libraries (inevitably) changes.

Some Details about Jax

Jax aims to be as close as possible to Numpy. There is a low-level library implementation but this is way beyond the scope of this post. For now, we will consider that we will use it exactly like a ‘Numpy for GPUs’.

To start, we will install Jax:

pip install --upgrade "jax[cpu]"

This command will only install the CPU support for us to test our code. If you want to install the GPU support, use:

pip install --upgrade "jax[cuda]"

Notice that you must have CUDA and CuDNN already installed for that to work.

Then, we will import the Numpy interface and some important functions as follows:

import jax.numpy as jnpfrom jax import random
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

We must now get a hold of some of the different aspects of Jax. The first one is how random number generators are handled.

Random Numbers in Jax

Usually, when we are dealing with Numpy, we set a random state integer (say 42) and use it to generate our random numbers for our program. Traditional machine learning libraries such as scikit-learn use this paradigm.

This works well when we are dealing with sequential executions, however, if we start to run our functions in parallel, this becomes an issue. Let me show you how:

  • Let’s define two functions, bar and baz. Each one will return a random number.
  • Let’s define a function foo who will make a computation on the results from the two previous functions

In code:

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

Here we set the random state with a seed equal to zero. If you run this code twenty times, you will get the same result all of those times, because the random state is set.

But what if we call baz before bar?

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return 2 * baz() + bar()

print(foo())

Yes, the result is different. This is because the order of the execution of the functions is not the same anymore. The guarantee from the random state is that the same execution will yield the same result, in this case, this is not true.

Now you can see how this becomes a problem when trying to parallelize all of our functions. We cannot guarantee the order of the executions and therefore, there is no way of enforcing reproducibility of the results we are getting.

The way Jax solves this is by defining Pseudo-Random Number Generators Keys, as follows:

random_state = 42
key = random.PRNGKey(random_state)

Every random function inside Jax must receive a key, and this key must be unique for each function. This means that, even if the order of the execution changes, the result will be the same because the key we are using is the same.

Then we need to create a list of keys, one for each function? Well, that would be really cumbersome, so Jax implements a handy method called split which receives a key and split it into a required number of subkeys we can then pass to our functions:

# Here we split our original key into three subkeys
random.split(key, num=3)

This will return a list with 3 keys. This way we can assure the reproducibility of our results during our executions.

Automatic Differentiation

One of the main objectives of Jax is to automatically differentiate native Python and NumPy functions. When dealing with machine learning, this is required since most of our optimization algorithms use the gradients of our functions to minimize some loss.

Doing differentiation in Jax is very straightforward:

def func(x):
return x**2
d_func = grad(func)

Just like that. Now, the function d_func will return the derivative of the func when you pass a value x to it.

Now, the neat thing is that you can apply the grad several times to yield higher-order derivatives. If we were to create a function that returns the second derivative of func, we would simply do:

d2_func = grad(d_func)

Just-in-Time Compilation

Python is an interpreted language. There are other languages, such as C, that are called compiled languages. On a compiled language, the code is read by a compiler, and machine code is generated. Then, this machine code is executed when you call the program.

An interpreted language offers some advantages to the developer, such as not needing to set the data types of the variables. However, since the code is not compiled, it is usually slower than a compiled language.

Some Python libraries such as Numba implement what is called Just-in-Time (JIT) Compilation. With this, the first time the interpreter runs a method, it compiles it to machine code so subsequent executions will run faster.

If you have a method that will run 10 thousand times (such as a gradient update during a training loop), JIT Compilation can greatly improve the performance of your code. And doing it in Jax is very simple:

def funct(x):
return x * (2 + x)
compiled_funct = jit(funct)

Notice that not every function can be compiled. I suggest you read the Jax documentation for a full understanding of the limitations of this method.

We will see later on the post the performance improvement we get from this.

Vectorization

Vectorization is a process in which our operations, which usually happen with some unit (an integer for example) are applied into vectors, which allows these computations to happen in parallel.

This process can yield great performance improvements in our pipelines and the Jax library has a built-in function called vmap that receives a function and automatically vectorizes it to us.

We will see an example of how we can apply that in our for loop in the next section when we implement the MLP.

Implementing an MLP in Jax

Now, let’s implement an MLP in Jax to exercise what we learned about the library. To help us, we will load the MNIST dataset from the TensorFlow Data Loader. This dataset is free to use:

import tensorflow as tf
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'

mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)

train_data, test_data = mnist_data['train'], mnist_data['test']

num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

train_images, train_labels = train_data['image'],train_data['label']
test_images, test_labels = test_data['image'], test_data['label']

Here we are just downloading the data and splitting it into train and test for us to train our model later.

Let’s create a helper function to one-hot encode our targets:

def one_hot(x, k, dtype=jnp.float32):
"""
Create a one-hot encoding of x of size k.

x: array
The array to be one hot encoded
k: interger
The number of classes
dtype: jnp.dtype, optional(default=float32)
The dtype to be used on the encoding

"""
return jnp.array(x[:, None] == jnp.arange(k), dtype)

Now, let’s encode our labels and transform our images to jnp tensors:

train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

Finally, we will define a function that will yield batches of data for us to use on the training loop:

def get_train_batches(batch_size):
"""
This function loads the MNIST and returns a batch of images given the batch size

batch_size: integer
The batch size, i.e, the number of images to be retrieved at each step

"""
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
ds = ds.batch(batch_size).prefetch(1)
return tfds.as_numpy(ds)

Now that our data is ready, let’s start our implementation.

Initializing the Parameters

First, we will initialize the weights for each layer of our MLP. For this post, we will start them randomly. To do that, we must use the Pseudo-Random Number Generators Key to guarantee that all of our executions will be reproducible.

def random_layer_params(m, n, key, scale=1e-2):
"""
This function returns two matrices, a W matrix with shape (n, m) and a b matrix with shape (n,)

m: integer
The first dimension of the W matrix
n: integer
The second dimension of the b matrix
key: PRNGKey
A Jax PRNGKey
scale: float, optional(default=1e-2)
The scale of the random numbers on the matrices
"""
# Split our key into two new keys, one for each matrix
w_key, b_key = random.split(key, num=2)
return scale * random.normal(w_key, (m,n)), scale * random.normal(b_key, (n,))

This method receives the number of neurons for that layer and the number of neurons for the layer after it. Also, we pass the key so we can split it.

Now, let’s create a function that receives a list with the sizes of the layers (the number of neurons) and uses that random layer generator to populate all the layers with random weights:

def init_network_params(layers_sizes, key):
"""
Given a list of weights for a neural network, initializes the weights of the network

layers_sizes: list of integers
The number of neurons on each layer of the network
key: PRNGKey
A Jax PRNGKey
"""
# Generate one subkey for layer in the network
keys = random.split(key, len(layers_sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(layers_sizes[:-1], layers_sizes[1:], keys)]

Prediction Function

Now we will create a function that, given an image and the weights, will output a prediction. For that, we will first define a ReLU function:

def relu(x):
return jnp.maximum(0, x)

Now, for every layer we must apply the weights to the image, sum up the bias, apply the ReLU to the result and propagate that activation for the next layers, so the method will look something like this:

def predict(params, x):
"""
Function to generate a prediction given weights and the activation

params: list of matrices
The weights for every layer of the network, including the bias
x: matrix
The activation, or the features, to be predicted
"""
activations = x

for w, b in params[:-1]:
output = jnp.dot(w.T, activations) + b
activations = relu(output)

final_w, final_b = params[-1]
logits = jnp.dot(final_w.T, activations) + final_b

return logits - logsumexp(logits)

Now, notice that we created this function in a way that it will work for only one image at a time. We cannot pass a batch of 100 images to this because the dot product will break because the shapes will not match.

This is where the vmap function will come in handy. We can use it to automatically allow our predict function to work with batches of data. This is done with the following line of code:

batched_predict = vmap(predict, in_axes=(None, 0))

The first argument is the function we want to apply the vectorization to. The second one is a tuple with a value for each input argument of the original function and states in which axis the batch should be propagated on.

So, the tuple (None, 0) means that we should not batch the first parameter (the weights), but should batch on the rows (axis 0) of the second one, the image.

Loss and Accuracy

Now let’s define two simple functions, one to calculate the accuracy of our model and another one to compute our loss:

def accuracy(params, images, targets):
"""
Calculates the accuracy of the neural network on a set of images

params: list of matrices
The weights for every layer of the network, including the bias
images: list of matrices
The images to be used on the calculation
targets: list of labels
The true labels for each of the targets

"""
target_class = jnp.argmax(targets, axis=1)

# Predicts the probabilities for each class and get the maximum
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)

return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)

Update Function

We’re getting there! Now, we must create the update function. On our MLP, the update will change the weights based on a step size and the gradient of our loss. As we saw, Jax will help us do that easily with the grad function:

def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]

Training Loop

Now we are ready to run our training loop. Let’s define the number of neurons for each layer, the step size of the gradient, the number of epochs for the train, and the step size:

layer_sizes = [784, 512, 512, 10]

# Training parameters
step_size = 0.01
num_epochs = 10
batch_size = 128

# Number of labels
n_targets = 10

Now, let’s initialize our weights:

params = init_network_params(layer_sizes, random.PRNGKey(0))

Now, let’s loop over the epochs and train our network:

for epoch in range(num_epochs):
for x, y in get_train_batches(batch_size):
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)

train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}\n".format(test_acc))

This is it. After the loop finishes, we will have successfully trained an MLP in Jax!

Performance Results

On the previous code, I didn’t use the jit function anywhere. But, as I said, this can greatly improve our results especially if some calculation happens a lot, such as the update of our parameters.

If we were to apply the JIT to our code, we would create a jitted version of the update function as follows:

jit_update = jit(update)

And then, we would change the last line of the inner for loop to:

params = jit_update(params, x, y)

The following image shows the resulting time elapsed per epoch for each case:

Elapsed time per epoch with and without JIT. Developed by the author.

As we can see, on average, we have a 2x decrease in training time!

Conclusion

Jax is a new kid on the block but already shows some great potential. It is good to get to know the new possibilities that our area brings us every day.

I will bring some posts introducing some of the Deep Learning libraries based on Jax in the future so we can have a better overview of this new ecosystem that is being created.

Hope that this has been useful!

This post is highly based on the Jax documentation.

--

--