Backpropagation in RNN Explained

A step-by-step explanation of computational graphs and backpropagation in a recurrent neural network

Neeraj Krishna
Towards Data Science

--

Backpropagation in RNN

Introduction

In the early days of machine learning when there were no frameworks, most of the time in building a model was spent on coding backpropagation by hand. Today, with the evolution of frameworks and autograd, to backpropagate through a deep neural network with dozens of layers, we just have to call loss.backward — that’s it. However, to gain a solid understanding of deep learning, it’s crucial we understand the foundation of backpropagation and computational graphs.

The intuition behind backpropagation is we compute the gradients of the final loss wrt the weights of the network to get the direction of decreasing loss, and during optimization we move along this direction and update the weights thereby minimizing the loss.

In this article we’ll understand how backpropation happens in a recurrent neural network.

Computational Graphs

At the heart of backpropagation are operations and functions which can be elegantly represented as a computational graph. Let’s see an example: consider the function f = z(x+y); It’s computational graph representation is shown below:

forward pass

A computational graph is essentially a directed graph with functions and operations as nodes. Computing the outputs from the inputs is called the forward pass, and it’s customary to show the forward pass above the edges of the graph.

In the backward pass, we compute the gradients of the output wrt the inputs and show them below the edges. Here, we start from the end and go to the beginning computing gradients along the way. Let’s do the backward pass for this example.

Notation: let’s represent the derivative of a wrt b as ∂a/∂b throughout the article.

The light gray arrows represent backward pass

First, we start from the end and compute ∂f/∂f which is 1, then moving backward, we compute ∂f/∂q which is z, then ∂f/∂z which is q, and finally we compute ∂f/∂x and f/∂y.

Upstream, Downstream, and Local Gradients

If you observe, we cannot compute ∂f/∂x and f/∂y directly, so we use the chain rule to first compute ∂q/∂x and then multiply it with f/∂q which was already computed in the preceding step to get f/∂x. Here, ∂f/∂q is called the upstream gradient, f/∂x is called the downstream gradient, and ∂q/∂x is called the local gradient.

downstream gradient = local gradient × upstream gradient

Advantages of Computational Graphs

  1. Node-based approach: In backpropagation, we’re always interested in gradient flow, and when working with computational graphs, we can think of gradient flow in terms of nodes rather than functions or operations.

Consider a simple addition node as shown in the below figure:

gradient distributor

Given inputs x and y, the output z = x + y. The upstream gradient is L/∂z where L is the final loss. The local gradient is z/∂x, but since z = x + y, ∂z/∂x = 1. Now, the downstream gradient L/∂x is the product of the upstream gradient and the local gradient, but since the local gradient is unity, the downstream gradient is equal to the upstream gradient. In other words, the gradient flows through the addition node as is, and so the addition node is called the gradient distributor.

gradient swap multiplier

Similarly, for the multiplication node, if you do the calculations, the downstream gradient is the product of the upstream gradient and the other input as shown in the above figure. So, the multiplication node is called the gradient swap multiplier.

Again, the key takeway here is to think of gradient flow in a computational graph in terms of nodes.

2. Modular approach: The downstream gradient at a particular node depends only on its local gradient and the upstream gradient. So if we want to change the architecture of the network in the future, we could simply plug and pull the appropriate nodes without affecting other nodes. This approach is modular, especially when working with autograd.

3. Custom nodes: We could combine multiple operations into a single node like a sigmoid node or a softmax node as we’ll see next.

Gradients of Vectors and Matrices

When working with neural networks, we usually deal with high dimensional inputs and outputs which are represented as vectors and matrices. The derivative of a scalar wrt a vector is a vector that represents how the scalar is affected by a change in each element of the vector; the derivative of a vector wrt another vector is a Jacobian matrix that represents how each element of the vector is affected by a change in each element of the other vector. Without proof, the gradients are shown below:

matrix multiplication gradient

Here W is the weight matrix, x is the input vector, and y is the output product vector.

Gradient of Cross-Entropy loss (Optional)

Let’s do another example to reinforce our understanding. Let’s compute the gradient of a cross-entropy loss node which is a softmax node followed by a log loss node. This is the standard classification head used across many neural networks.

softmax + -ve log loss

Forward pass

In the forward pass, a vector 𝑦⃗ = [y1, y2, y3, ..., yn] passes through the softmax node to get the probability distribution S = [S1, S2, S3, ..., Sn] as the output. Then, say the ground truth index is m, we take the negative logarithm of Sm to compute the loss: l = -log(Sm).

The softmax function is given by:

Backward pass

The tricky part here is the dependence of loss on a single element of the vector S. So, l = -log(Sm) and ∂l/∂Sm = -1/Sm where Sm represents the mth element of S where m is the ground truth label. Next, moving back, we ought to compute the gradients of the softmax node. Here the downstream gradient is ∂l/∂y, the local gradient is S/∂y, and the upstream gradient is l/∂Sm.

backward pass

First, let’s compute the local gradient ∂S/∂y. Now, here S is a vector and y is also a vector so ∂S/∂y will be a matrix that represents how each element of S is affected by a change in each element of y; but you see, the loss depends on only a single element of S at the ground truth index, so we’re only interested to find how a single element of S is affected by a change in each element of y. Mathematically, we’re interested in finding Si/∂y where Si is the ith element of the vector S. Here’s another catch, Si is simply the softmax function applied over the ith element of 𝑦⃗, which means Si has more dependence on yi and less on the other elements of 𝑦⃗. So we cannot compute ∂Si/∂y directly. So, we’ll find Si/∂yj where yj is an arbitary element of 𝑦⃗ and consider two cases where j = i and j ≠ i as shown below:

case 1: j= i
case 2: j ≠ i

So finally we can write:

local gradient of the softmax node

Next let’s compute the downstream gradient ∂l/∂y. Now, since the downstream gradient is the product of the local gradient and the upstream gradient, let’s again find l/∂yj and consider two cases where j = i and j ≠ i as shown below:

case 1: j= i
case 2: j ≠ i

So finally we can write:

So if the vector 𝑦⃗ = [y1, y2, y3, ..., ym, ..., yn] passes through the softmax node to get the probability distribution S = [S1, S2, S3, ..., Sm, ..., Sn], then the downstream gradient ∂l/∂y is given by [S1, S2, S3, ..., Sm — 1, ..., Sn] i.e., we keep all the elements of the softmax vector as is and subtract 1 from the element at the ground truth index. This can also be represented as S - 1[at index m]. So next time we want to backpropagate through a cross-entropy loss node, we can simply compute the downstream gradient as S - 1[at index m].

Alright, now that we’ve got a solid foundation in backpropagation and computational graphs, let’s look at backpropagation in RNN.

Backpropagation in RNN

Forward pass

In the forward pass, at a particular timestep, the input vector and the hidden state vector from the previous timestep are multiplied by their respective weight matrices and are summed up by the addition node. Next, they pass through a non-linear function and then they are copied: one of them goes as an input to the next time step, and the other goes into the classification head where it’s multiplied by a weight matrix to obtain the logits vector before computing the cross-entropy loss. This is a typical generative RNN setup where we model the network such that given an input character, it predicts the probability distribution of the next appropriate character. If you’re interested to build a character RNN in pytorch, please check my other article here. The forward pass equations are shown below:

Backward pass

In the backward pass, we start from the end and compute the gradient of the classification loss wrt the logits vector — details of which have been discussed in the previous section. This gradient flows backward to the matrix multiplication node where we compute the gradients wrt both the weight matrix and the hidden state. The gradient wrt the hidden state flows backward to the copy node where it meets the gradient from the previous time step. You see, a RNN essentially processes sequences one step at a time, so during backpropagation the gradients flow backward across time steps. This is called backpropagation through time. So, the gradient wrt the hidden state and the gradient from the previous time step meet at the copy node where they are summed up.

Next, they flow backwards to the tanh non-linearity node whose gradient can be computed as: ∂tanh(x)/∂x = 1−tanh²(x). Then this gradient passes through the addition node where it’s distributed to both the matrix multiplication nodes of the input vector and the previous hidden state vector. We usually don’t compute the gradient wrt the input vector unless there is a special requirement, but we do compute the gradient wrt the previous hidden state vector which then flows back to the previous time step. Please refer to the diagram for the detailed mathematical steps.

Let’s see how it looks in code.

Python Code

Andrej Karpathy has implemented character RNN from scratch in Python/Numpy, and his code brilliantly captures the backpropagation step we’ve discussed as shown below:

Code by Andrej Karpathy. Reused here under BSD license.

The full code can be found here, and the reader is highly encouraged to check it out. If you’re looking for a pytorch implementation of RNN with example, please check my other article here.

Why backpropagation in RNN isn’t effective

If you observe, to compute the gradient wrt the previous hidden state, which is the downstream gradient, the upstream gradient flows through the tanh non-linearity and gets multiplied by the weight matrix. Now, since this downstream gradient flows back across time steps, it means the computation happens over and over again at every time step. There are a couple of problems with this:

  1. Since we’re multiplying over and over again by the weight matrix, the gradient will be scaled up or down depending on the largest singular value of the matrix: if the singular value is greater than 1, we’ll face an exploding gradient problem, and if it’s less than 1, we’ll face a vanishing gradient problem.
  2. Now, the gradient passes through the tanh non-linearity which has saturating regions at the extremes. It means the gradient will essentially become zero if it has a high or low value once it passes through the non-linearity — so the gradient cannot propagate effectively across long sequences and it leads to ineffective optimization.

There is a way to avoid the exploding gradient problem by essentially “clipping” the gradient if it crosses a certain threshold. However, RNN still cannot be used effectively for long sequences.

I hope you’ve got a clear idea of how backpropagation happens in RNN. Let me know if you’ve any doubts. Let’s connect on Twitter and LinkedIn.

Image Credits

All the images used in this article are made by the author.

--

--

I write about effective learning, technology, and deep learning | 2x top writer | senior data scientist @MakeMyTrip