Gradient descent is probably the most well-known optimisation algorithm, and in the world of machine learning, you must have been either directly or indirectly used gradient descent. You likely already know that it can help to minimise a loss function by moving tiny steps towards the negative direction of the gradient. But how exactly? And instead of directly using gradient methods incorporated in the package, how can we implement our own and have a closer look of gradient descent?
Understanding Gradient
In this post, let’s break Gradient Descent piece by piece and grow a deeper understanding of it by implementing our own. Now let’s start with the simplest example:
Suppose we have the function f(x) = x^2 , where x ranges from -1 to 1, given x randomly start in the range, how to find the minimum value of f(x) ?
Clearly in this example, the minimum locates in x = 0 , and we would like to
make x moves to the right if x < 0
make x moves to the left if x > 0
So how can we do that? You might also notice that the gradient of the function
grad > 0 when x > 0
grad < 0 when x < 0
Which is the opposite of the moving direction of x ! Gradient descent makes use of this and let x to move to the opposite direction of its gradient. In this scenario, no matter where x is, it will move to the minimum.
The next question is how much further should x be moving, and this leads to learning rate, which comes as parameters in numerous Machine Learning algorithms.
In fact, in the example above, tanα is the gradient, and at each step, we make x move with the step size tanα * learning_rate , where learning rate becomes an adjustable parameter which controls the speed of descent.
First Implementation
With the above in mind, we can have our first implementation to find the minimum value of function x^2 .
In each iteration, x -= lr*grad_fn(x) makes it always moving to the minimum, and we can also plot the trajectory of x :
Starting from -1, x descent to 0 gradually. Also notice that it moves faster at the beginning and slows down when approaching the goal, this is because the absolute gradient is higher at start.
Optimising Parameters
Now let’s get to an example of optimising parameters. Suppose we try to optimise parameters with function:
The objective would be to minimise lose (y - f(x))^2 , and the corresponding gradient of parameter a and b would be:
Note that x and y here are considered constants.
Now let’s implement the function optimisation:
Here we update the parameter on each input x, y pair, and got result:
You can see that the updating process is volatile and for parameter a , it first goes down(to the opposite direction) before it moves to the optimal value. This is because we updated parameters on each input, as each individual input could potentially update parameter in arbitrary direction. Are we able to generate a smoother line? The answer leads to batch gradient descent.
Batch Gradient Descent
In actual use cases, parameters are not updated each time on single data point, instead batch update is applied, where in each iteration(epoch), parameters are updated based on the average of a batch of data points. In this case, our updating formula would be:
The summation of gradient of a batch is calculated and taken average as the gradient to be updated:
Here requires a higher number of iterations, as each batch only contributes to 1 update(we took all data point into 1 batch, one could different combinations), and this time the update process will look like:
Way smoother and stabler, but the trade-off is that more computations required.
Conclusions
Hope till here you’ve acquired a slightly better understanding of vanilla gradient descent. Traditional gradient descent does not guarantee optimality and in fact, it could easily fall into local minimum when there are multiple basins in the objective function, as parameters each time only move slightly based on gradient and no stochasticity is allowed. If you are interested, please check out the full implementation here.
Next up, I will be introducing SGD & Momentum, which adds some variations to vanilla gradient descent and solve some problems of it.