Fine-tuning BERT with Keras and tf.Module

Denis Antyukhov
Towards Data Science
7 min readNov 30, 2019

--

In this experiment we convert a pre-trained BERT model checkpoint into a trainable Keras layer, which we use to solve a text classification task.

We achieve this by using a tf.Module, which is a neat abstraction designed to handle pre-trained Tensorflow models.

Exported modules can be easily integrated into other models, which facilitates experiments with powerful NN architectures.

The plan for this experiment is:

  1. getting a pre-trained BERT model checkpoint
  2. defining the specification of the tf.Module
  3. exporting the module
  4. building a text preprocessing pipeline
  5. implementing a custom Keras layer
  6. training a Keras model to solve a sentence-pair classification task
  7. saving and restoring
  8. optimizing trained model for inference

What is in this guide?

This guide is about integrating pre-trained Tensorflow models with Keras. It contains implementations of two things: a BERT tf.Module and a Keras layer built on top of it. It also includes examples of fine tuning (see below) and inference.

What does it take?

For a reader familiar with TensorFlow it should take around 60 minutes to finish this guide. The code was tested with tensorflow==1.15.0 .

OK, show me the code.

The code for this experiment is available in Colab here. The standalone version can be found in the repository.

Step 1: getting the pre-trained model

We start with a pre-trained BERT-base checkpoint. For this experiment we will be using an english model pre-trained by Google. Naturally, you can use a model more suitable for your use-case when building your tf.Module.

Step 2: building a tf.Module

tf.Modules are designed to provide a simple way to manipulate reusable parts of pre-trained machine learning models in Tensorflow. Google maintains a curated library of such modules at tf.Hub. In this guide however, we will build one from scratch by ourselves.

To that end, we will implement a module_fn containing the full specification of the module inner workings.

We begin by defining input placeholders. The BERT model graph is created from a configuration file passed through config_path. Then we model outputs are extracted: the final encoder layer output is saved to seq_output and pooled ‘CLS’ token representation to pool_output.

Additionally, extra assets may be bundled with the module. In this example, we add a vocab_file containing the WordPiece vocabulary to the module assets. As a result, the vocabulary file will be exported with the module, which will make it self-contained.

Finally, we define signatures, which are particular transformations of inputs to outputs, exposed to consumers. One could think of it as a module interface with the outside world.

Here two signatures are added. The first takes raw text features as input and returns computed text representations as output. The other takes no inputs and returns the path to vocabulary file and lowercase flag.

Step 3: exporting the module

Now that the module_fn is defined, we can use it to build and export the module. Passing the tags_and_args argument to create_module_spec will result in two graph variants being added to the module: for training with tags {“train”} and for inference with an empty set of tags. This allows to control dropout, which is disabled at inference time, and enabled during training.

Step 4: building the text preprocessing pipeline

The BERT model requires that text is represented as 3 matrices containing input_ids, input_mask, and segment_ids. In this step we build a pipeline which takes a list of strings, and outputs these three matrices, as simple as that.

First of all, raw input text is converted into InputExamples. If the input text is a sentence pair, separated by a special ‘|||’ sequence, the sentences are split.

InputExamples are then tokenized and converted to InputFeatures using the convert_examples_to_features function from the original repository. After that, the list of features is converted to matrices with features_to_arrays.

Finally, we put it all together in a single pipeline.

All done!

Step 5: implementing a BERT Keras layer

There are two ways to use tf.Modules with Keras.

The first way is to wrap a module with hub.KerasLayer. This approach is straightforward but not very flexible, because it doesn’t allow to put any custom logic into the module.

The second way is to implement a custom Keras layer containing the module. In that case, we have full control over the trainable variables, and can add pooling ops or even the whole text preprocessing pipeline to the computational graph! We will go the second way.

To design a custom Keras layer we need to write a class that inherits from tf.keras.Layer and overrides some methods, most importantly build and call.

The build method creates assets of the module. It begins with instantiating the BERT module from bert_path which can be a path on disk or a http address (e.g. for modules from tf.hub). Then the list of trainable layers is built and the layer’s trainable weights are populated. Limiting the number of trainable weights to a couple of last layers significantly reduces the GPU memory footprint and accelerates training. It might also improve model accuracy, particularly on smaller datasets.

The build_preprocessor method retrieves the WordPiece vocabulary from the module assets to build the text preprocessing pipeline defined in Step 4.

The initialize_module method loads the module variables into the current Keras session.

Most of the fun stuff is happening inside the call method. As input, it accepts a Tensor of tf.Strings, which are transformed into BERT features using our preprocessing pipeline. The python code doing that is injected into the graph using the tf.numpy_function. The features are then passed into the module and the output is retrieved.

Now, depending on the pooling parameter set in __init__, additional transformations are applied to the output tensor.

If pooling==None, no pooling is applied and the output tensor has shape [batch_size, seq_len, encoder_dim]. This mode is useful for solving token level tasks.

If pooling==’cls’, only the vector corresponding to first ‘CLS’ token is retrieved and the output tensor has shape [batch_size, encoder_dim]. This pooling type is useful for solving sentence-pair classification tasks.

Finally, if pooling==’mean’, the embeddings for all tokens are mean-pooled and the output tensor has shape [batch_size, encoder_dim]. This mode is particularly useful for sentence representation tasks. It was inspired by the REDUCE_MEAN pooling strategy from bert-as-service.

The full listing for the BERT layer can be found in the repository.

Step 6: sentence pair classification

Now let us try the layer on a real-world dataset. For this part we will use the Quora Question Pairs dataset which consists of over 400,000 potential question duplicate pairs labeled for semantic equivalence.

Building and training a sentence-pair classification model is straightforward:

BTW, if you don’t like doing the preprocessing in-graph, you can disable it by setting do_preprocessing=False and build the model with 3 inputs instead.

Fine-tuning just the last three layers yields 88.3% validation accuracy.

Train on 323432 samples, validate on 80858 samples
Epoch 1/5
323432/323432 [==============================] - 3197s 10ms/sample - loss: 0.3659 - acc: 0.8255 - val_loss: 0.3198 - val_acc: 0.8551
Epoch 2/5
323432/323432 [==============================] - 3191s 10ms/sample - loss: 0.2898 - acc: 0.8704 - val_loss: 0.2896 - val_acc: 0.8723
Epoch 3/5
323432/323432 [==============================] - 3231s 10ms/sample - loss: 0.2480 - acc: 0.8920 - val_loss: 0.2833 - val_acc: 0.8765
Epoch 4/5
323432/323432 [==============================] - 3205s 10ms/sample - loss: 0.2083 - acc: 0.9123 - val_loss: 0.2839 - val_acc: 0.8814
Epoch 5/5
323432/323432 [==============================] - 3244s 10ms/sample - loss: 0.1671 - acc: 0.9325 - val_loss: 0.2957 - val_acc: 0.8831

Step 7: saving and restoring the model

The model weights can be saved and restored by usual means. The model architecture can also be serialized to json format.

Rebuilding the model from json will work, provided that the relative path to the BERT module does not change.

Step 8: optimizing for inference

In some cases (e.g. when serving), one might want to optimize the trained model for maximum inference throughput. In TensorFlow this can be achieved by “freezing” the model.

During “freezing” the model variables are replaced by constants, and the nodes required for training are pruned from the computational graph. The resulting graph becomes more lightweight, requires less RAM and achieves better performance.

We freeze our trained model and write the serialized graph to file.

Now let’s restore the frozen graph and do some inference.

To run inference we need to get the handles for input and output tensors of the graph. This part is a little tricky: we retrieve a list of all operations in the restored graph and then manually get the names of relevant ops. The list is sorted, so in this case it is enough to take the first and the last operation.

To get the Tensor name we append “:0” to the op name.

The preprocessing function we injected into the Keras layer is not serializable and was not restored in the new graph. No worries though — we can simply define it again with the same name.

Finally, we get the predictions.

array([[9.8404515e-01]], dtype=float32)

Conclusion

In this experiment we created a trainable BERT module and fine-tuned it with Keras to solve a sentence-pair classification task.

By freezing the trained model we have removed it’s dependancy on the custom layer code and made it portable and lightweight.

Other guides in this series

  1. Pre-training BERT from scratch with cloud TPU
  2. Building a Search Engine with BERT and Tensorflow
  3. Fine-tuning BERT with Keras and tf.Module [you are here]
  4. Improving sentence embeddings with BERT and Representation Learning

--

--