
The go-to data science tools for stratification are great at handling supervised classification datasets. However, they aren’t designed to handle numeric, unsupervised, and multi-dimensional datasets out-of-the-box. This is where AIQC excels.
Stratification Leads to Generalization
In previous blog posts, we extensively covered the importance of evaluating models with validation and test splits/ folds:
The purpose of predictive analytics is to train a model that can be applied elsewhere – to be able to make predictions about external samples that the model has never seen before. In order to achieve this, the data distribution that the model is trained on must be representative of the broader population.
In this post, we’ll cover how to stratify different types of data and the pitfalls to watch out for along the way
Categorical Stratification
Let’s have a go at stratifying the Iris dataset. First, we import the data:
from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
Then we split the data into train, validation, & test splits using sklearn’s train_test_split
function. Note the use of the stratify
argument.
from sklearn.model_selection import train_test_split
train_features, eval_features, train_labels, eval_labels = train_test_split(
features, labels, test_size=0.32, stratify=labels
)
# `eval_*` is further divided into validation & test.
val_features, test_features, val_labels, test_labels = train_test_split(
eval_features, eval_labels, test_size=0.32, stratify=eval_labels
)
With a bit of wrangling, we can plot the distribution of our new splits.
import numpy as np
_, train_counts = np.unique(train_labels, return_counts=True)
_, val_counts = np.unique(val_labels, return_counts=True)
_, test_counts = np.unique(test_labels, return_counts=True)
from itertools import chain
counts = list(chain(train_counts,val_counts,test_counts))
labels = ['Setosa', 'Versicolor', 'Virginica']
labels = labels + labels + labels
splits = ['Train', 'Train', 'Train', 'Validation','Validation','Validation','Test','Test','Test']
import pandas as pd
df = pd.DataFrame()
df['Count'], df['Label'], df['Split'] = counts, labels, splits
import plotly.express as px
px.histogram(
df, x='Split', y='Count', color='Label',
barmode='group', height=400, title='Stratification'
).update_layout(yaxis_title='Count')

Notice that as we divide the data further and further, the distribution struggles to stay even. When there is a minority class in a dataset, it’s easy to run out of samples from this underrepresented class in later subsets. This can easily result in a split that has zero representation from a minority class, especially in scenarios like 10-fold cross-validation. Not only is this underrepresentation a source of bias but also it can cause bugs when calculating performance metrics.

If we run the same code without the use of the stratify
argument, then we see our distributions become skewed. Although the classes are equally represented in the dataset, the Test split becomes skewed when we leave stratification up to randomness.
Numeric Stratification
However, when attempting to use stratify
with a numeric (integer or float) label that has many unique values, it’s likely that you will immediately run into the error below:
from aiqc import datum
df = datum.to_pandas('houses.csv')
labels = df['price'].to_numpy()
features = df.drop(columns=['price']).to_numpy()
train_features, eval_features, train_labels, eval_labels = train_test_split(
features, labels, test_size=0.32, stratify=labels
)
"""
ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.
"""
 on Unsplash](https://towardsdatascience.com/wp-content/uploads/2022/03/1U-Xb9e5vzeBd9oYQ8Y5LSQ.png)
Bins to the Rescue
We can reduce the number of unique label values by placing them into binned ranges. Then we can treat their respective bin numbers as if they were discretized categorical data.
First segment the numeric data with Pandas quantile cut:
bin_count = 5
bin_numbers = pd.qcut(
x=labels, q=bin_count, labels=False, duplicates='drop'
)
train_features, eval_features, train_labels, eval_labels = train_test_split(
features, labels, test_size=0.32, stratify=bin_numbers
)
# 2nd round
bin_numbers = pd.qcut(
x=eval_labels, q=bin_count, labels=False, duplicates='drop'
)
val_features, test_features, val_labels, test_labels = train_test_split(
eval_features, eval_labels, test_size=0.32, stratify=bin_numbers
)
Then plot them to verify the distributions:
train_df = pd.DataFrame(train_labels, columns=['price'])
val_df = pd.DataFrame(val_labels, columns=['price'])
test_df = pd.DataFrame(test_labels, columns=['price'])
px.histogram(train_df, x='price', height=400, title='Train Labels', nbins=30)
px.histogram(val_df, x='price', height=400, title='Validation Labels', nbins=30)
px.histogram(test_df, x='price', height=400, title='Test Labels', nbins=30)



Looking at these charts really makes me question how the metrics from extensively cross-validated (5–10 folds) can be trusted.

Multi-Dimensional & Unsupervised Data
Let’s get a bit more advanced. What if we don’t have labels? What if our data isn’t tabular?
The ingestion script below results in a 3D array. There are 178 EEG readings from 1000 patients. Each set of readings is formatted as 2D for a challenge.
df = datum.to_pandas('epilepsy.parquet')
features = df.drop(columns=['seizure']).to_numpy().reshape(1000, 178, 1)
It may surprise you that labels aren’t actually required to perform stratification. The function doesn’t care where the stratify array comes from, it just needs to have the same sample order as our features. However, if we attempt to stratify using our 3D features, we get an error:
train_features, eval_features, = train_test_split(
features, test_size=0.32, stratify=features
)
"""
ValueError: Found array with dim 3. Estimator expected <= 2.
"""
Taking a moment to think about it, there are 178 readings per sample in that array. So what would it even be stratifying by? Let’s take the median of each set of readings, and use that as the stratify array.
medians = [np.median(arr) for arr in features]
bin_count = 5
bin_numbers = pd.qcut(x=medians, q=bin_count, labels=False, duplicates='drop')
train_features, eval_features, = train_test_split(
features, test_size=0.32, stratify=bin_numbers
)
# 2nd round
medians = [np.median(arr) for arr in eval_features]
bin_numbers = pd.qcut(
x=medians, q=bin_count, labels=False, duplicates='drop'
)
val_features, test_features = train_test_split(
eval_features, test_size=0.32, stratify=bin_numbers
)
Easier with AIQC
If you’re feeling a bit overwhelmed by all of this, have a look at AIQC – an open source library for end-to-end MLops. It provides high-level APIs for pre/post-processing, experiment tracking, and model evaluation.
For example, when ingesting data with aiqc.Pipeline
, practitioners can use the stratification arguments to declare how they want their data processed:
size_test = 0.22
size_validation = 0.12
fold_count = 5
bin_count = 15
Since AIQC is data-aware, knows how to handle the dtypes and supervision of the dataset.
AIQC is an open source library created by the author of this post.
Don’t forget to ⭐ github.com/aiqc/aiqc
Summary
Let’s recap what we’ve learned:
- Stratification ensures that training & evaluation data is representative of distributions found in the the broader population, which helps us train generalizable models.
- Binning discretizes numeric data for use as a stratification array.
- The stratification array doesn’t have to be a label/ target!
- Summary statistics like mean and median help us reduce the dimensionality of a feature for use as a stratification array.