Intuition (and maths!) behind multivariate gradient descent

Machine Learning Bit by Bit: bite-sized articles on machine learning

Misa Ogura
Towards Data Science

--

Hello again!

Machine Learning Bit by Bit aims to share my own endeavour to explore and experiment on topics in machine learning.

In my last post, we discussed:

  1. What is gradient descent.
  2. How it may be useful in linear regression.
  3. How it actually works with a simple univariate function.

In this post, we’re going to extend our understanding of gradient descent and apply it to a multivariate function.

In my opinion, this offers a smooth transition to apply gradient descent to more complex functions, as well as helps you solidify the knowledge of gradient descent, which will be essential in the next topic up in the series — linear regression.

Alright, let’s get into it.

Multivariate gradient descent — intuition

First things first, let’s talk about the intuition. What does it actually mean to apply gradient descent to a multivariate function?

I will try to explain this by visualising:

  1. the target multivariate function
  2. how gradient descent works with it

Remember, gradient descent is an algorithm to find a minimum of a function. Therefore, our aim here is to find the minimum of a function with more than one variable.

In my last post, we used this univariate quadratic function as an example:

And here is the bivariate (two variables) quadratic function, J(θ1, θ2), which we are going to look at today:

Fig.1 below visualises J(θ1, θ2) in various ways — 3D plots on the left (Fig.1a) and centre (Fig.1b) and a contour plot (Fig.1c) on the right. The contour plot is one way to represent a 3D function on a 2D plane. It’s as if you look down on the 3D graph from the top and squish it along the z-axis. Fig.1b, which is a rotated version of Fig.1a, should provide you with some visual intuition.

When applying gradient descent to this function, our objective still remains the same, except that now we have two parameters, θ1 and θ2, to optimise:

So far so good…

Update rule

Another feature of gradient descent is that, it is an iterative algorithm. Therefore, it uses the update rule to systemically and efficiently update the values of parameters after each iteration.

This was the update rule for univariate gradient descent:

Where α is the learning rate and dJ(θ)/dθ is the derivative of J(θ) — i.e. the slope of a tangent line that touches the J(θ) at given θ.

Now that we have two variables, we need to supply an update rule for each:

These equations look almost identical to the one for univariate functions. The only change here is the derivative term, ∂J(θ1, θ2)/∂θ1 and ∂J(θ1, θ2)/∂θ2. But don’t be alarmed by them. The symbol instead of d simply means that it is a partial derivative.

Partial derivative

In partial derivatives, just as in normal derivatives, we are still interested in the slope of a tangent that touches J(θ1, θ2) at given θ1 or θ2… but this or here is crucial.

Essentially, we cannot move both θ1 and θ2 at the same time when looking at a tangent. Therefore, we focus on only one variable at a time, whilst holding the other constant. Hence, the name partial.

I’ll try to explain this better with the help of a graph (Fig.2). Let’s consider θ1 as a variable and keep θ2 constant, in other words, a partial derivative of θ1.

What keeping θ2 constant visually translates to is a θ1-J(θ1, θ2) plane (Fig.2 blue square) cutting through the graph at a particular value of θ2. The Fig.2 red line represents the intersect between the θ1-J(θ1, θ2) plane and the J(θ1, θ2) plot, which becomes the function of interest in the partial derivative.

Now, if we extract the blue plane along with the red line, what we end up is a good-old univariate function with θ1 as a parameter, on a 2D plane, just like what we saw in the last post!

Therefore, we can calculate the partial derivate term in the update function as follows, as ∂θ1 shrink towards zero:

Proof of partial derivative formula (optional)

The equation above utilises a well-known formula for partial derivatives, so it omits the proof of how you actually calculate the partial derivative to reach 2θ1. If you are not interested in proving it, please skip this section.

Now, this is for you — who are a bit like me and have an obsessive compulsive urge to see what’s happening behind the scene…

Et voila!

Simultaneous update

Applying the same logic for the partial derivation for θ2, we can simplify update rules as follows:

One last thing, but not least, to mention is the concept of simultaneous update — that is, when gradient descent is applied to multivariate functions, the update to every parameter has to happen all at the same time, rather than sequentially.

I found a very intuitive description:

A simple analogy would be walking. You typically don’t walk east-west direction first, and then north-south. You walk the shortest direction, i.e., move in both coordinates simultaneously. (StackExchange)

What that means in practice is that, in every iteration, we have to assign each of the newly calculated parameter to a temporary variable until we finish calculating all the parameters. Using our example, it looks like this:

Then,

Cool, we have every piece of the puzzle to proceed.

Gradient descent in action

The time has come!

We’re now ready to see the multivariate gradient descent in action, using J(θ1, θ2) = θ1² + θ2². We’re going to use the learning rate of α = 0.2 and starting values of θ1 = 0.75 and θ2 = 0.75.

Fig.3a shows how the gradient descent approaches closer to the minimum of J(θ1, θ2) on a contour plot. Fig.3b is a plot of J(θ1, θ2) against the number of iteration and is used to monitor convergence.

As you can see in the figure legend, from the seventh to the eighth iteration, J(θ1, θ2) decreases by 0.00056, which is less than the threshold of 10^(-3), at which point we can declare convergence.

Therefore, we’ve found a combination of parameters θ1 = 0.013 and θ2 = 0.013 that satisfies our objective.

Wrap-up

Okay, so this time we looked at the application of gradient descent to a multivariate function. Next time, we will finally look at the use of gradient descent in linear regression. Stay tuned!

Please post any feedback, questions, or requests for topics. I would also appreciate 👏 if you like the post, so others can find this article too.

Thanks!

--

--

Senior Machine Learning Engineer @Healx | Creator of github.com/MisaOgura/flashtorch | Published Scientist | Co-founder of @womendrivendev