Stratified K-Fold Cross-Validation on Grouped Datasets

This article explains how to use optimization to perform stratified K-Fold cross-validation on a grouped dataset

João Paulo Figueira
Towards Data Science

--

Photo by Nicolas COMTE on Unsplash

Cross-validation is a common resampling technique to get more mileage from your dataset. The procedure involves taking repeated independent samples from the original dataset and fitting them to the desired model. Cross-validation is helpful for model selection, as it provides better generalization performance estimates than the holdout method. The resampling process ensures that we reap the benefits of the whole dataset's diversity, especially when its size is on the smaller end of the scale.

One of the most popular cross-validation methods works by evenly splitting the entire dataset into K groups or folds. It is aptly named "K-Fold Cross Validation."

K-Fold Cross-Validation

To use K-Fold cross-validation, we split the source dataset into K partitions. We use K-1 as the training set and the remaining one to validate. The process runs K times, at the end of which, we take the average of the K learning metrics. Naturally, the validation dataset rotates on each cycle, and we average the model performance metrics over the K runs. Figure 1 below provides a schema of how the process works.

Figure 1 — The image above depicts the process of five-fold cross-validation. When fitting a model, each row represents the split between train (blue) and validation datasets (orange). We typically estimate the final performance metrics through an average of the runs. (Image source: Author)

We can extend this process for the case of stratified datasets. In these datasets, we might want to keep the same proportion of classes on each fold so that the learning process is not skewed. Here we can think of types of classification targets and any underlying categories on the training data itself. If we take a human sample, it is reasonable to expect that all folds should have the same demographic distribution (think histogram age classes). You can find support for stratified K-Fold cross-validation on the Scikit-Learn Python package.

This article is a follow-up to a previous one where I devised a means to perform a stratified partition of a grouped dataset into train and validation datasets using discrete optimization. I invite you to read it as an introduction to the material presented here, where I extend the concept to K-Fold cross-validation.

Grouped Datasets

A grouped dataset has an extra layer of structure, namely an objective grouping criterion. Each data group operates as a logical unit, and we cannot split them across the folds.

More concretely, let us imagine a predictive maintenance use case where the dataset consists of time series of telematics data acquired for each piece of equipment. When developing a model to predict the impending failure, you may opt not to split each of the equipment's time series among the folds but to keep them in their folds integrally. This approach has the advantage of avoiding potential leakage issues but makes the splitting process a bit harder.

In the previously mentioned article, I dealt with a simple binary split. Here, the challenge is to generalize that process to a K-Fold division. Before proceeding, note that the Scikit-Learn Python package provides support for this feature through the StratifiedGroupKFold function. Still, according to the documentation, this function performs a greedy assignment. We can take it a step further using an optimization algorithm. As in the previous article, this is a typical case for discrete optimization. We must select the groups to ensure an even split among the folds and guarantee all classes' stratification. Concretely, we must guarantee that each fold will have approximately twenty percent of the original data for five-fold cross-validation. Each fold must also have roughly the same stratification as the original dataset.

We can think of this problem as a generalization from the previous article, as two-fold cross-validation. The first change we must make is a simple modification to the solution model. Instead of a binary array, where the true value indicates a validation item and the false indicates training, we use an integer array containing each group's fold index. The second change lies in the search space representation.

Problem Model

The problem model for a K-way split of the dataset is very similar to the one I proposed for the previous article's problem. Here, instead of using a boolean array to represent the solution, I use an integer array with the index to which each group is assigned. Figure 2 below displays a typical problem model sample and a proposed solution. I drew this sample from a larger dataset with 500 groups.

Figure 2 — We model the problem as an array where each row contains group counts, and the columns reflect the class counts. The solution model is a simple integer array containing the fold indices. The table above is an illustrative sample from a problem with 500 groups. (Image source: Author)

The problem model reflects on the first four columns, corresponding to the class counts. Each row represents a group, with counts broken down by class. Horizontal sums yield group counts, while column sums yield class counts. Note that the solution column contains the proposed index for each group between zero and four for a 5-Fold problem.

Solving this optimization problem requires finding a split that respects both the fold and the class proportions. We measure the solution fitness using a cost function.

The Cost Function

The cost function measures how far a solution is from the optimum, which is to have the same class proportions in each fold, guaranteeing that groups remain together and that folds have the required size. For a 5-Fold problem, each fold should contain 20% of the total data points. Figure 3 below shows the implementation of the cost function.

Figure 3 — The function above implements the cost calculator for a single solution. (Image source: Author)

The cost function iterates through all folds computing the difference between each fold's proportion and the target (line 16). Then, it iterates through the fold's classes and calculates the difference between their proportions and the existing class proportions (line 20). The final cost is the summation of all these squared differences, and the perfect solution would have zero cost.

We can search for the optimal partition using a specialized data structure with the cost function.

The Search Space

Searching for an optimal solution to this problem involves iterating through a search space. Here, we implement the search space as a matrix with the problem's folds as columns and the groups as rows, directly corresponding to the problem matrix's rows. Each row of the search space matrix contains either infinity for the current solution or the cost of moving the solution to the corresponding fold.

Figure 4 — The image above shows a typical search space matrix. (Image source: Author)

The search space is straightforward to build. We start with the last generated solution and place the infinity value in all groups at the corresponding selected fold. We calculate the solution cost for all other folds as if that particular fold belonged to the solution. Considering Figure 4 above, the first row and first column's value results from calculating the cost using the cost function as if the solution contained the first fold for the first group.

Figure 5 below shows the simple code that generates the search space from the problem definition and the last known solution.

Figure 5 — The function above generates the search space by calculating the cost for all possible alternative solutions per group. Using the infinity value relates to how we will select the next solution. (Image source: Author)

As you can see, this is a very simple function that fills the whole search space matrix with the corresponding solution costs, except for the current solution, where it uses the infinity value (explained below).

Searching

We now turn our attention to the search process. Note that the search space structure lends itself to a simple search strategy — sorting. On each iteration of the search process, we look for the smallest entry in the search space matrix and use its coordinates (group and class) as the following change to the current solution. We update the solution vector using the coordinates of the smallest entry in the search space. It's why the recent solution entries got marked with the infinity value — it ensures that the algorithm will not use these entries during the search. Figure 6 below shows the function that implements the main solution searching loop.

Figure 6 — We search for a solution using the above function. It starts by setting up a solution history and generating an initial solution using a greedy process. The search function generates a new space for each loop and determines the next move using a simple sorting procedure and the solution history. If the new solution cost is lower than the previous best, the function records the new solution as the incumbent (best so far) and uses its cost as the lowest mark. The search process ends when the cost is below a given value. (Image source: Author)

The function that searches for the best solution, depicted above, starts by generating an initial candidate. This process uses a greedy algorithm to satisfy the fold proportion requirements and uses this solution as the initial search step. Figure 7 below shows the code that generates the initial solution.

Figure 7 — There are three options to generate an initial solution, as shown above. The default option uses a greedy approach to satisfy the fold proportion restriction. The second option uses random assignment, while the last one merely assigns all groups to fold zero. (Image source: Author)

After generating a new search space corresponding to the last solution, the search algorithm selects the next move in the optimization process. The function depicted below in Figure 8 shows how this process works.

Figure 8 — We select the next move simply by collecting the sorted indices of the search space. The selection function then iterates through the move candidates, excluding the ones that generate visited solutions. (Image source: Author)

Note how we use the history set to weed out the already-generated solutions.

Termination Criteria

Finally, let us look at the search algorithm's termination criteria. We can control how far the algorithm goes in its search effort using two parameters: the minimum allowed cost and the maximum number of retries. As its name suggests, the minimum cost is a target value for the solution cost, below which the search stops.

Whenever the search algorithm loops and cannot produce a better solution, it increments a retry counter. The maximum number of retries parameter limits how many consecutive retries the algorithm will perform before failing and reporting back the last best solution, the incumbent.

Running

You can run this article's code from the companion GitHub repository. You can run either the Jupyter notebook or the standalone Python script. In the end, you will get a performance report showing how the optimizer performed matching the fold and class proportions.

Conclusion

In this article, we have discussed an approach to solving the problem of splitting a stratified and grouped dataset into K folds. We have used a simple optimization algorithm that uses a cost function and an explicit search space. The cost function calculates the distance between any given solution state and the final desired state. The optimizer merely looks for the next configuration that reports the minimum cost value and is capable of hill climbing. The search terminates when the incumbent solution has a cost lower than a preset or when the number of retries has exceeded a given maximum.

Resources

An Introduction to Statistical Learning (statlearning.com)

Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition [Book] (oreilly.com)

Machine Learning Engineering (mlebook.com)

joaofig/strat-group-split: Code to perform a stratified split of grouped datasets into train and validation sets using optimization (github.com)

João Paulo Figueira works as a Data Scientist at tb.lx by Daimler Trucks and Buses in Lisbon, Portugal.

--

--