The world’s leading publication for data science, AI, and ML professionals.

To Know Is Also to Remember

Understanding Long Short-Term Memory Networks

Image generated by the author using MidJourney
Image generated by the author using MidJourney

A man and a woman talk inside a quiet room in a clinical research center. The woman asks questions and then waits for the man to answer while taking some notes. It might seem like a normal conversation. Yet, it is anything but ordinary. Inside the woman’s notebook, the man’s answers are always the same regardless of the date written on each page. Although the conversations take place in the 80’s, the answers referred to events that happened more than 10 years before. Jenni Ogden was one of the first researchers to talk to patient H.M. who, years later, became widely known for his impact on neuropsychology and would finally be addressed by his real name Henry Molaison. After years of interviews and tests, researchers concluded that Henry had lost the ability to generate new memories due to a lobotomy procedure he had when he was 27 years old. Henry’s case helped to better understand the connection between brain function and memory and to coin the concept of short- and long-term memory. This concept paved the way for groundbreaking research in the field of Machine Learning where scientists and developers try to find in the mysterious inner workings of the brain, a better way to build predictive models.

Introduction

Artificial Neural Networks (ANNs) were inspired by the real neural networks that work in our brains. In reality, ANNs are just an abstraction of how we think the real neurons connect and how we explain situations like the one described in the paragraph above. Similar to processes like ant colony optimization, differential evolution and particle swarm, ANNs capture the essence of a real-life process and use it to design algorithms that are behind most of the AI solutions nowadays. The discussion of whether ANNs really learn or whether what they do should be called intelligence is broad and ongoing. However, their versatility and power are undeniable. New ANN configurations are being developed every day and successfully applied to many problems. Most of these variations still grab inspiration from the behaviour of real neural networks.

Recurrent Neural Networks (RNNs) incorporate a "memory" component to process sequences which, years ago, represented an important approach to Natural Language Processing (NLP). RNNs paved the way toward Long-Short Term Memory (LSTM) Networks that also improved the performance of neural networks in NLP applications. LSTM networks would be later superseded by the transformer model and the Generative Pre-trained Transformer which is the basis of ChatGPT. This article explores what LSTM networks are and what makes them so special.

Why RNNs?

To understand how LSTM networks work, it is important to have an idea of what their purpose is. RNNs and LSTM networks follow a similar goal. They are used to model and predict data that is stored in sequential form. This means that this type of network reads a sequence of data and then tries to predict what value comes next. Let’s say that you have a log of the average temperature in a particular city for the previous 30 days and you want to estimate what the temperature will be on day 31. One approach is to correlate the temperature to other variables and estimate the new temperature based on the values of these variables on day 31. RNNs consider some days, which could be 30 or less, and forecast the temperature for day 31 based on the previous temperature values. In a way, RNNs try to memorize a sequence and then come up with the next value or group of values. In a previous article, I wrote about how RNNs compare to mnemonists and explained how an RNN works step-by-step.

Recurrent Neural Networks: A Very Special Kind of Mnemonist

The idea of predicting a new temperature based on the temperature of the previous days can be extrapolated to other applications. As mentioned in the introduction, RNNs were one of the first approaches in NLP. The idea is to train the RNN with a text and then use the RNN to predict the word or group of words that come next after an input. This idea can be also applied to other tasks such as automatic translation as well as speech and handwriting recognition. One of the problems that RNNs face when dealing with advanced tasks such as theses is that they will probably suffer from vanishing/exploding gradients. A solution to this problem is the application of an LSTM network.

Figure 1 shows the main architectural difference between fully-connected ANNs and RNNs. In this simplified example, the values of (X1,Y1) and (X2,Y2) are used to calculate the new value of Y3. In reality, ANN and RNNs are trained with many input-output pairs. After the training process is completed, the network is then used to predict new values. If you are unfamiliar with this process, I have added some references which I find useful at the end of this article. I have also made my attempt to explain all this process here. Although the general idea of training and then predicting is similar among ANNs and RNNs, their structure is considerably different. Note how in an RNN the values of Y1 and Y2 are used to train the network instead of X1 and X2. Also note how there is a weight that connects the first and second unit and that is equal to the activation that comes from the previous unit. This weight represents the "memory" component of the RNN. Bigger weights mean that the RNN gives more importance to previous values whereas smaller weights mean that the RNN does not remember so well the past values. As this is a weight in the neural network structure, its value is learned during the process which means that the RNN can determine if the sequence that is trying to reproduce needs more or less memory.

Figure 1. Fully-connected ANNs and RNNs (Image made by the author).
Figure 1. Fully-connected ANNs and RNNs (Image made by the author).

An important aspect of the training and application of RNNs is the length of the sequence of values the network reads and uses to train and forecast. Let’s say that the sequence length is 15. This means that the RNN will be trained to find the 16th value after reading 15 input values (note that this is not always the case since there are also dynamic RNNs). Going back to the sequence length, a longer sequence length is beneficial because it allows an RNN to incorporate multiple readings into its final output. However, longer sequence lengths lead to vanishing/exploding gradients. How can LSTM networks overcome this problem?

LSTM networks

LSTM networks are trained to decide which information they want to "remember" and which one they want to "forget". In an RNN there is a weight that applies only to the memory component which is the activation of the previous unit. In an LSTM the idea of a memory component is substituted by a long-term memory component (cell state) and a short-term memory component (hidden state). Each one of these components is associated with a set of biases and weights distributed in different gates. This means that each input goes through different gates (or stages) where the network decides how much of the information stays and how much is discarded. This decision is made based on the values of weights and biases which are learned during the training process.

Figure 2 shows a very simple sketch of an LSTM network reading a single input. We will analyze this simplified diagram first before looking at a real LSTM network unit. Note how different this is from the previous figure that represented the RNN. The first difference to bear in mind is that, besides the input values, an LSTM network has a short and a long-term memory component. These components are formally known as the cell state (C) and hidden state (h). We will use this notation later. For now, let’s use the previous names. The input and the memory components go through three different gates: forget, input and output.

  • The forget gate is in charge of deciding how much of the long-term memory should be preserved. This gate takes into account the current input as well as the short-term memory component and calculates a value between 0 and 1 that multiplies the long-term memory component. A forget gate of zero means that the network is not preserving old information. Its answers are only based on the new inputs.
  • The input gate decides how much of the new information is preserved in the long-term memory component. The output of this gate is added to the long-term component that is preserved for the next input. Note how this component only connects to the forget and input gate. This means that at each iteration the long-term memory is updated according to what it should discard and what information it should add.
  • The output gate takes into account the input and the short-term memory component and calculates how much of the new long-term memory component will be preserved as the new short-term memory component. This means that the final value that comes out of an LSTM network is calculated using both the long and short-term memory components as well as the output gate.
Figure 2. Simplified LSTM network scheme. (Image made by the author)
Figure 2. Simplified LSTM network scheme. (Image made by the author)

Now that we know the structure of a LSTM network cell we can get a closer look at the calculations that take place there. Figure 3 shows a more complete look at an LSTM network cell. The input, short-term memory component and long-term memory component are represented with the letters x, h and C respectively. Note that each of these letters has a subscript that corresponds to the time period that is being analyzed. A subscript of t-1 means that the value belongs to the previous iteration. For example, the forget gate takes into account the current input (xt) and the short-term memory component of the previous iteration (ht-1). The gates also consider a set of weights and biases. The forget and output gates have 3 parameters each:

  • A bias (bxf, bxo)
  • A weight that multiplies the input (wxf,wxo)
  • A weight that multiplies the short-term memory component (whf, who)

The weights and biases are multiplied and summed to the input values before going into the activation function similar to what happens in any other ANN. In this example both the forget and output gates have sigmoid activation functions and their final activations (af, ao) are shown on the right side of Figure 3. Unlike these gates, the input gate has two activation functions, a sigmoid and a tanh with their own set of weights and biases. This means that the total number of parameters to train for a single LSTM network unit is 12.

The last important aspect to understand about the LSTM network cell is how the new C and h are calculated. The previous long-term memory component is first multiplied by the output of the forget gate and then added to the result of the input gate. This means that the forget gate decides how much of C is transferred to the next iteration and the input gate decides how much is added to the new value of C. For the short-term memory component calculation, the result of the output gate is multiplied by the tanh of the new C. This means the output gate decides how much of the long-term memory component is passed to the next iteration. Since the value of C can be more than 1, a tanh operation is applied to limit this value between -1 and 1.

Figure 3. Operations inside an LSTM cell. (Image made by the author)
Figure 3. Operations inside an LSTM cell. (Image made by the author)

Once the input has gone through all the gates, the new C and h are passed on to the next iteration where they interact with a new input (Figure 4). This process is repeated for all the values in a sequence until we reach the final h which is the network output for that sequence.

Figure 4. Information passes from one cell to another in an LSTM network. (Image made by the author)
Figure 4. Information passes from one cell to another in an LSTM network. (Image made by the author)

Going forward

LSTM networks can be trained using a similar methodology to what is used in other ANNs. Depending on the problem, we define a loss function which is updated after each forward propagation cycle. Then, this loss function is used to update the weights and biases through a backpropagation process. In the case of RNNs and LSTM networks, the backpropagation is usually called backpropagation through time (BPTT) since it involves accumulating weights and biases through all the recurrent units. This is a Jupyter notebook that contains a simple implementation of an LSTM network and a detailed explanation of what goes on during the forward and backpropagation process.

Figure 5 shows an example of a forward propagation pass in an LSTM network cell that reads a single unit. Note how the final value of h carries information that passes through all the gates in the network whereas the final C does not interact with the output gate. Each one of these gates works as a regulator of the information that is preserved and the one that is forgotten. In a similar way to what happens in a fully connected ANN where the weights and biases learn the optimal way to interact with the inputs, in an LSTM network the parameters are trained to learn what is the optimal amount of information to preserve or discard.

Figure 5. Example of forward propagation in an LSTM cell. (Image made by the author)
Figure 5. Example of forward propagation in an LSTM cell. (Image made by the author)

More units?

So far all the figures and examples have shown a single-unit LSTM network but, as is the case with other types of networks, LSTM networks can have multiple units. How does the number of parameters change according to the number of units? In an LSTM network with two units, instead of 12 parameters to learn, we have 32. Where are the 20 additional parameters? Well, things are about to get messy. Let’s break down this in parts.

There are 12 parameters for the first unit. These are the same parameters explained previously: 4 weights that multiply the input, 4 weights that multiply the hidden state (h) and 4 biases for each gate. The second unit also has 12 related parameters. This means that so far we have 24 parameters in total as shown in Figure 6.

Figure 6. Some of the parameters (12) in an LSTM network with two units. (Image made by the author)
Figure 6. Some of the parameters (12) in an LSTM network with two units. (Image made by the author)

Since an LSTM network is a particular kind of RNN, there will be connections between different units. This means that the information processed by Unit 1 will be transferred to Unit 2 and so on. Each of these links carries its own weight. In an LSTM network with two units, besides the 24 parameters previously mentioned, there will be 8 parameters that correspond to the connection of the gates in Unit 2 with each of the gates from Unit 1 as well as the connections between the hidden state of Unit 1 with the gates in Unit 2. Figure 7 shows how to calculate the activation in the forget gate of Unit 2. Note how this gate is now connected to all the previous gates and the previous hidden state. Although the figure shows only 5 new weights, in reality, there are eight because the hidden state in the first unit connects to each gate in the second unit. For n units, the number of parameters is 12n+4n(n-1) or in a simplified expression: 8(n+n²/2).

Figure 7. The rest of the parameters (8) in an LSTM network with two units. (Image made by the author)
Figure 7. The rest of the parameters (8) in an LSTM network with two units. (Image made by the author)

Going backwards

As is common with ANNs, the backpropagation process is usually the hardest to understand and implement. It is important to take into account that in a single unit LSTM network, each backpropagation needs to update the 4 biases and 8 weights as well as h(t-1) and C(t-1). Note how h(t-1) depends on the output gate as well as the current C which in turn is dependent on the input and forget gates. It is important to take this into account when calculating the partial derivatives with respect to the loss. As mentioned previously, this Jupyter notebook contains all the equations you need to build a simple LSTM network including the backpropagation process. Figure 8 shows the dependencies in the parameters that can guide you to calculate each partial derivative.

Figure 8. Dependencies in the calculation of the activation function in an LSTM network. (Image made by the author)
Figure 8. Dependencies in the calculation of the activation function in an LSTM network. (Image made by the author)

Applications

The following sections contain three examples of the application of an LSTM network. The examples are presented in order of simplicity. You can find a Python code for each in this Jupyter Notebook.

Vanilla LSTM network to model a continuous function.

This is a very simple example of how to implement an LSTM network from scratch. An important lesson to learn here is the importance of preparing the data correctly before feeding it to the LSTM network. If we have a continuous function that we want to model with an LSTM network we first need to create pairs of input-target data as shown in Figure 1. This data depends on the sequence length. For a sequence length of 15, for example, each input entry will contain 15 values and the 16th value will be the target for that input. It is also important to normalize the data before going into the network. In this example, a simple LSTM network is used to model a sin(x) function as well as the oil production behaviour in a well.

Vanilla LSTM network to build a word predictor

In this example, the previous vanilla LSTM network is applied to a word prediction problem. After being trained with a short text, the model predicts what word comes next. In reality, word processing and word predicting problems are not approached like this example. However, is a simple and clear explanation about a possible application of LSTM networks.

Keras’ LSTM network to build a word predictor

This is a more realistic example in which a Keras’ LSTM network is used to predict what the next word will be after being trained with a long text. In this example, we use an embedding layer which is a general addition in NLP problems. The embedding layer is used to convert integer-encoded representations of words (indexes) into dense vectors of fixed sizes which helps to better model the connections between words.

Conclusion

Using artificial neural networks and in particular recurrent neural networks to approach natural language processing problems or modeling sequential data is not a recent development. This application has been around for a long time and it is still being improved with new features. Understanding how an LSTM network works and what makes it a special kind of RNN can help us to get a better insight into its results and the reasons why it might not be working as expected. This article contains a comprehensive explanation of LSTM networks and presents three examples of their application. Although most current NLP tools and solutions rely on different network architectures, a solid idea of LSTM networks’ inner workings will always be beneficial in the world of machine learning. Remember to store this in your long-term memory cell! 😉

References


Related Articles