Wrapping Your Head Around Gradient Descent (with comics!)

Why do we need another article on gradient descent?

Johnny Burns
Towards Data Science

--

Because math is hard, and some of us need a slower and more visual approach.

The tl;dr — There is no tl;dr. We’re predicting the future here. Describing the details without making you rage quit takes a few words.

The only prerequisite to understanding this post is to know what a linear function is. We see linear functions stated in many different ways, but today I’m going to stick with “ y = mx + b ” since it’s probably the most widely used.

If I asked “What is the variable of the function y = mx + b?”, you’d probably say “x”.

Normally, you’d be right, but first “gotcha” about gradient descent is that it’s backwards.

Instead of “given this linear function, find these data points”, we’re saying “given these data points, find the linear function

Let that sink in, because it’s super weird.

Think of x and y more like a bunch of constants. We need to find values for “m” and “b” that fit those points. Then, we can use the line to predict new points.

If palm readers have taught us anything, predicting the future is not an exact science. Typically, no line fits perfectly through the points, so we want to figure out the line which comes closest to hitting the points.

What is a “close” fit?

This might seem like a stupid question. It seems like you should be able to measure how “far off the line” each point is, and divide by the number of points to figure out (on average) how close the line is to each point.

However, consider the following two lines, and ask yourself which one better represents the two data points:

The first line skews toward the bottom. The second line splits directly between the data points. The second line feels like a more accurate representation of the data, but the average error is the same:

For this reason and others, we usually measure the “closeness” of a datapoint by taking the distance from the line squared. This also ensures that there is only one “optimal” placement.

By using the average squared error, we can see that putting the line in the middle is better than putting the line at the top.

Great, how do I find the place with the smallest average squared error?

There are several methods for this. The simplest method is called ordinary least squares, but it doesn’t perform well on complex problems. To solve complex problems, we often use a method called “gradient descent”. Much of machine learning builds on this concept.

To start, let’s assume there is no “mx” in the equation, and just try to find good value for “b”. It’s easier because you only have one variable to solve for.

y = b

This might sound absurd. This means “m” locked at zero (a horizontal line). No matter what input you get, you always have to guess the same output.

Photo from Oleg Magni on pxhere

Bear with me a second. It’s gonna get crazy real fast.

For simplicity, let’s say we have just two data points. These data points represent recent home sales. We know the square footage of each house, and how much it sold for.

I’m going to call these data points “Mr. East’s House” and “Mr. West’s House” (because one closer to the east side of the graph, and the other is closer to the West).

We’re trying to find a line y = b, that minimizes the average squared error of Mr. East’s House and Mr. West’s House.

I know what you’re thinking. “Oh… just take the average sales price”.

b = 7.5

You’re right. That method works, and it minimizes the squared error. However, this is not gradient descent. This cute little “average” trick isn’t going to hold up to more complex problems, so lets instead use gradient descent to find the value.

In gradient descent, we start off by placing the line in a random spot. I’m going to place it here:

We want the line to be as close as possible to all houses. So we take a survey of the neighborhood. We knock on each of the doors and ask “which direction should we move the line?”. I’m going to call this the Move Survey:

Move Survey #1

*knock knock knock*

Left Photo by Brett Sayles on pexels. Right photo from pixnio

After the move survey, the votes are unanimous, so we move the line up. Let’s say we move it to 7.1

Now we take another movement survey:

Move Survey #2

*knock knock knock*

Now it’s getting interesting. Mr. West wants to move the line DOWN, and Mr. East wants to move the line UP. Imagine these neighbors in a tug-of-war (each wants the line closer to their property).

Photo from pxhere

So how do we decide where to move the line?

Put yourself in the shoes of a homeowner and think about this:

  • If your property is exactly on the line, you have no squared error.
  • Now, if we pull the line 1 space away from your property, your squared error is 1 (1²).
  • If we pull the line another space away from your property, your squared error is 4 (2²).
  • If we pull the line another space away from your property, your squared error is 9 (3²).

As shown here, each move away from your house more expensive than the last.

Errors are very cheap at close distance, and they get very expensive far away.

Looking at the two houses:

  • Mr. West is close to the line where errors are cheap.
  • Mr. East is far from the line, where errors are expensive.

When you think about the tradeoff, It’s worth adding some cheap errors to Mr. West, so we can subtract expensive errors from Mr. East.

Mr. East is very sensitive to changes. By that, I mean that moving the line away from Mr. East costs us a lot, and moving the line closer to Mr. East saves us a lot.

In contrast, Mr. West is not as sensitive to changes. We don’t get as much benefit by moving the line toward Mr. West. It also costs less to move the line away from Mr West.

Our survey did not take “sensitivity” into account. We should modify the survey. We need to ask not only “what direction should we move the line?” But “How sensitive are you to line movement?

We write the results down on our Move Survey:

Once we’ve surveyed all residents, we divide the total sensitivity by the number of residents to figure out an “average sensitivity”:

The positive average sensitivity here means that if we move the line UP, the benefits outweigh the drawbacks. The average sensitivity (0.4) tells us how strong the UP vs DOWN tradeoff is.

We take a small percentage of that number (say, 25%). We call this percentage the “learning rate”.

0.4 *.25 = 0.1

Then we move the line by that much.

Going back to the tug-of-war analogy, you can imagine that Mr. East and Mr. West are pulling in opposite directions, but Mr. East is far stronger than Mr. West (because he’s more sensitive).

Photo from user falco on pixabay

With the line now at 7.2, we do another survey, again asking:

  • what direction should we move the line?
  • How sensitive are you to line movement?

We can see that Mr. West is more sensitive than before, and Mr. East is less sensitive than before.

Mr. East’s sensitivity still outweighs Mr. West, but not by as much as before.

Let’s take 25% of this number and update the line:

.3 * .25 = .075

The line is now at 7.275. The move was even smaller than before.

If we continue to do more rounds of the survey, you will see that Mr. West gets more and more sensitive, while Mr. East gets less and less sensitive. The moves also get smaller and smaller.

As more rounds come, the tug-of-war becomes more evenly matched. Mr. East wants the line UP almost as much as Mr. West wants the line DOWN. At this point, the “average sensitivity” is very small (the two sensitivities are effectively cancelling each other out).

As the average sensitivity approaches zero (equilibrium), the line stops moving, since the line update is based on the average sensitivity.

I’m going to show you 3 different animations that all display the gradient descent as the line moves from our first survey (7.1) to equilibrium (7.5). These 3 views should help you to wrap your mind around the process.

View #1: Line Movement View

The following animation shows line move after each survey. You can see the moves get smaller and smaller as the line approaches the 7.5 (the minimum error point).

View #2: Individual Error View

This animation shows the errors of Mr. East and Mr. West. You can see here that we continue to trade expensive errors for cheaper errors until neither side is cheaper.

When you think about it, this makes sense. The only time you want to move the line is if you gain more than you give up. When the average sensitivity is zero, you get no benefit from moving the line in either direction (since both sides are equally sensitive).

View #3: Gradient Descent View

You’ll see this last type of graph see referenced a lot when talking about gradient descent. This one can be difficult to wrap your head around.

The graph has the same shape as the previous graph, but don’t be deceived, it is very different. In the previous graph, the X axis was the error (each error was a separate point), and the Y axis was squared error.

In this new graph:

  • The X axis shows different values for ‘b’.
  • The Y axis is the average squared error when ‘b’ is set to that value.

We see here that the ideal ‘b’ value is 7.5, which produces a 0.25 average squared error.

Our task in gradient descent is finding that point at the bottom of this curve (minimum squared error).

When we started (at 7.1), and continued to do move surveys, we approached the bottom after several rounds. Watch the animation here:

Before we began our gradient descent, we knew the curve would be shaped this way (it always is), but we didn’t know where the minimum value was.

After the first Move Survey, we discovered that a guess of 7.1 results in an average squared error of 0.41. At first, it might seem like this is all we know:

However, we actually know more. We know the average sensitivity is 0.4, which tells us two things:

  1. Residents want the line moved UP (so we know we’re on the LEFT side of the curve).
  2. The average sensitivity is high, so we’re pretty far from center.

I’m going to mark a blue line showing what we know about where we are.

Then, we do a second move survey, and we discover that a guess of 7.2 results in an average squared error of 0.34.

Because of the average sensitivity (.3), we know that we’re still on the LEFT side of the graph, but we’re getting closer to the center. I’m going to draw another line to represent this, but the line will not be as steep because we’re closer to the bottom.

As we continue to do Move Surveys, we make smaller steps, getting closer and closer to the bottom of the curve.

The “steepness” of blue line represents the average sensitivity. When the ‘b’ value far away from the minimum, the slope is very steep.

As the ‘b’ value approaches the minimum, the average sensitivity (aka, the slope), approaches zero.

“Gradient” is just another term for “slope”. As we move toward the bottom of the graph, the gradient gets smaller and smaller. This is why we call this method “gradient descent”.

Whether you have two data points (as shown here), or 1000 data points. The logic for gradient decent is the same:

  • Survey each point, asking “which direction do you want me to move?” And “how sensitive are you?
  • Take the average of all the sensitivities (gradient).
  • Multiply that by the “learning rate”
  • Update the line.
  • Repeat until we find equilibrium.

In actuality, you’ll never quite hit equilibrium, since the update gets smaller and smaller as it approaches the bottom, but after enough rounds, the difference will be insignificantly small.

Okay, so we can solve for “b”. What about “m”?

Now that we’ve discovered how to find the minimum “b” value, let’s ignore the “b” for a minute and think about how to find the minimum “m” value.

y = mx

This means that “b” is effectively stuck at a constant of zero. We can change the slope of the function using the “m” value, but the line will always pass through (0,0)

Let’s take another look at Mr. East and Mr. West, and take a guess at where the line might be:

At m=1, we see that Mr. West is 2 units above the line, and Mr. East is 2 units below the line. Average squared error is 4.

Intuitively, it seems like we’re in equilibrium and cannot improve on this. However, that is not the case.

The following animation shows what happens if we change the “m” value from 1 to 0.9

Notice how we got 1 full space closer to Mr. East, but we only got 1/2 space further from Mr. West. The total squared error is reduced to 3.625

This makes sense when you think the equation y = mx

The “m” value is multiplied by “x”. Because Mr. East (x=10) is twice as far as Mr. West (x=5), Mr. East is twice as affected by changes in “m”.

In general, the further East you are, the more sensitive you are to slope changes.

Since we get twice the benefit moving toward Mr. East, why not just move the line all the way to Mr. East? We can move a full 2 units closer to Mr. East, and only Move 1 unit further from Mr. West.

Since we’re 3 spaces away from Mr. West, the new average squared error is (4.5). This is worse than when we started… what gives?

Remember: As we move the line away from Mr. West, each error is more expensive than the last (due to squared error). At some point, Mr. West is so sensitive to movement that it’s not worth making the tradeoff, even though you can get twice as close to Mr. East by moving the line.

While the line above had the smallest RAW error, it does NOT have the smallest SQUARED error.

So Mr. East and Mr. West are both sensitive, but for different reasons. It turns out that a point’s sensitivity to slope changes is determined by two factors:

  • How far away from the line is the point?
  • How far east is the point?

We need to take both of these into account when we calculate the sensitivity.

sensitivity = {distance from line} * {x value}

Let’s put the line back at (m = 0.9)

For the first time in this post, you probably can’t tell which direction to move the line by simply looking at it.

Let’s take a move survey, calculate the sensitivity, and find out.

We can see here that even though Mr. West has half the “x” value, he is more sensitive due to his distance from the line.

We know the slope needs to move UP, but adding 1.25 is much too far and would place the line above both points. That’s why we only move it by a percentage of that amount (the “learning rate” I described previously). I’ll use a learning rate of 1% here.

1.25 * .01 = .0125

New “m” = .9 + .0125

The new line is at .9125. After another move survey, we can see that the squared error has decreased.

Multiply by the learning rate to get the update value:

0.46875 * .01 = .0046875

New “m” value = .9125 + .0046875

After the update, the line is at .917. If we continue rounds of the move survey, the line will get closer and closer to its equilibrium (.92).

Here is what the move survey looks like at .92

The gradient decent process looks very similar to what we saw with the gradient descent for “m”.

Gradient Descent View

Let’s look at the gradient decent view of slope changes:

We see here that the squared error is minimized with an “m” value of 0.9.

One thing to note about the “m” curve is that it’s much steeper than the curve we saw for “b”. This is because the “m” value is multiplied by each of our “x” values, so even a small update to “m” can result in a large change (for better or worse). This is why it’s important to keep the learning rate small.

Putting it all together

First, we figured out how to solve for “b” when “m” is held constant. Then, we figured out how to solve for “m” when “b” is held constant. How can we move both “m” and “b” to find the best overall fit?

When both “m” and “b” are moveable:

  1. Pretend that “m” is constant. Do a Move Survey asking all residents which direction they want the “b” value moved.
  2. Pretend that “b” is constant. Do a Move Survey asking all residents which direction they want the “m” value moved.
  3. Update the “b” value based on the average sensitivity from the “b” Move Survey.
  4. Update the “m” value based on the average sensitivity from the “m” Move Survey.
  5. There is only a single place where the sensitivity for both “m” and “b” are zero. The line will converge on that spot.

A Full Example

Let’s run through what happens when we start with a horizontal line sitting at (y = 0x + 7.5”), and do full gradient descent.

Round 1:

  1. During the “b” Move Survey, we discover Mr. East and Mr. West are equally sensitive to changes in “b” (they’re both 0.5 away). The average sensitivity for “b” is zero.
  2. During the “m” Move Survey, we discover that while they’re both the same distance away, Mr. East is twice as sensitive to “m” changes because he has twice the “x” value.
  3. There is no update to the “b” value, since there is no average “b” sensitivity.
  4. We update the “m” value by 1% of the average “m” sensitivity.

The line is now “y = 0.0125x + 7.5”

Round 2:

  1. This time, Mr. West is further away from the line than Mr. East. This means he is more sensitive to changes in “b”. Mr. West is stronger and wants to pull the “b” value DOWN.
  2. After the “m” Move Survey, we discover that Mr. East is still more sensitive to “m” changes and wants to pull the “m” value UP.
  3. We move the “b” value DOWN by 25% of the average “b” sensitivity.
  4. We move the “m” value UP by 1% of the average “m” sensitivity.

As we continue move “m” and “b” toward equilibrium, we approach the only place where both “m” and “b” are zero (the minimum).

Let’s look at two views of this to wrap things up.

Line Movement View

The following animation shows repeated rounds of gradient descent, updating both “b” and “m”. Watch how the “b” value drops and the “m” value increases.

Gradient Descent View

Let’s look at the “gradient descent” view of what’s happening. Until now, we’ve shown two “gradient descent” graphs.

  1. b” value vs. average squared error.
  2. m” value vs. average squared error.

To show both “mANDb” vs. average squared error, we need a third axis on the graph.

The two horizontal axes will represent values for “m” and “b”. The vertical axis represents the squared error at those points.

3-axis gradient descent graphs typically look something like this. This graph is not based on our data, but helps you wrap your head around the concept.

This graph helps show that there is only one place where both “m” and “b” are in equilibrium (the bottom).

Here is the gradient descent graph based on our actual data. It is a little bit harder to look at, but it has the same “downhill-to-minimum” property:

If you inspect this graph closely, you can see that the minimum point has an “m” value of 6, and a “b” value of 0.2.

The line “ y = 0.2x + 6 “ is in fact the ideal function and results in zero error (it crosses through both points).

This graph looks different than the previous graph for a couple reasons:

  • Changes in the “m” side are much more sensitive than changes in the “b” side, giving the graph this distinct “taco” shape.
  • The minimum possible “m” value changes a little bit for each value of “b”. This is why you see the shallow channel that moves in both the “m” and “b” direction.

About the Author

I’m Johnny Burns, founder of FlyteHub.org, a repository of free open-source workflows to perform Machine Learning with no coding. I believe that collaborating on AI will lead to better products.

If you’re interested in seeing how the math backs up our “sensitivity” formulas, I will do a follow up post to explain.

--

--