The world’s leading publication for data science, AI, and ML professionals.

Deep Learning Illustrated, Part 5: Long Short-Term Memory (LSTM)

An illustrated and intuitive guide on the inner workings of an LSTM

Welcome to Part 5 in our illustrated journey through Deep Learning!

Deep Learning, Illustrated

Today we’re going to talk about Long Short-Term Memory (Lstm) networks, which are an upgrade to regular Recurrent Neural Networks (RNN) which we discussed in the previous article. We saw that RNNs are used to solve sequence-based problems but struggle with retaining information over long distances, leading to short-term memory issues. Here’s where LSTMs come in to save the day. They use the same recurrent aspect of RNNs but with a twist. So let’s see how they achieve this.

Sidenote – this is one of my favorite articles I’ve written, so I can’t wait to take you on this journey!


Let’s first see what was happening in our RNN previously. We had a neural network with an input x, one hidden layer that consists of one neuron with the tanh activation function, and one output neuron with the sigmoid activation function. So the first step of the RNN looks something like this:

Terminology segue: We’re going to call each step a hidden state. So the above is the first hidden state of our RNN.

Here, we first pass our first input, _x,_ to the hidden neuron to get h₁​.

h₁ = first hidden state output
h₁ = first hidden state output

From here we have two options:

(option 1) Pass this h₁​ to the output neuron to get a prediction using just this one input. Mathematically:

y_₁hat = first hidden state prediction
y_₁hat = first hidden state prediction

(option 2) Pass this h₁​ to the next hidden state, by passing this value into the hidden neuron of the next network.

So the second hidden state will look like this:

first and second hidden states
first and second hidden states

Here we are taking the output from the hidden neuron in the first network and passing it to the hidden neuron in the current network alongside the second input, x₂​. Doing so gives us our second hidden layer output, h₂​.

h₂ = second hidden state output
h₂ = second hidden state output

Again, from here, we can do two things with h₂​:

(option 1) Pass it to the output neuron to get a prediction that is a result of the first, ​_x_₁, and the second, x₂​.

y_₂hat = second hidden state prediction
y_₂hat = second hidden state prediction

(option 2) Or we simply pass it to the next network as is.

And this process continues, with each state taking the output from the hidden neuron of the previous network (alongside the new input) and feeding it to the hidden neuron of the current state, thereby generating the output for the current hidden layer. We could then pass this output either to the next network or to the output neuron to produce a prediction.

This entire process can be captured by these key equations:

Despite its simplicity, this approach has a limitation: as we progress to the final steps, the information from the initial steps starts to fade away because the network fails to retain a lot of information. The larger the input sequence, the more pronounced this problem becomes. Clearly, we need a strategy to enhance this memory.

Enter LSTMs.

They accomplish this by implementing a simple yet effective strategy: at each step, they discard unnecessary information from the input and past steps, effectively "forgetting" information that’s not important and only retaining information that’s crucial. It’s kind of like how our brain processes information – we don’t remember every single detail, but only hold on to the details that we find necessary, discarding the rest.

LSTM Architecture

Consider a hidden state of our basic RNN.

hidden state of an RNN
hidden state of an RNN

We know each state starts with two players: the previous hidden state value _h_ₜ₋₁,​ and ​the current input, _x_ₜ. And the end goal is to produce a hidden state output, _h_ₜ​, which can either be passed onto the next hidden state or passed to the output neuron to produce a prediction.

LSTMs have a similar structure, with a slight elevation in complexity:

hidden state of an LSTM
hidden state of an LSTM

This diagram might seem daunting, but it’s actually intuitive. Let’s break it down slowly.

We had two players in an RNN with the end goal of producing a hidden state output. Now we have three players at the beginning that are inputted to the LSTM – previous long-term memory Cₜ₋₁, previous hidden state output hₜ₋₁ and input xₜ:

And the end goal is to produce two outputs – new long-term memory Cₜ and new hidden state output hₜ:

The primary focus of the LSTM is to discard as much unnecessary information as possible, which it accomplishes in three sections –

i) the forget section

section 1 - forget section
section 1 – forget section

ii) the input section

section 2 - input section
section 2 – input section

iii) and the output section

section 3 - output section
section 3 – output section

We notice that they all have a purple cell in common:

These cells are called gates. To decide what information is important and what is not, LSTMs employ these gates, which are essentially neurons with the sigmoid activation function.

These gates decide what proportion of information to retain in their respective sections, effectively acting as gatekeepers that only let a proportion of information pass through them.

The use of a sigmoid function in this context is strategic, as it outputs values ranging from 0 to 1, which correspond directly to the proportions of information we intend to retain. For instance, a value of 1 implies that all information will be preserved, a value of 0.5 means only half of the information will be kept, and a value of 0 denotes that all information will be discarded.

Now let’s come to the formula for all these gates. If you look closely at the hidden state diagram, we see that they all have the same input, _x_ₜ,​ and _h_ₜ₋₁, but different weight and bias terms.

They all have the same mathematical formula, but we need to swap out the weight and bias values appropriately.

Each of these will produce values between 0 and 1, since that’s how the sigmoid function works, which will determine what proportion of certain information in each section we want to retain.

Note: Here you’ll notice we’re just using a vector notation of weights. This just means we’re going to multiply the xₜ,​ and hₜ₋₁ with their respective weights represented by W.

Forget section

The main purpose of this section is to figure out what proportion of the long-term memory we want to forget. So all we’re doing here is taking this proportion (a value from 0–1) from the forget gate…

…and multiplying that with the previous long-term memory:

This product gives us the exact previous long-term memory that the forget gate thinks is important and forgets the rest. So the close the forget gate proportion, fₜ, is to 1 the more of the previous long term memory we’re going to retain.

Note: The ‘x’ symbol within the blue bubble signifies a multiplication operation. This notation is consistently used throughout the diagrams. Essentially, these blue bubbles indicates that the inputs are subjected to the mathematical operation depicted in the bubble.

Input section

The main purpose of this section is to create a new long-term memory, which is done in 2 steps.

(step 1) create a candidate for the new long-term memory, _C(tilda)_ₜ. We get this candidate for the new long-term memory using this neuron with the tanh activation function:

We see here that the inputs for this neuron are _x_ₜ,​ and _h_ₜ₋₁, similar to the gates. So, passing them through the neuron…

…we get the output, which is a candidate for the new long-term memory.

Now we only want to retain necessary information from the candidate. This is where the input gate comes into play. We use the proportion obtained from the input gate…

…to retain only the necessary data for the candidate by multiplying this input gate proportion with the candidate:

(step 2) now to get the final long-term memory, we take the old long-term memory that we decided to keep in the forget section…

…and add that to the amount of new candidate that we decided to keep in this input section:

And viola, we completed mission 1 of the game, we created a new long-term memory! Next, we need to produce a new hidden state output.

Output section

The main purpose of this section is to create a new hidden state output. This is pretty straightforward. All we’re doing here is taking the new long-term memory, Cₜ, passing it through the tanh function…

…and then multiplying it with the output gate proportion…

new hidden state output
new hidden state output

…which gives us the new hidden state output!

And just like that we completed mission 2 – producing a new hidden state output!

And now we can pass these new outputs to the next hidden state to repeat the same process all over again.

We also see that each of the hidden states has an output neuron:

Just like in an RNN, each of these states can produce their own individual outputs. And similar to RNNs, we use the hidden state output, ​hₜ, to produce a prediction. So passing hₜ​ to the output neuron…

…we get a prediction for this hidden state!

And that wraps this up. As we saw, LSTMs take RNNs to the next level by handling long-term dependencies in sequential data better. We saw how LSTMs cleverly manage to retain essential information and discard the irrelevant, much like our brains do. This ability to remember important details over extended sequences makes LSTMs particularly powerful for tasks such as natural language processing, speech recognition, and time series prediction.

Connect with me on LinkedIn or shoot me an email at [email protected] if you have any questions/comments!

NOTE: All illustrations are by the author unless specified otherwise


Related Articles