The world’s leading publication for data science, AI, and ML professionals.

Multi-class Text Classification using BERT and TensorFlow

A step-by-step tutorial from data loading to prediction

Photo by Alfons Morales on Unsplash
Photo by Alfons Morales on Unsplash

Table of contents

  1. Introduction
  2. Data preparation 2.1 Load the dataset 2.2 [Optional] Observe random samples 2.3 Split in train and test set

  3. Data modeling 3.1 Load BERT with TensorfFlow Hub 3.2 [Optional] Observe semantic textual similarities 3.3 Create and train the classification model 3.4 Predict 3.5 Blind set evaluation

  4. [Optional] Save and load the model for future use
  5. References

1. Introduction

In this post, we will develop a multi-class text classifier.

The task of classification refers to the prediction of a class for a given observation. For this reason, the only needed input to train such a model is a dataset composed of:

  • Text samples
  • Associated labels

For the purpose of this post, we need to know that Bert¹ ** (Bidirectional ** Encoder Representations from Transformers) is a Machine Learning model based on transformers², i.e. attention components able to learn contextual relations between words. More details are available in the referenced papers.

Interestingly, we will develop a classifier for non-English text, and we will show how to handle different languages by importing different BERT models from TensorFlow Hub.

2. Data preparation

2.1 Load the dataset

In a previous post³, we analyzed the readers’ comments open dataset from Rome’s libraries made publicly available by "Istituzione Biblioteche di Roma"⁴. In absence of labels, we used topic modeling (an unsupervised technique) to find recurring themes among readers’ comments, and thus determine, by inference, the subjects of the borrowed books and the interests of the readers.

We will now use the topics emerged from the previous analysis as labels to classify users comments.

Notably, the steps that follow may be applied to any dataset containing at least two columns, for text samples and their labels, respectively:

Image by author.
Image by author.

As the dataset contains users comments from the libraries of Rome, their language is Italian. We randomly sampled five topics from the previous analysis³, each corresponding to a label. Their distribution is as follows:

  1. Reviews about the condition of women in society, or novels with strong female protagonists (n=205, 25.5%)
  2. Reviews of albums and concerts, or biographies of musicians (n=182, 22.64%)
  3. Reviews of books and essays about economics and socio-political conditions (n=161, 20.02%)
  4. Reviews related to Japan or the Japanese culture (n=134, 13.67%)
  5. Reviews about scientific and technical divulgation essays (n=122, 15.17%)

    Image by author.
    Image by author.

For the purpose of classification, we need numeric labels. Therefore, we map the topics descriptions to integers as follows:

Image by author.
Image by author.

2.2 [Optional] Observe random samples

Although it is not strictly necessary to the goal of text classification, we might want to inspect some random samples from different topics to develop a better understanding of the data.

Therefore we define a function that takes as input a filtering condition on a column, and prints a random review that satisfies the condition, together with an English translation for readability:

Now, let us observe some samples. We can see some customers reviews about women and their condition in society by calling print_rand_example(df, "Labels", 1):

Output of print_rand_example(df, "Labels", 1). Image by author.
Output of print_rand_example(df, "Labels", 1). Image by author.

When reading the first sentences of the review, the topic description previously manufactured through the unsupervised approach seems reasonable. We can look at another random sample from the same topic with the same call to print_rand_example(df, "Labels", 1):

Image by author.
Image by author.

What about the Japanese-related topic? Let us find out through print_rand_example(df, "Labels", 3):

Image by author.
Image by author.

2.3 Split in train and test set

3. Data modeling

3.1 Load BERT with TensorFlow Hub

TensorFlow Hub is a repository of trained Machine Learning models⁵.

A data scientist might conveniently load large and complex pre-trained models from TensorFlow Hub and re-use them as needed.

Interestingly, as we search for "bert" on TensorFlow Hub, we may also apply filters such as the problem domain (classification, embeddings, …), architecture, language – and more, to ease the retrieval of the model that better suits our needs:

Search results for "bert" on TensorFlow Hub⁵. Image by author.
Search results for "bert" on TensorFlow Hub⁵. Image by author.

In this example, we make use of the universal-sentence-encoder-cmlm/multilingual-base⁶ model, a universal sentence encoder that supports more than 100 languages. It is trained using a conditional masked language model, as described in the reference paper⁷.

What we want to achieve is to turn text into high-dimensional vectors that capture sentence-level semantics. Therefore, we proceed by loading the preprocessor and the encoder layers from the endpoints provided by TensorFlow Hub, and define a simple function to get the embeddings from input text.

The important takeaway is that one may choose to import any model of preference depending on the task and the input language.

As the model is based on the BERT transformer architecture, it will generate a pooled_output (output embedding of the entire sequence) of shape [batch size, 768], as displayed in the following example:

Image by author.
Image by author.

3.2 [Optional] Observe semantic textual similarities

As the embeddings provide a vector representation of sentence-level semantics, we might want to observe the similarities between different text sequences.

To this aim, we plot the semantic textual similarity over different text samples calculated through the cosine similarity⁸ as follows⁹:

For example, we expect that two almost identical sentences such as:

  • "Il libro è interessante" (The book is interesting)
  • "Il romanzo è interessante" (The novel is interesting)

will display a high semantic similarity, and we also expect that the two sentences will share a comparable similarity measure with a third sentence having a different meaning, and this is exactly the case:

Image by author.
Image by author.

Due to the nature of the model, we can effectively estimate the semantic similarity between sentences in different languages:

Image by author.
Image by author.

3.3 Create and train the classification model

As we are facing a multi-class classification problem, and we previously noticed that our topics distribution is slightly imbalanced, we might want to observe different metrics during model training.

For this reason, we define functions to calculate, respectively, precision, recall, and F1 score for each class during training, and then return the average value over the classes:

Definition of precision, recall and F1 score. Image modified from Wikipedia¹⁰.
Definition of precision, recall and F1 score. Image modified from Wikipedia¹⁰.

We now define a model as the preprocessor and encoder layers followed by a dropout and a dense layer with a softmax activation function and an output space dimensionality equal to the number of classes we want to predict:

Once we have defined the model’s structure, we can compile and fit it. We choose to train the model for 20 epochs, but we also use the EarlyStopping callback in order to monitor the validation loss during training: if the metric does not improve for at least 3 epochs (patience = 3), the training is interrupted and the weights from the epoch where the validation loss showed the best value (i.e. lowest) are restored (restore_best_weights = True):

Training logs. Image by author.
Training logs. Image by author.

We can finally plot the values assumed by each monitored metric during the training procedure, and compare the training and validation curves:

Training history. Image by author.
Training history. Image by author.

3.4 Predict

The model is now ready, and it is time to test some predictions:

Image by author.
Image by author.

Given these three input sentences, we would expect the model to predict, respectively, the topic ids 3 (Japan/Japanese/..), 1 (Woman/Women/..) and 0 (Economy/Politics/..).

In order to test our assumption, we define a simple function wrapping model.predict. In particular, as we are attempting a classification between five possible labels, model.predict will return a numpy.ndarray of size five. The softmax activation function that we used in the last layer of the model provides, indeed, the discrete probability distribution over the target classes. Therefore, we can simply take the index associated to the higher probability (np.argmax) to infer the predicted label:

The above snippet produces the expected result: [3,1,0].

3.5 Blind set evaluation

We initially split the dataset in train and test set, but we used both during the training and validation procedure.

In order to fairly estimate our performances, we evaluate the quality of the predictions on a new dataset containing observations that were not "seen" by the model during training (blind set):

Test set for blind evaluation. Image by author.
Test set for blind evaluation. Image by author.

We can observe the topics distribution by using the same code used in the data preparation section:

Image by author.
Image by author.

By following the same steps used to prepare the training and validation sets, we can map topic descriptions to numeric labels and inspect some random samples.

For example, let us check a review from topic with id 1: print_rand_example(test_set, "Labels", 1)

Image by author.
Image by author.

Let us also observe a review from topic with id 3: print_rand_example(test_set, "Labels", 3)

Image by author.
Image by author.

We can now test the blind set performances:

Image by author.
Image by author.

4. [Optional] Save and load the model for future use

This task is not essential to the development of a text classification model, but it is still related to the Machine Learning problem, as we might want to save the model and load it as needed for future predictions.

By calling model.save, one may save the model in the SavedModel¹¹ format, which comprehensive of the model architecture, weights, and the traced TensorFlow subgraphs of the call functions. This enables Keras to restore both built-in layers as well as custom objects:

We can now load the model as needed for future use:

That’s it!

5. References

[1] Devlin, Jacob; Chang, Ming-Wei; Lee, Kenton; Toutanova, Kristina, "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", 2018, arXiv:1810.04805v2

[2] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, "Attention Is All You Need", 2017, arXiv:1706.03762

[3] https://towardsdatascience.com/romes-libraries-readers-comments-analysis-with-deep-learning-989d72bb680c

[4] https://www.bibliotechediroma.it/it/open-data-commenti-lettori

[5] https://tfhub.dev/

[6] https://tfhub.dev/google/universal-sentence-encoder-cmlm/multilingual-base/1

[7] Ziyi Yang, Yinfei Yang, Daniel Cer, Jax Law, Eric Darve, "Universal Sentence Representations Learning with Conditional Masked Language Model", 2021, arXiv:2012.14388

[8] https://en.wikipedia.org/wiki/Cosine_similarity

[9] https://www.tensorflow.org/hub/tutorials/bert_experts

[10] https://en.wikipedia.org/wiki/Precision_and_recall

[11] https://www.tensorflow.org/guide/saved_model


Related Articles