How do we deal with lesser amount of training data in NLP? Semi-supervised learning – to our rescue!

Semi Supervised Learning is an actively researched field in the machine learning community. It is typically used in improving the generalizability of a supervised learning problem (i.e. training a model based on provided input and ground-truth or actual output value per observation) by leveraging high volumes of unlabeled data (i.e. observations for which inputs or features are available but a ground truth or actual output value is not known).
Experiment Setup
We have the following two different datasets (in increasing order of complexity from a classification perspective) at our disposal:
IMDB reviews data: Hosted by Stanford, this is a sentiment (binary- positive and negative) classification dataset of movie reviews. Please refer here for more details.
20Newsgroup dataset: This is a multi-class classification dataset where every observation is a news article and it is labeled with one news topic (politics, sports, religion and so on). This data is available through the open source library scikit-learn too. Further details on this dataset can be read here.
We have performed experiments to see how these perform with unlabelled data at low, medium and high volumes of data. We have been introduced to sklearn’s SelfTrainingClassifier which works as a wrapper function to enable self-training of any sklearn-based algorithm. We have seen a comparison of pre and post self-training results for a simple logistic regression model and an ANN(Artificial Neural Net) model.
In the Part-1 of this article, some initial methods used in this field are thoroughly discussed – followed by the more involved Part-2 where you can get a sense of how to apply these methods in practice.
In this part, we will go through some state-of-the-art Semi Supervised Learning(SSL) methods and how they have performed in our experiments. We start by looking at how an advanced NLP model – like an LSTM/CNN (Yes, you read that right!! A CNN can be a good NLP model. More details on that in the next section) or a state-of-the-art (SOTA) model like Bert perform in the Semi Supervised Learning(SSL) setting. We then move on to more recent methods in this space – UDA and MixText which employ some novel ideas to crack the SSL problem. Finally, we conclude with a detailed comparison of these algorithms and further directions to explore, in this space.
Let’s dive right in!
LSTM
We have seen LSTMs disrupting the NLP space since their inception. So, we thought using an LSTM would be the next logical step. This article does a good job providing a brief history and use of LSTMs and RNNs.
Since an LSTM is a sequential model, tf-idf embeddings would be of little help. So, we have decided to use GloVe embeddings which can be given to the model as a sequential input. We have used a fairly simple architecture to create the LSTM model. It consists of three LSTM layers, each followed by a dropout layer. Finally, we have a dense layer followed by the output layer.
The following code shows what the architecture looks like:
Note that all of the code in this series would be available on this github page.
The official [sklearn.semi_supervised](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.semi_supervised)
.SelfTrainingClassifier wouldn’t work with TensorFlow components which contain a Functional module in the model. In such scenarios, we would use the KerasClassifier from scikeras.wrappers which works in the exact same way as that of the SelfTrainingClassifier and supports the Functional modules of a TensorFlow model.
We are not including the LSTM results at this point because even with a relatively shallow architecture, running it took a considerably longer time and even the performance was not at par with the CNN model – which is what we are going to see in the next section.
CNN
Now, what does a model like CNN which is predominantly used in Computer Vision doing in our NLP experiments? People have tried using this and did see great success. See this paper for more details about one such work.
Coming back to our problem, we thought that it’ll be worth trying out a CNN for 2 reasons:
- LSTMs are relatively slower to execute when compared to CNNs.
- CNNs perform almost as well as LSTMs, if not better – in many use-cases.
We have used the same GloVe embeddings and trained our CNN. See this comprehensive article for more details on how to apply CNN on text data.
Here is a code snippet which would help you understand our CNN architecture:
We have used the same KerasClassifier from scikeras.wrappers that we have used in our LSTM experiments.
We have seen that the CNN’s performance is atleast as good as the LSTM and sometimes even better! The runtime is also significantly lesser.
Here are the results:


BERT
Any NLP model building exercise these days wouldn’t be complete without trying out BERT. BERT is a pre-trained language model which was introduced in 2018 by Google and since then, it disrupted the NLP space. It uses a set of transformers with encoder-decoder architecture and is trained on a large corpus of data using a novel technique called Masked Language Model (MLM). If any of that doesn’t sound familiar, you can go through this article to get a deeper understanding of the ideas and mechanics behind BERT.
After BERT, there have been many variants of pre-trained models with different architectures and trained using different techniques. All in all, pre-trained models like BERT carry a plethora of information and knowledge about the corpora that they are trained on. The only downside to these models is that they do require a huge amount of computation power in the form of GPUs.
We need not use any word embeddings separately for BERT as it produces BERT embeddings which are superior to the GloVe, word-to-vec and other embeddings in the sense that these are contextual embeddings.. i.e. the same word can have different embeddings based on the context in which it can appear. For example:
- Apple is my favourite fruit.
- I bought an Apple iPhone.
The word ‘apple’ in the above 2 sentences would have a different BERT embedding as it appears in different contexts.
Since BERT is an already pre-trained model, we do not train it from scratch again. But we do fine-tune it. Fine-tuning BERT is a supervised method which involves passing a few labeled training examples after which, it’ll be ready to give out labels for our test data.
[For the purposes of this article, we use the words training and fine-tuning interchangeably, in the context of BERT.]
For our training/fine-tuning purposes, we have used a very user-friendly BERT API called sklearn-bert. This is a no-strings-attached API which we can use just like any other sklearn model. This is how we can install sklearn-bert.
The below snippet of code shows how sklearn-bert can be trained:
Due to our computational restrictions, we limited our sequence length and batch_size to 128 and 16 respectively. We have used the ‘bert-base-uncased’ model which can have a maximum sequence length of 512. We can just do a model.fit() and voila! We’re done with the fine-tuning. We can do a model.predict() to get the predicted labels.
The advantage with this kind of API is that we need not manually do any kind of tokenization or last layer additions to BERT – which we would have to do if we train BERT using PyTorch or TensorFlow.
Now, how do we do SSL on this BERT model? Since this is an sklearn model, we figured we could use the same KerasClassifier from scikeras.wrappers that we have used previously. But unfortunately, that didn’t work. We think that it is because BertClassifier did not implement ALL sklearn methods that our KerasClassifier would need. So, we had to implement the SSL logic ourselves – which was fairly straightforward. Have a look at the below code snippet to understand our BERT SSL implementation:
Let’s quickly go over how we implemented the SSL algorithm here.
_Xu represents the unlabelled dataset. The number of iterations – _numiter would be at most the ratio of _Xu and the kbest value.
We do a model.fit() and then, predict using the model. We then select the predictions with maximum probability as the final model output.
We then partition these based on _n_toselect which is the smaller of kbest value and the size of _Xu. This is concatenated with the already present _newX dataset. In the final step, we add whatever is left(which is done in the _if not_selectedupdate == ‘yes’: clause) irrespective of the kbest value, to the data.
Here are the results we obtained using SSL with BERT:


We couldn’t use BERT for higher amounts of labeled data on the NewsGroup dataset due memory and compute resource constraints.
Unsupervised Data Augmentation (UDA)
Everything that we have done till this point only involved changing the models and trying out different architectures in order to improve the performance. But now, we fundamentally change the way we deal with this problem.
Instead of the same SSL, in UDA we use data augmentations and consistency training to further improve upon the existing performance. We will now briefly discuss how UDA works, specifically in the context of NLP.
Unsupervised Data Augmentation tries to effectively increase the amount of labeled data by using a set of data augmentations – like back translations, word replacement etc.
Back Translations
Back translation translates an English sentence into a different language – say, Russian, German, Chinese etc and translate them back into English so that we would get a re-phrased sentence with the same meaning and hence, the same label. These back translations can be done by using any machine translation model as a service.
Sentence in English - S, Rephrased sentence - S".
S' = EnglishToGerman(S);
S" = GermanToEnglish(S').
We can control the quality of back translations and get different kinds of rephrased sentences with slightly different structure and wording but with a similar meaning. We can see three different back translations for a sentence in the below image.
![An example of a back translation. [Image Source]](https://towardsdatascience.com/wp-content/uploads/2022/05/1t_rNNLHcr5pB2ZZrdbRRjg.jpeg)
Word Replacement
Word replacement is another data augmentation method. We replace words from a sentence with low TF-IDF scores by sampling words from the whole vocabulary. This will potentially help generate both diverse and valid examples.
UDA uses both labeled data and the augmented unlabelled data for training. We can use any general NLP model for performing UDA, just by changing the loss function slightly. Here is an image which shows how the loss function is tweaked:
![Overview of UDA [Image source]](https://towardsdatascience.com/wp-content/uploads/2022/05/0K-yoV5wwG_EhTVoz.jpg)
The loss function has 2 components: supervised cross entropy loss for labeled data, which is the traditional one and unsupervised consistency loss for unlabelled data.
Intuitively, unsupervised consistency loss does a very simple thing: It forces all the augmented examples to have the same label. This seems reasonable as we know that the augmentations that we do give us rephrased sentences with the same meaning and hence, the same label.
We have a slightly updated variation of UDA – called MixText and this is what we have implemented. Let’s look at it now.
MixText
MixText initially does exactly what we discussed in UDA – takes in labeled data and unlabelled data, conducts augmentations and predicts labels for unlabelled data. But before applying the cross-entropy and consistency loss, it performs another augmentation called TMix.
TMix
TMix provides a general approach to augment text data and hence can be applied independently to any downstream tasks. TMix does – interpolation in textual hidden space. It has been shown that decoding from an interpolation of two hidden vectors generates a new sentence with mixed meaning of two original sentences.
For an encoder with L layers, we choose to mixup the hidden representation at the m-th layer, where m ∈ [0, L]. Simply put, TMix takes in two text samples x and x` with labels y and y`, mixes their hidden state representations h and h` at layer m into h˜, and then continues forward passing to predict the mixed labels y˜. Mixing hidden states(with a parameter λ) happens like this:
Experimentally, some specific BERT layers – {3,4,6,7,9,12} – have been shown to have the most representational power. Based on this, the authors of MixText have selected some random subsets and {7, 9, 12} have given the best results and hence, this is what we have used.
Running MixText required us to follow the instructions on their GitHub page. We had to slightly modify the code in code/read_data.py file so as to read our format of IMDB/NewsGroup data.
Here are the results:

Final Comparison
We can see that both BERT and CNN perform reasonably well in the SSL setting. As there is no clear winner, we can test and choose wither of them depending on our task at hand.
At the sequence length that we have fixed, we couldn’t observe any significant difference between MixText and BERT+SSL. The results mentioned by the authors might have been because of using higher sequence lengths and more efficient and diverse augmentation, with the help of better compute resources.
The 3-part series, including this article, has been a joint work between Naveen Rathani, an applied machine learning specialist and data science enthusiast with a master’s degree from BITS, Pilani-IN and Sreepada Abhinivesh, who is a passionate NLP data scientist.