Get started with Google Trax for NLP

How to use Google Brain’s newest library to easily train and build models for Natural Language Processing

Tiago Duque
Towards Data Science

--

Trax is the newest book to the already thriving library of NLP. Image by Alfons Morales at Unsplash.

After reading this article you’ll be able to:

  • Understand the idea behind Trax Library
  • Understand how to write easy data preprocessing pipelines with Trax.
  • Use Trax pretrained models and build your own Models from scratch.
  • Train a Sentiment Analysis model using a Deep Neural Network with Trax.
  • Learn how to predict the result from new inputs.
  • Load checkpoints, model weights and convert Trax models to Keras for deployment.
  • Have access to tutorials showing how to train a Neural Machine Translator and a Named Entity Recognition system with Trax.

Machine Learning Engineers and Data Scientists already have a handful of tools to build and deploy the most up-to-date models for Natural Language Processing.

Among these tools, we could mention Pytorch and Tensorflow 2.0 as strong bases, building up from the most basic elements for Neural Networks and by providing useful high-level tools, such as the Keras library, that is now part of Tensorflow 2.0.

Aside from that, there’s 🤗 Transformers, which makes the life of NLP practitioners a tad easier by providing high-level access to many pretrained models at either Pytorch or Tensorflow format.

With all these, it seems there’s no need for anything else, right? But Google Brain team, the one behind BERT and the Reformer, decided that the answer is no. Enters Trax.

In the maintainers own words,

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team.

Looking more carefully, you’ll note that Trax actually is the “interface” that runs over more complex resources, such as Jax (from which Trax takes its name) or Tensorflow-numpy as backends.

If you’re familiar with Keras, you’ll get to see that Trax looks a lot like his older cousin (that’s clear in the fact that Trax models can be easily converted to Keras models). So, again, what’s the benefit in Trax? The answer is at the words mentioned above: Trax has been written from scratch for better readability and speed.

Chances are that, if you’re the average python developer, you’ll be able to understand Trax’s source. Want some example? Look at the code for the init_weights_and_state method for the Serial layer:

Clear and comprehensible code is Trax’s best. You could even learn theory at just looking at the code!

Besides being clear and attempting to use python default structure’s whenever possible, it is all well documented, some docstrings even having examples.

Sure, you’ll lack the amount of support and courses that Tensorflow, Keras or Pytorch have, but you’ll be able to figure most of the code out by your own (as I did while writing this article).

More than an in-detail approach, this article will attempt to give you a shortcut to some of the most useful resources in Trax for NLP, as well as providing some examples.

As a disclaimer, I have nothing to do with Google Brain Team and my first contact with Trax was at Deeplearning.AI NLP Specialization at Coursera (which I highly recommend).

Before proceeding, all the code and samples will also be available in a public Google Colab, which can be accessed below:

Without further ado, let us get started with some interesting features that Trax brings to NLP.

Important Note: when I wrote this article, Trax’s latest release was version 1.3.6. From recent commits, some significant changes are being added, especially to the metric layers and other useful resources. However, since these are experimental, I decided to cover only what is available in the current release (out in October 2020).

Basic Imports

First, let us understand the imports for the main modules in Trax. For preprocessing, basic model designing and training, we’ll need the following imports:

Also, as mentioned, Trax can use distinct backends (to do tensor computation). For now, there are two options: Jax and Tensorflow. Jax is being called “numpy on steroids” since it uses more modern techniques to speed up Numpy and Python computations. It is the default for Trax, but you can set it as such (bonus example on how to set TPU for Jax in colab):

Preprocessing with Trax

Trax provides us easy to use resources to implement NLP Pipelines (actually, any preprocessing pipeline). This is where the data module comes in.

For example, Trax allows us to use many Tensorflow utilities, such as Tensorflow Datasets. We can download any dataset available in TFDS very easily (and use the tool utilities to prepare our own datasets).

We can, for example, get the ‘imdb_reviews’ dataset to be used for sentiment analysis.

To access its contents, we can:

Results:

The text, up to 100 characters, is: This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. The sentiment is: Negative

Trax allows for easy to use and configure preprocessing pipelines.

To do it, we implement a special type of object: the Serial Layer (for data, there’s another one for modeling). We can import it from trax.data.Serial, and this one allows us to process data one function at a time, in a serial manner.

We can also make a preprocessing pipeline and then feed the data generator to it. This way, we can do many important tasks, such as:

Just remember that each of your functions takes as input the same format as the output of the previous function (eg.: (1) takes a string and returns a list; (2) takes a list and returns a list, etc.)

(generator_sample)->(1)->(2)->…->(ml_output)

For you to ‘plug’ your pipeline into Trax algorithms, the expected format is batch_input, batch_expected_output, mask weights (if any).

In the following example, we show a real pipeline that makes use of some useful preprocessing steps provided by Trax itself.

Here’s how we test it:

Example sentence first 10 tokens: [1863   58 5459    6  292   30   43 6622   32 2172]... 
Example first 10 tokens after detokenization: What starts out as an interesting story...
Sentiment: Negative

If you want an example on how to include your own preprocessing steps into the pipeline, check the colab notebook.

Creating Models

Now that we’ve seen preprocessing, it’s time to move into Modeling itself.

Trax allows the use of models in two ways:

  • Predefined models, such as:
   - Seq2Seq with Attention
- BERT
- Transformer
- Reformer
  • And custom models with a combination of layers.

Lets peek into one of Trax pre-made models, the LSTMSeq2SeqAttn:

Here are a few lines of the result:

Serial_in2_out2[
Select[0,1,0,1]_in2_out4
Parallel_in2_out2[
Serial[
Embedding_8183_512
LSTM_512
LSTM_512
]
Serial[
Serial[
AssertShape
ShiftRight(1)
AssertShape
]
Embedding_10000_512
LSTM_512
]
]
PrepareAttentionInputs_in3_out4
Serial_in4_out2[
Branch_in4_out3[
None
Serial_in4_out2[...

That’s a lot of things, right? Each of these ‘names’ appearing there are the network layers, accompanied by their description (Embedding_10000_512 is an Embeddings layer with a 10000 vocab size and 512 embedding dimension).

Layers are like Lego Blocks in Trax. Image by Xavi Cabrera at Unsplash.

Layers are the LEGO blocks of Trax!

With these layers we can create our own models as well.

Each layer is a function (or a whole bunch of functions bundled together) that gets some input and returns output in the promised format. Layers always have two important parameters: Expected Input (n_in) and promised Output (n_out).

Layers can be combined using what are called “combiner layers”.

By default, there are 3 “combiner” layers: “Serial” (sequential model), “Branch” (parallel model) and “Residual” (kind of a merge between Serial and Branch).

Example:

We can use trax predefined layers (such as LSTM Cells, Fully connected [Dense] layers, etc), but we can even make our own layers. In the notebook, I’ve shown how to “implement” a modified (yet useless) version of the Dense Layer, as a simple example. This means that you can easily implement novel algorithms on your own and integrate them to other preexisting layers or to use Trax facilities for training/preprocessing pipeline.

Let us now see how to build an entire model with trax layers (this is a simple neural network to perform Sentiment Analysis that we’ll use later for training). We’ll use a Serial Layer as a combiner and the following layers:

  • Embedding: creates a layer that maps a word in the vocab to a values vector (the word embeddings).
  • Mean: creates the average over one axis. In the case below, we’re averaging on the length of the sentence to crudely “summarize” the input words.
  • Dense: Fully connected layer. This is the basic structure of Neural networks. In our case, all outputs from our Mean layer are fed into the two neurons, which will activate to a certain threshold based on inputs.
  • LogSoftmax: A layer that outputs the result of the logsoftmax function from the outputs of the dense layer. This layer will tell the pertinence of the activation to each of the two classes (it more or less tells which of the two classes are more probable).

Output:

Serial[
Embedding_8183_256
Mean
Dense_2
LogSoftmax
]

Much simpler right? But it works for the basic. Now, how train it?

Training Models

So we already seen the way pipelines are created and how to create/use models.

Now, it is time for us to do the trick and put both of them to work together.

Let us put the machine to learn!

A machine, learning… Image by Brett Jordan at Unsplash.

Remember, there are two main methods of Machine Learning:

  • Supervised (you provide the inputs and expected outputs for the model to train)
  • Unsupervised (the model trains only using the input, evaluating the output from some inner metric)

We’re only covering supervised models here.

To do Supervised Training in Trax, we have to define three important ‘blocks’:

  1. The training task, which takes as input the labeled data, the loss layer, the optimizer and the number of steps between checkpoints.
  2. The evaluation task, which takes as input the labeled data, the metrics and the number of eval batches. This is important since it tells how good our model is at generalizing.
  3. The training loop, which puts all of this together.

We get all that from trax.supervised (which we imported as ts in the beginning).

Before, though, let us use those pipelines we’ve learned before to build the pipelines for training and evaluation data.

Now we can define the training/eval tasks and the loop.

Just some important notes before proceeding.

About the training and eval tasks:

About the Training loop:

  • We create an output dir for the weights and checkpoints (this will save a checkpoint file that we can use to resume the training later).
  • We plug the model and the train/eval tasks.
  • We run it by using training_loop.run =)

The so expected output!

Step      1: Total number of trainable weights: 2095362 
Step 1: Ran 1 train steps in 2.52 secs
Step 1: train CrossEntropyLoss | 0.69192028
Step 1: eval CrossEntropyLoss | 0.67562409
Step 1: eval Accuracy | 0.57187500
Step 200: Ran 199 train steps in 39.13 secs
Step 200: train CrossEntropyLoss | 0.62404096
Step 200: eval CrossEntropyLoss | 0.51835893
Step 200: eval Accuracy | 0.71718750
Step 400: Ran 200 train steps in 37.59 secs
Step 400: train CrossEntropyLoss | 0.44796222
Step 400: eval CrossEntropyLoss | 0.42682799
Step 400: eval Accuracy | 0.82812500
[...]
Step 2000: Ran 200 train steps in 45.79 secs
Step 2000: train CrossEntropyLoss | 0.25747755
Step 2000: eval CrossEntropyLoss | 0.33661298
Step 2000: eval Accuracy | 0.84687500

We first get the size of the model in the stdout. 2 million weights is bigger than we thought, right? Then, we get the outputs at every checkpoint mark we’ve set in the Train Task.

After some 10 minutes, It gets to about ~85% accuracy. This is not to bad for for a simple model, right?

Okay, now how to use it?

Predicting from new inputs

So we got a trained model, how do we use it?

Simple! Just feed a tokenized input to the model!

But, some words of caution before: Trax models (as all current deep learning frameworks) expects the input to come with a batch dimension besides the expected input dimensions. So we have to wrap our sample around that.

Let us try it out.

The output:

Input review: "I loved the way that the actors were cast, also, It is clear that they've put a huge effort in post-production." 
The sentiment is: Positive

It got it right! And will do often, but this is a simple model and way below the current state of the art. For this, you will need to use other models and, preferably, finetune pretrained models!

Resuming training and loading models from file

First, let us see how to restore a checkpoint. This enables us to retake training from a certain point, which is useful if you want to train for a really long time, if for some reason the running session crashes or if you want to test new parameters to the training and not lose all the work.

Remember that we told Trax to save at each checkpoint (each 200 steps)? Now we can restore the model to continue training:

Output:

Step   2200: Ran 200 train steps in 47.54 secs
Step 2200: train CrossEntropyLoss | 0.28513467
Step 2200: eval CrossEntropyLoss | 0.33272305
Step 2200: eval Accuracy | 0.87812500

(Be careful, this can cause the model to overfit)

We can also load a pretrained model. This allows us to use models that have been trained before and are to be used in production:

The sentiment is:  Negative

Great! So this is how we can save our training to use in production, right?

Deploying to Production

Well, the thing is: Trax is currently set to be a good tool for developing and training fast and SOTA models. It’s a young project (its first commit was in November 2019) and lacks some perks of the older cousins, such as a native deploy-to-production set of tools.

But while there’s no current native support for deploying to Tensorflow Serving, for example, there are native tools to turn Trax models into Keras models, which in turn can be deployed into Tensorflow Serving! So, all in all, you get the cleanness of code and performance of Trax allied to the maturity of Keras and Tensorflow.

That’s another easy step to implement. But you have to keep in mind that (currently) Trax can only be converted to Keras if it’s backend is set to tensorflow-numpy (no straight Jax conversion). But that is very easy to do.

Start of by changing the backend:

Next, train the model for a single step over the new backend (this will implicitly convert the backend structures to the new format — expect this step to be unnecessary soon):

To make the conversion, it is very simple:

<trax.trax2keras.AsKeras object at 0x7fa52c3c0dd8> 
Keras returned sentiment activations: [[-2.0911603 -0.13186848]]

Works!

Now, you can do what you’re used to with Keras and add this layer to a full model:

With this, we can save our model to use with Tensorflow Serving as easy as:

That’s it for this article, however it is not all!

In the Colab Notebook that I provided you, there are two other examples that might be useful: how to Train a Neural Machine Translator with Transformer and how to Train a Named Entity Recognition System with Reformer.

Be sure to check these examples since they cover other options that are not in the text. Also, I did a lot of work to make it easier to understand than Trax original examples. Maybe you could keep this like a companion cheatsheet to learn Trax 😉.

--

--

A Data Scientist passionate about data and text. Trying to understand and clearly explain all important nuances of Natural Language Processing.