Linear Regression & Gradient Descent

Learn about the most fundamental machine learning algorithm

Vyacheslav Efimov
Towards Data Science

--

Linear regression is one of the fundamental algorithms existing in machine learning. Understanding its internal workflow helps in grasping the main concepts of other algorithms in data science. Linear regression has a wide range of applications where it is used to predict a continuous variable.

Before diving into the inner workings of linear regression let us first understand a regression problem.

Introduction

Regression is a machine learning problem aiming to predict a value of a continuous variable given a feature vector which is usually denoted as x = <x₁, x₂, x₃, …, xₙ> where xᵢ represents a value of the i-th feature in data. In order for a model to be able to make predictions, it has to be trained on a dataset containing mappings from feature vectors x to corresponding values of a target variable y. The learning process depends on the type of algorithm used for a certain task.

Regression problem consists of finding such a function f which can closely approximate f(x) = y for all dataset objects

In the case of linear regression, the model learns such a vector of weights w = <w₁, w₂, w₃, …, wₙ> and a bias parameter b which try to approximate a target value y as <w, x> + b = x₁ * w₁ + x₂ * w₂ + x₃ * w₃ + … + xₙ * wₙ + b in the best possible for every dataset observation (x, y).

Formulation

When building a linear regression model the ultimate goal is to find a vector of weights w and a bias term b that will more closely bring predicted value ŷ to the real target value y for all the inputs:

Regression equation. <w, x> represents the inner product between vectors w and x.

To make things easier, in the example we are going to look at, a dataset with single feature x is going to be used. Therefore, x and w are one-dimensional vectors. For simplicity, let us get rid of inner product notation and rewrite the equation above in the following way:

Regression equation for one predictor

Loss function

In order to train an algorithm, a loss function has to be chosen. The loss function measures how good or bad the algorithm made predictions for a set of objects at a single training iteration. Based on its value, the algorithm adjusts the parameters of the model in the hope that in the future the model will produce fewer errors.

One of the most popular loss functions is Mean Squared Error (or simply MSE) which measures the average square deviation between predicted and true values.

MSE formula (y is a true value, ŷ is a predicted value, n is the number of objects)

Gradient descent

Gradient descent is an iterative algorithm of updating the weights’ vector for minimizing a given loss function by searching for a local minimum. Gradient descent uses the following formula on each iteration:

Gradient descent update formula
  • <w> is a vector of model weights on the current iteration. Computed weights are assigned to <w>’. During the first iteration of the algorithm, weights are usually initialized randomly but there exist other strategies as well.
  • alpha is usually a small positive value, also known as a learning rate, — hyperparameter which controls the speed rate of finding a local minimum.
  • upside-down triangle denotes a gradient — vector of partial derivatives of a loss function. In the current example, the vector of weights consists of 2 components. So, to compute a gradient of <w> 2 partial derivatives need to be computed (f represents a loss function):
Gradient of vector w

The update formulas can be rewritten in the following way:

Update formulas for weight and bias

Right now the objective is to find partial derivatives of f. Assuming that MSE is chosen as a loss function, let us compute it for a single observation (n = 1 in the MSE formula), so f = (y — ŷ)² = (y — wx — b)².

Partial derivative for variable w
Partial derivative for variable b

The process of adjustment of model’s weights based on a single object is called stochastic gradient descent.

Batch Gradient Descent

In the section above, model parameters were updated by calculating MSE for a single object (n = 1). In fact, it is possible to perform a gradient descent for several objects in a single iteration. This way of updating weights is called batch gradient descent.

Formulas for updating weights in such a case can be obtained in a very similar manner, compared to stochastic gradient descent in the previous section. The only difference is that here the number of objects n has to be taken into consideration. Ultimately, the sum of the terms of all objects in a batch is computed and then divided by n — the batch size.

Update formulas for batch gradient descent

Visualization

When dealing with a dataset consisting only of a single feature, the regression results can be easily visualized on a 2D-plot. The horizontal axis represents values of the feature while the vertical axis contains target values.

The quality of a linear regression model can be visually evaluated by how closely it fits dataset points: the closer the average distance between every dataset point to the line, the better the algorithm is.

Two regression lines built for the same dataset. The line on the left fits data much closer than the line on the right. Thus the regression model on the left is considered as a better one.

If a dataset contains more features, then visualization can be done by using dimensionality reduction techniques like PCA or t-SNE applied to features to represent them in lower dimensionality. After that, new features are plotted on 2D or 3D-plots, as usual.

Analysis

Linear regression has a set of advantages:

  1. Training speed. Due to the simplicity of the algorithm, linear regression can be rapidly trained, compared to more complex machine learning algorithms. Moreover, it can be implemented through the LSM method which is also relatively fast and easy to understand.
  2. Interpretability. A linear regression equation built for several features can be easily interpreted in terms of feature importance. The higher the value of the coefficient of a feature, the more effect it has on the final prediction.
Given this linear regression equation, feature X₃ has the highest feature importance

On the other hand, it comes with several disadvantages:

  1. Data assumptions. Before fitting a linear regression model it is important to check the type of dependency between output and input features. If is linear, then there should not be any issue with fitting it. Otherwise, the model is not normally able to fit the data well since the equation has only linear terms in it. In fact, it is possible to add higher degrees into the equation to turn the algorithm into polynomial regression, for instance. However, in reality, without a lot of domain knowledge it is often difficult to correctly foresee the type of dependency. This is one of the reasons why linear regression might not adapt to given data.
  2. Multicollinearity problem. Multicollinearity occurs when two or more predictors are highly correlated to each other. Imagine a situation when a change in one variable influences another variable. However, a trained model has no information about it. When these changes are large, it is difficult for the model to be stable during theh inference phase on unseen data. Therefore, this causes a problem of overfitting. Furthermore, the final regression coefficients might also be unstable for interpretation because of this.
  3. Data normalisation. In order to use linear regression as a feature importance tool, the data has to be normalized or standardized. This will make sure that all of the final regression coefficients are on the same scale and can be correctly interpreted.

Conclusion

We have looked through linear regression — a simple but very popular algorithm in machine learning. Its core principles are used in more complex algorithms.

Though linear regression is rarely used in modern production systems its simplicity allows to use often as a standard baseline in regression problems which is then compared to more sophisticated solutions.

The source code used in the article can be found here:

Resources

All images unless otherwise noted are by the author.

--

--

BSc in Software Engineering. Passionate machine learning engineer. Writer for Towards Data Science.