Why do RNNs have Short-term Memory?

All you need is a straightforward explanation

Essam Wisam
Towards Data Science

--

As we illustrated in the last story, Recurrent neural networks attempt to generalize feedforward neural networks to deal with sequence data. Their existence has revolutionized deep learning because they paved the way for numerous fascinating applications ranging from language modeling to translation. Nevertheless, one problem that prevents RNNs from prevailing in many of these applications is their inability to deal with long-term dependences which naturally arise when the input sequences are long.

The inability of RNNs to deal with long-term dependences can be ascribed to their vanishing gradient problem; if the gradients due to earlier words vanish and gradients are what’s responsible for training (weight updates) then when gradients vanish its equivalent to the inexistence of the corresponding words; in other words, earlier words are forgotten and the RNN has short-term memory. The fact that gradients vanish in the first place is a side effect of backpropagation through time — the algorithm used to train such networks.

In this story, we will set out on a quick journey to understand how backpropagation through time works to find out:

  1. Why do the gradients vanish?
  2. Why it causes short-term memory?
  3. What can we do to stop them from vanishing?
Photo by Aron Visuals on Unsplash

Why do the gradients vanish?

In an RNN, the loss due to any example (sequence) is the sum of the losses due to each token in the sequence. This assumes the general setting where the RNN produces an output for each token in the input. Thus we can write

Let’s try to find ∂L/∂W because we’re going to need that for gradient descent to update W.

Going onwards, Wis the input weight matrix (for input to recurrent layer connections) and Wₕₕ is the hidden weight matrix (for connection from the recurrent layer to its previous output). Meanwhile, Wₕ is the output weight matrix (for connections from the recurrent layer to the output).

Now clearly, we find our target ∂L/∂Wonce we find ∂L/∂W. We utilize the chain-rule to write that as

The two terms in blue could be easily computed. For the first term, the loss is a direct function in the output and for the second term we know
yₜ = f(Wₒₕhₜ + bₕ) so in both cases it’s just a matter of taking the derivative.

Meanwhile for the third term we need to call to mind that

which implies

so by using the chain rule we can write

The term in blue can be easily computed because every his a direct function in Wₓ.

Note that if we want to use perfect notation, then we should perhaps state something like h=f(…) and then use f rather than h on the right hand side but I’ll leave it like this because it looks much more cumbersome otherwise, among other reasons.

Meanwhile, for the other term it looks like we will need another round of chain rule because

so we have to write

Which after plugging in ∂h/∂Wyields

Why it causes short-term memory?

Now to see the vanishing gradient problem more closely let’s consider a sentence of 4 words. This means that t=4 and that the expression above evaluates to

which we can think of as the fourth word’s contribution of updating the weights in W. Notice how the only way for this to take the first word into account is via the first term, which (as we will show) is expected to be extremely small. Thus, it’s as if it wasn’t there despite any dependence that relates it to the fourth term.

The results are that parameters will be biased to only capture short-term dependencies during training. Bringing about a neural network that won’t be sure whether to choose was or were in a sentence like “The cats that I saw earlier this morning was/were very hungry”.

To back up the fact that the first term is expected to be very small we need to observe

Weights are naturally small due to how the network is initialized (often from a normal distribution) and the same applies to the derivative of the activation function because it’s often a Tanh which has derivative ≤ 1. Replacing the Tanh with something like ReLU might only solve part of the problem.

What can we do to stop them from vanishing?

The source of the problem is obviously

Can we redesign the recurrent layer so that this no longer necessary holds? Suppose, we rather let the hidden state h be defined as

* is an element-wise product

with each of G₁, G₂ and ĥ being a result of a different recurrent layer that uses h for the previous hidden state (they are all a function of it).

In this case, we have

and we can no longer claim that it will be always ≤ 1 even if that holds for each of the four terms. This fact is exactly what architectures like LSTM and GRU exploit to overcome the vanishing gradient problem.

Photo by Jacqueline Flock on Unsplash

This brings our story to an end. We have explored how and why RNNs suffer from short-term memory and what could be done to alleviate that issue. Hope you found the story easy to read. Till next time, au revoir.

--

--