The world’s leading publication for data science, AI, and ML professionals.

Creating a Gradient Descent Animation in Python

How to plot the trajectory of a point over a complex surface

Photo by Todd Diemer on Unsplash
Photo by Todd Diemer on Unsplash

Let me tell you how I created an animation of Gradient Descent just to illustrate a point in a blog post. It was worth it since I learned more Python by doing it and unlocked a new skill: making animated plots.

Gradient descent animation created in Python. Image by the author.
Gradient descent animation created in Python. Image by the author.

I’ll walk you through the steps of the process I followed.

A bit of background

A few days ago, I published a blog post about gradient descent as an optimization algorithm used for training Artificial Neural Networks.

I wanted to include an animated figure to show how choosing different initialization points for a gradient descent optimization can produce different results.

That’s when I stumbled upon these amazing animations created by Alec Radford years ago, and shared on a Reddit comment, illustrating the difference between some advanced gradient descent algorithms, like Adagrad, Adadelta and RMSprop.

Since I’ve been pushing myself to replace Matlab with Python, I decided to give it a go and try to code a similar animation myself, using a "vanilla" gradient descent algorithm to start with.


Let’s go, step by step.

Plot the surface used for optimization

The first thing we do is import the libraries we’ll need and define the mathematical function we’ll want to represent.

I wanted to use a saddle point surface, so I defined the following equation:

Saddle surface equation. Image by the author.
Saddle surface equation. Image by the author.

We also create a grid of points for plotting our surface. [np.mgrid](https://numpy.org/doc/stable/reference/generated/numpy.mgrid.html) is perfect for this. The complex number 81j passed as step length indicates how many points to create between the start and stop values (81 points).

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

# Create a function to compute the surface
def f(theta):
  x = theta[0]
  y = theta[1]
  return x**2 - y**2

# Make a grid of points for plotting
x, y = np.mgrid[-1:1:81j, -1:1:81j]

The next thing to do is create a figure and axes, plot the surface, and format it.

# Create a figure with axes in 3d projection
fig1 = plt.figure(figsize = (4,8), dpi = 300)
ax1 = fig1.add_subplot(111, projection='3d')

plot_args = {'rstride': 2, 'cstride': 2, 'cmap':"coolwarm",
             'linewidth': 0.01, 'antialiased': False,
             'vmin': -1, 'vmax': 1, 'edgecolors':'k'}

ax1.plot_surface(x, y, f([x,y]), **plot_args)

ax1.view_init(azim= 120, elev= 37)
ax1.set_xlabel('x')
ax1.set_ylabel('y')

plt.plot()
Saddle surface. Image by the author.
Saddle surface. Image by the author.

Compute the Gradient Descent trajectories

We’ll need to implement the gradient descent algorithm and apply it to our surface. That way, we will compute the evolution of the x and y coordinates throughout the iterations of the optimization.

As I mentioned in my post about gradient descent, the simplest implementation is an iterative update of the x and y values in proportion to the gradient, until we arrive at a local minimum.

Once you initialize x and y at any arbitrary point to start the optimization, the algorithm is based on the following steps:

  1. Compute the gradient of the surface (partial derivatives) at the current point (x, y).
  2. Update the coordinates in proportion to the gradient.
  3. Evaluate the surface function at the new coordinates.

This is done for a predefined number of iterations or until the difference between one iteration and the previous one is smaller than some threshold.

To achieve this, I created a function taking as arguments the initialization values, step size parameter, and the desired number of iterations.

It returns an array with the x, y, z coordinates for each iteration.

# Define a function to compute the gradient
def grad_f(theta):
    returnValue = np.array([0.,0.])
    x = theta[0]
    y = theta[1]
    returnValue[0] += 2*x
    returnValue[1] += - 2*y
    return returnValue

# Gradient descent routine
def gradient_descent(x_init, y_init, step_size, n_iters):
  eta = step_size

  theta = np.tile([x_init, y_init], (n_iters,1) )
  g_t = np.zeros((n_iters,2))
  z = np.tile([f(theta[0])], n_iters )

  for k in range (1, n_iters):

      g_t[k] = grad_f(theta[k-1])
      theta[k] = theta[k-1] - eta * g_t[k]
      z[k] = f(theta[k])

  # Setting up Data Set for Animation
  dataSet = np.stack((theta[:,0], theta[:,1], z), 1)

  return dataSet

I tested it with two different initialization points, one over a symmetry line on y = 0, and the other one at an offset y value.

I used the same step size and number of iterations for both in order to compare them. It also makes it easier to animate later on since the two results will have the same number of frames.

n_iters = 200
step_size = 0.0

x_sym, y_sym = 1, 0
x_off, y_off = 1, 0.25

data_sym = gradient_descent(x_sym, y_sym, step_size , n_iters)
data_off = gradient_descent(x_off, y_off, step_size , n_iters)

To verify the results, I plotted the trajectory for each coordinate separately

fig2, ax2 = plt.subplots(1,3, figsize = (6,2), dpi = 300)

ax2[0].plot(data_sym[:,0])
ax2[0].plot(data_off[:,0])
ax2[0].set_xlabel('Iteration')
ax2[0].set_ylabel('x')

ax2[1].plot(data_sym[:,1])
ax2[1].plot(data_off[:,1])
ax2[1].set_xlabel('Iteration')
ax2[1].set_ylabel('y')
ax2[1].set_ylim(0, 10)

ax2[2].plot(data_sym[:,2], label = 'Symmetry')
ax2[2].plot(data_off[:,2], label = 'Offset')
ax2[2].set_xlabel('Iteration')
ax2[2].set_ylabel('z')
ax2[2].set_ylim(-5, 2)

handles, labels = ax2[2].get_legend_handles_labels()
fig2.legend(handles, labels, loc=(0.3, 0.75), fancybox = False, 
            frameon = False, ncols = 2)
fig2.suptitle('Evolution of coordinates during gradient descent', y = 0.99)

fig2.tight_layout()
Evolution of x, y, and z coordinates during optimization. Image by the author.
Evolution of x, y, and z coordinates during optimization. Image by the author.

For the symmetry case, y and z converge to 0, while the optimization initialized at an offset moves down the z-axis quickly. All as expected.

Now we can visualize the trajectories on top of the surface.

Plot the trajectories of gradient descent on top of the optimization surface

Initially, I thought this part was just a matter of plotting the results on the same coordinate axes we used for the surface, ax1. Unfortunately, that won’t work. With Matplotlib we’ll have to do a bit more.

I added new axes to the figure and set them to have the same view angles and position of ax1 . Then I plot the trajectory lines.

The most important part here is using set_zorder to manually place the lines on top of the surface. Second, we need to use set_axis_off to make the new axes invisible.

newax = fig1.add_axes(ax1.get_position(), projection='3d',
                     xlim = ax1.get_xlim(),
                     ylim = ax1.get_ylim(),
                     zlim = ax1.get_zlim(),
                     facecolor = 'none',)
newax.view_init(azim= 120, elev= 37)
newax.set_zorder(1)
ax1.set_zorder(0)

newax.plot3D(data_sym[:,0], data_sym[:, 1], data_sym[:, 2], 
             c='blue', alpha = 0.7)
newax.plot3D(data_off[:, 0], data_off[:, 1], data_off[:, 2],
             c='red', alpha = 0.7)

newax.plot3D(data_sym[0, 0], data_sym[0, 1], data_sym[0, 2],
               ms = 2.5, c='black', marker='o')
newax.plot3D(data_off[0, 0], data_off[0, 1], data_off[0, 2],
               ms = 2.5, c='black', marker='s')
newax.set_axis_off()
sq = u'■'
cir = u'●'
ax1.text2D(0.05, 0.95, f'Initilizationn'
                      f'{sq}: x = {x_off}, y = {y_off}n'
                      f'{cir}: x ={x_sym}, y = {y_sym}', transform=ax1.transAxes)
fig1
Static representation of gradient descent trajectory. Image by the author.
Static representation of gradient descent trajectory. Image by the author.

Great! Our plot makes sense, so we can start creating the animated version.

Create the animated figure

Remember how movies were done in the old days with sequences of static images?

Eadweard Muybridge, Waugsberg - Unknown source, available on Wikipedia. Public Domain
Eadweard Muybridge, Waugsberg – Unknown source, available on Wikipedia. Public Domain

Well, they are still done the same way.

That’s why I stored the updated x, y, and z coordinates at each iteration on an array.

We need to find a way to plot sequentially those multiple "pictures" or frames of the optimization process.

I had no clue about how to do this, so I started to look for code that did similar things to what I wanted to accomplish.

I found a brilliant article by Zack Fizell here on Medium where he showed how to animate plots in Python.

Luckily, there is a dedicated class to do this in Matplotlib, animation.

Zack explained all the details of the code, and you can go check them out in his article. For me, the most important part was using

from matplotlib import animation

# Define Aimation function
def animate_func(num):
  ... # Plot trajectories, index arrays with incremental num

# Plotting the Animation
fig = plt.figure()
ax = plt.axes(projection='3d')
line_ani = animation.FuncAnimation(fig, animate_func, interval=100,   
                                   frames=numDataPoints)
plt.show()

To plot all the steps of the trajectory sequentially.

Easy enough, right?

What animation.FuncAnimation does is updating the figure we pass as an input, using whatever function we define, in Zack’s case, animate_func. The function takes as an argument the number of frames we specify and is called frames times.

The sequence of images is stored as an [TimedAnimation](https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.TimedAnimation.html#matplotlib.animation.TimedAnimation) object, with a time interval (in milliseconds) between them.

All of this is well explained in the official matplotlib documentation.

In our case, we need to include in the function all the workaround to avoid the new axes overlapping with the surface every time we update the trajectory lines.

Here’s how I defined the update function:

def descent_animation(num):
    # Clear the axes where we are plotting the tracjectories
    newax.clear()

    # Manually adjust the order of the axes
    newax.set_zorder(1)
    ax.set_zorder(0)

    # Hide the axes in the front plane
    newax.set_axis_off()

    # Plot new frame of trajectory line for the symmetry case
    newax.plot3D(data_sym[:num+1, 0], data_sym[:num+1, 1],
                 data_sym[:num+1, 2], c='blue', alpha = 0.7)
    # Updating Point Location
    newax.scatter(data_sym[num, 0], data_sym[num, 1], data_sym[num, 2],
               s = 10, c='blue', marker='o', edgecolor = 'k', linewidth = 0.5)
    # Adding Constant Origin
    newax.plot3D(data_sym[0, 0], data_sym[0, 1], data_sym[0, 2],
               ms = 2.5, c='black', marker='o')

    # Plot new frame of trajectory line for the offset case
    newax.plot3D(data_off[:num+1, 0], data_off[:num+1, 1],
                 data_off[:num+1, 2], c='red', alpha = 0.7)
    # Updating Point Location
    newax.scatter(data_off[num, 0], data_off[num, 1], data_off[num, 2],
               s = 10, c='red', marker='o', edgecolor = 'k', linewidth = 0.5)
    # Adding Constant Origin
    newax.plot3D(data_off[0, 0], data_off[0, 1], data_off[0, 2],
               ms = 2.5, c='black', marker='s')

    # Setting Axes Limits and view angles
    newax.set_xlim3d([-1, 1])
    newax.set_ylim3d([-1, 1])
    newax.set_zlim3d([-1, 1])
    newax.view_init(azim= 120, elev= 37)

Finally, to create the animation, I created a new figure and repeated the process of adding the second pair of axes. This time I only plotted the surface, since the trajectories will be handled by the descent_animation function.

fig = plt.figure(figsize = (4, 3), dpi = 300)
ax = fig.add_subplot(111, projection='3d')

plot_args = {'rstride': 2, 'cstride': 2, 'cmap':"coolwarm",
            'linewidth': 0.01, 'antialiased': False,
            'vmin': -1, 'vmax': 1, 'edgecolors':'k'}
x, y = np.mgrid[-1:1:81j, -1:1:81j]

# Plot surface
ax.plot_surface(x, y, f([x,y]), **plot_args)
ax.view_init(azim= 120, elev= 37)
ax.set_xlim3d([-1, 1])
ax.set_ylim3d([-1, 1])
ax.set_zlim3d([-1, 1])

ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# Add second pair of axes
newax = fig.add_axes(ax.get_position(), projection='3d',
                    xlim = ax.get_xlim(),
                    ylim = ax.get_ylim(),
                    zlim = ax.get_zlim(),
                    facecolor = 'none',)
newax.view_init(azim= 120, elev= 37)

# Manually adjust the order of the axes
newax.set_zorder(1)
ax.set_zorder(0)

# Hide axes in the front plane
newax.set_axis_off()

# Add some text to distinguish the two initialization points
sq = u'■'
cir = u'●'
ax.text2D(0.05, 0.95, f'Initilizationn'
                      f'{sq}: x = {x_off}, y = {y_off}n'
                      f'{cir}: x ={x_sym}, y = {y_sym}', 
          transform=ax.transAxes,
          fontsize = 8)

# Plotting the Animation
line_ani = animation.FuncAnimation(fig, descent_animation, interval=100,
                                   frames= n_iters + 1)
plt.show()
Final result: gradient descent animation created on Python. Image by the author.
Final result: gradient descent animation created on Python. Image by the author.

Saving and displaying the animation

We can save the animation using a few lines of code.

filename = "gradient_func.gif"
writergif = animation.PillowWriter(fps=500)
line_ani.save(filename, writer=writergif)

The PillowWriter method works well for storing Gif files.

The problem with Gif files is that you can end up with an unnecessarily large file if you have many frames, and want to preserve good image quality.

To optimize the file size, we could play around with the interval passed to FUncAnimation and specify the FPS and DPI we want to use when creating and saving the figure. Otherwise, we could use EZGif’s Optimization to reduce the file size.

Another option would be storing it as an MP4 or , even better, WebM file, using [FFMpegWriter](https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.FFMpegWriter.html#matplotlib.animation.FFMpegWriter).

As I have discussed before, using WebM instead of Gif files can produce a great improvement in speed if we want to use the figures for web display, or just reduce the file size. Unfortunately, WebM is still not supported on Medium.

Finally, if you are using Jupyter Notebooks or working on Google Colab, you can display the animated figure using:

from IPython.display import Image
Image(open(filename,'rb').read())

Additional

This type of animated figure is great for comparing the trajectory of different types of optimization algorithms.

Below, you can see a comparison of the performance of Stochastic Gradient Descent (SGD) – the one I’ve shown in this post- and a more advanced method called Adam (derived from Adaptive Moment Estimation), which is currently one of the industry standards for training Artificial Neural Networks, because of its robustness and speed.

Both were initialized on the point x = 0.2, y = 1e-8, to provide a challenging situation. There is very little offset from the exact symmetry line (y = 0) and the small x coordinate provides little room for building momentum when "descending". You can see how Adam manages to break the symmetry after very few iterations when compared to SGD.

If you are interested in learning more about the differences between these two algorithms, check out my article about advanced methods for gradient descent optimization.

Comparison of SGD vs Adam. Image by the author.
Comparison of SGD vs Adam. Image by the author.

So that’s all for this one.

Thank you for reading! I hope you could find any useful information from what I shared, and perhaps use it for your own projects.

If you are interested in learning more about gradient descent and neural networks, I recommend you read my other articles on those topics.

DL Notes: Gradient descent

DL Notes: Feedforward Artificial Neural Networks

I’m writing a series of articles about Deep Learning and Machine Learning, but I’ll be also publishing meta content about other things I learn while writing those articles – this post is an example. So not only the theory and applications of DL but also Python Programming tips like this one.

Be sure to follow me and subscribe to my emails to get notified whenever I publish new posts!


Related Articles