The world’s leading publication for data science, AI, and ML professionals.

3 neural network architectures you need to know for NLP!

In this article, I will discuss what I think are the three most important architectures to be aware of for NLP.

Recurrent Neural Network

Recurrent Neural Network (RNN). Image from Wikipedia under CC BY-SA 4.0 License.
Recurrent Neural Network (RNN). Image from Wikipedia under CC BY-SA 4.0 License.

Recurrent neural networks are special architectures that take into account temporal information. The hidden state of an RNN at time t takes in information from both the input at time t and activations from hidden units at time t-1, to calculate outputs for time t. This can be seen in the image above. This gives the RNN memory, or the ability to remember previous inputs and their outputs.

This is extremely important for Natural language processing, as in NLP the input data does not have a fixed size, and the next word is highly dependent on previous words. Context is very important in NLP. The size of each sentence varies, and the output of each sentence varies too. Hence the ability to take variable input sizes and compute outputs of variable size is highly beneficial, and RNNs are capable of this. RNNs also have memory and can remember contextual information.

The many relations of RNNs. Image from Andrej Karpathy's Blog.
The many relations of RNNs. Image from Andrej Karpathy’s Blog.

RNN architectures can facilitate multiple types of input and output shapes.

  • A one to one architecture can be used for predicting the next word in a sentence for example, where the input is the current word and the next word is the output, and both input and output are 1 word long. When the word predictor is chained multiple times, we can generate sentences and even text, and this is called a language model.
  • An example of one input to many outputs can be image captioning. For a single image, if we want to generate a sentence caption, this requires a one to many model.
  • If we want to predict the sentiment of a sentence or a review, then we can use a many to one model. We can pass the many words of our review as input and get one output (sentiment: positive or negative).
  • Many to many models can be used for machine translation. If we have a sentence in English and want to convert it to a sentence in Kannada (my mother tongue from Karnataka, India), we want multiple words in English to be converted to multiple words in Kannada, however, the number of words might not be the same or in the same order.

Long short term memory network (LSTM)

LSTM architecture. Image from Wikipedia under CC BY-SA 4.0 License.
LSTM architecture. Image from Wikipedia under CC BY-SA 4.0 License.

LSTM was first introduced by Hochreiter et al. in their paper "LONG SHORT-TERM MEMORY" in 1997. You can read the original paper [here](https://en.wikipedia.org/wiki/Long_short-term_memory#History). The original paper’s architecture was different from what is popular today, and after many iterations (which you can read here), we got the version that is popular today.

Fall of the RNN

But why did we need this complicated architecture? Why were RNNs not enough?

RNNs had one glaring downside that made it nearly impossible to train larger versions, they were extremely prone to Gradient explosion and vanishing problems. I have discussed this issue in other articles (here), and you can find other resources for this problem. But very briefly, the exploding and vanishing gradient problem is caused by repeated multiplications (In our case, multiplication of the input and weight-matrix). Numbers just greater than 1 explode to infinity when multiplied by themselves, and numbers smaller than 1 vanish to zero when multiplied by themselves (you can test this by calculating 1.01¹⁰⁰ in python).

In an RNN, the same weight matrix is being multiplied with inputs and previous outputs, and hence the gradients explode and vanish. Also, the path for gradient flow is very long from the last iteration of the RNN to the first. This means the amount of contextual information the RNN can maintain in memory before gradients explode or vanish, is small. Hence the memory of vanilla RNNs is low, and the size of the reference window (the number of words before the current word, from which RNNs can draw contextual information ) is small.

Rise of LSTM

LSTM addresses the issues of the RNN. It does this by maintaining a cell state, which is the state at any given time. This cell state is updated with relevant information at each time step. The output at each time step is derived both from the input, the previous output and the updated cell state.

The cell state is updated using 3 gates, the input gate, forget gate and candidate gate (also sometimes called gate gate). The output is calculated with the input passed through the output gate and the updated cell state. Each gate is simply either a sigmoid or tanh of input and previous output, followed by either a product or sum. The gates and their processes have been explained too many times, so I will not reinvent the wheel here. But if you want an amazing resource, you can find it here. Also, I will be happy to help clarify doubts in the comment section.

The main advantage of using a cell state is that the cell state can be maintained and updated with minimal computation and processing, and the cell state creates a gradient highway for gradients to backpropagate, hence avoiding the vanishing and exploding gradient problem.

Transformer Network

RNNs had the issue of vanishing and exploding gradients, so memory was severely short term. Think of it this way, the more vanilla RNNs you stacked, the more likely the gradients exploded or vanished and outputs did not get better. This means the long term memory of RNNs was very small.

With the introduction of LSTMs and their ability to remember cell state, this improved long term memory, however, it was still limited. Also, LSTMs had to be fed input data sequentially as they needed outputs from the previous time step to calculate current outputs. This did not make full use of modern GPUs. Also, LSTMs were significantly slower than RNNs.

Transformer architecture. Image from the original paper.
Transformer architecture. Image from the original paper.

The Transformer Network was introduced to address these issues. It was introduced in 2017 by Ashish Vaswani et al. in their paper "Attention Is All You Need" which you can read here. They used an encoder-decoder network similar to RNNs, with a key difference. They introduced attention blocks.

Attention blocks calculated how each word in the input relates to other words in the input. The higher the value, the more attention is paid to those words, and the more dependent the set of words are. However, as most words have the highest relation to themselves, a single attention block might not be very helpful. This is why a multi-head attention block is used. This calculates attention multiple times and takes the weighted average to calculate overall attention.

Attention increases the number of contextual connections the network can make, and the network can learn relations and context from large datasets. Also as the attentions for each word are independent of other words, the attentions for each word can be calculated in parallel and be further processed in parallel, greatly increasing the computation time of these networks.

The details of the attention block and the nuances of Transformer network are beyond the scope of this article. If you are interested in learning the details of the transformer network, the original paper and this TDS article are very good resources.

Transformers have garnered a lot of attention (pun intended) in recent days due to 2 amazing models, BERT (by google) and GPT (by openAI). I recommend further reading on these topics if you find transformers interesting.


Related Articles