The importance of k-fold cross-validation for model prediction in machine learning

Jaime Dantas
Towards Data Science
8 min readNov 4, 2020

--

Image by Author — Thunder Bay, Canada

This article will discuss and analyze the importance of k-fold cross-validation for model prediction in machine learning using the least-squares algorithm for Empirical Risk Minimization (ERM).

We’ll use a polynomial curve-fitting problem to predict the best polynomial for the sample dataset. Also, we’ll go over the implementation step-by-step of the 10-fold cross-validation on MATLAB.

By the end of this post, you’ll know how to implement the k-fold cross-validation method and understand the benefits and drawbacks that come with it.

Overview

To better visualize the benefits of applying k-fold cross-validation on machine learning, we’ll analyze some problems we may face when estimating a model without doing any type of cross-validation at all.

We’ll begin by importing our dataset to MATLAB and plotting the data points to better visualize them. Then, we’ll implement the 10-fold cross-validation algorithm, and calculate the square loss error for the Empirical Risk Minimization (ERM) to find the best model for our data. Finally, we’ll repeat this process but without any type of cross-validation this time, and compare the results.

Dataset

We’ll use a dataset with just 100 data points. The reason for using such a small dataset will be explained later on. This dataset was originally proposed by Dr. Ruth Urner on one of her assignments for a machine learning course. In the repository below, you’ll find two TXT files: dataset1_inputs.txt and dataset1_outputs.txt.

These files contain the input and output vectors. After importing them in Home > Import Data on MATLAB, we can plot the data points as shown below.

Dataset

K-fold Cross-Validation

Cross-validation is usually used in machine learning for improving model prediction when we don’t have enough data to apply other more efficient methods like the 3-way split (train, validation and test) or using a holdout dataset. This is the reason why our dataset has only 100 data points. If you want to know more about the math behind this approach, I recommend reading this article.

In k-fold cross-validation, we first shuffle our dataset so the order of the inputs and outputs are completely random. We do this step to make sure that our inputs are not biased in any way. Then, we split the dataset into k parts of equal sizes. In this analysis, we’ll use the 10-fold cross-validation. So, the first step is to shuffle and split our dataset into 10 folds.

Splitting the data into 10-folds

Then, we’ll use one fold for testing and computing the empirical square loss and the remaining 9 other ones for training our model in each k interaction. By doing this, each time we begin a new interaction we use a different fold for testing. This way, we guarantee that every k part is used one time for testing.

Algorithm for the 10-fold cross-validation

In the end, we’ll have 10 values for the empirical square loss, one for each interaction. The final empirical square loss will be the average of these values. Enough said, let’s go to the implementation itself.

Implementation

First of all, we need to randomly divide our dataset into 10 equal parts. In order to do this, I created two loops that iterate over the vectors t and x splitting them into 10 equal parts. These parts are stored in the cell arrays split_t and split_x. I also created a secondary cell array to store the selected indexes for each part. Finally, I created a vector to store all the indexes that were already split.

In the inner loop, I generate a uniformly distributed random integer from 1 to 100 [1]. Then, I check if this index is on the list of visited indexes. If it is not in there, I add this value in the split cell arrays and increment the inner loop; otherwise, I generate a new random index. The outer loop repeats 10 times this process.

This process may not be the most efficient way of shuffling and splitting the dataset, but it does work and it does guarantee a random split. In order to visualize the 10 folds we created, I plot them in the figure below.

Dataset randomly split into 10 folds

Now, we need to implement the 10-fold cross-validation for the empirical square loss for ERM. Before going any further, I recommend reading this explanation about ERM so you understand some of the main concepts behind it. The solution of the ERM is defined in the equation below.

The vector w is our polynomial coefficients, X is the design matrix and t is the vector of outputs.

For the sake of simplifying our calculations, we’ll solve the linear equation for w manually in MATLAB (by multiplying and inverting the matrices). Be aware this is not the most efficient way of solving linear equations. Also, we’ll analyze the polynomials of order W = 1, 2, …, 30.

So, for 10-fold cross-validation, we execute the process of ERM 10 times in a loop and storage all test scores of each execution in the cell array E. The empirical square loss is calculated by the equation below.

In this process, we are using one fold for testing and the remaining 9 folds for training. Inside this loop, I also compute both the training and the testing design matrices.

The output of our algorithm is shown in the figure below.

The output of the Cross_Validation.m script

So, the polynomial of order W = 6 is the best fit for this data. Now, let’s analyze the plot of the ERM for all polynomials in the figure below.

Empirical Square Loss with cross-validation in reduced scale

Note that I reduced the scale of the graph in the figure above to better understand the trend. We can conclude that as the order of the polynomial increases, the empirical square loss decreases. However, when we increase the order to very large values, we start to see the behaviour of overfitting (W > 21 for this case). Also, note that from W = 12 up to W = 15 the square loss increased slightly. This was caused because we are permuting our dataset in the k-fold cross-validation process. One way that could avoid this is to perform k-fold cross-validation several times and average the square loss in the end.

The figure below shows the plot of the polynomial of order W = 6 against the dataset.

The polynomial of order 6 and dataset

Note that we didn’t overfit our data with the model we chose.

What if we had not used 10-fold cross-validation?

Image by Author

To answer this question, let’s see what outcome we would get if we had not used 10-fold cross-validation in our problem. In order to do so, we simply apply the square loss method for the ERM as shown below.

The output of the algorithm above is presented below.

The output of the ERM.m script

As we can see, this time the polynomial of order 21 was the one with the smallest empirical square loss. Even if we analyze the curve for the empirical square loss, which is represented below, we’ll end up choosing a model of very large order.

Empirical Square Loss in reduced scale

The consequence of selecting a larger order polynomial doesn’t come cheap. Not only will it overfit the data, but also increase the complexity. Let’s take a look at the polynomial of order 21 below.

The polynomial of order 21 and dataset

We can see that although this polynomial of order W = 21 has the smallest empirical loss, it overfits the data in a humongous way.

Conclusion

We saw that cross-validation allowed us to choose a better model with a smaller order for our dataset (W = 6 in comparison to W = 21). On top of that, k-fold cross-validation avoided the overfitting problem we encountered when we don’t perform any type of cross-validation, especially with small datasets.

This improvement, however, comes with a high cost. More computation power is required to find the best model when using k-fold cross-validation.

When we analyze the curves for the models with and without cross-validation, we can clearly see that 10-fold cross-validation was paramount in choosing the best model for this data.

We also went through the algorithm for the 10-fold cross-validation detailing every step needed to implement it on MATLAB.

About Me

I’m an M.A.Sc. student at York University, and a Software Engineer by heart. During the past decade, I’ve been working in several industries in areas such as software development, cloud computing and systems engineering. Currently, I’m developing research on cloud computing and distributes systems.

You can check my work on my website if you want to.

Thanks for reading it!

References

[1] Randi Function MATLAB. URL: https://www.mathworks.com/help/matlab/ref/randi.html#d122e1072277

[2] Shai Shalev-Shwartz and Ben-David. Understanding Machine Learning: From Theory to Algorithms. Cambridge University Press, 2014. DOI:10.1017/CBO9781107298019. URL: https://www.cs.huji.ac.il/~shais/UnderstandingMachineLearning/understanding-machine-learning-theory-algorithms.pdf

--

--