
K-Means, Hierarchical Clustering, Expectation-Maximization, and DBScan are probably the most famous clustering algorithms you may know in the context of Machine Learning.
However, there is another density-based algorithm that has been solving a variety of problems, ranging from satellite image segmentation to object tracking, and its name is mean-shift. This algorithm is renowned for finding the modes of a specific dataset by building clusters around areas with higher data point densities.
Visualizing mean-shift is easy. Normally, I use the skyscraper skyline analogy. If you consider a skyscraper skyline:

We can clearly see two different clusters of density if we consider the skyscrapers as our datapoints:

But let’s expand the image below, to see more skyscraper areas. As we expeand the image, we can see other clusters of density, with different "peak" size:

The goal of the mean-shift algorithm is to find these different clusters (of varying peak) within a dataset. It achieves this by shifting the cluster means until they reaches a "peak", similarly to dbscan.
It’s a powerful technique to leverage, particularly when our data contains non-spherical clusters.
In this blog post, we will:
- See a step-by-step implementation of the mean-shift algorithm.
- Check the
sklearn
implementation of the algorithm. - Go through a practical use case of the algorithm using real-data.
If you would like, you can also check the other blog posts of this Unsupervised Learning series, using the links below:
Compared to the latter (DBScan, another density based algorithm), mean-shift only requires the tuning of one hyperparameter. This comes with a cost, as it doesn’t provide the robustness to outliers that DBScan can offer.
Hope you’re ready, and let’s start!
Why K-Means doesn’t work for Non-Spherical Shapes
For our introduction of the mean-shift algorithm, we’ll start by understanding why traditional clustering methods fail on some datasets. Take the following dataset with generated customer data. Our two variables are annual income (in k euros) and average spending for a group of customers:

Visually, it’s hard to understand the number of different clusters we have on this dataset. We can consider three, four or five. For example, if we consider 5 different clusters, we’ll probably outline the following:

The image above is tricking us a bit, though. As with most clustering algorithms that deal with distances, it’s important to standardize our data – let’s do that:

Now, our densities are a bit different:

Another hypothesis is to combine some these data points into 3 or 4 clusters. Of course, doing this visually is a bit tricky, so how does K-Means behave in this dataset – first let’s see the 5 cluster example:

As you can see, K-Means clustering has some trouble to define these clusters. Mostly, this happens due to the non-spherical nature of the data. Let’s see what happens if I fit a k-means solution with 4 clusters:

It’s an interesting solution, but becoming too influenced by outliers. Notice how these 3 data points impact one of the clusters in a major way:

Now let’s see what happens with mean-shift:

We see clusters that are more aligned with the density we’ve seen before. Although the solution is not perfect, it’s an interesting pattern discovered by the algorithm.
But why is mean-shift so different from K-Means? As usual, no question is left unanswered – let’s see!
Mean-Shift Walkthrough, step-by-step
First, we’ll start with a clean version of our dataset (without any clusters assigned). Mean-shift starts with a random data point (highlighted in orange below):

Next step is to draw a circle around this data point with a radius (hyperparameter of the algorithm). I’ll define 0.1 as the value of the radius of our circle:

Note: From now on, I’ll show squared plots so that we can correctly see the circles generated by the radius. Default matplotlib definitions shows an elipsis as the plot is stretched on the x-axis.
We now need to perform a crucial step of the mean-shift algorithm: the shifting operation!
This is actually very simple. Depending on the kernel, we will drag this orange point to its center of density. The most famous kernels to use within mean-shift are:
- Gaussian kernel
- Mexican-hat
- Uniform kernel
The gaussian and mexican-hat kernels consider the points that are farther from the center should have less influence in the shifting operation. The uniform kernel is a bit less flexible and states that any data point within the radius should have equal influence in the shifting operation.
We’ll use uniform kernel as an example, due to its simplicity, although gaussian kernel is also used often.
First, we need to assess which data points are within our radius:

I’ll consider that every data point that touches some part of the radius is part of the cirle. Subsetting these data points (and the point itself):

Now, we can use the averages both age and _annualincome of this subset of customers to drag the "centroid".
- age average is approximately 0.0679
- _annualincome average is approximately 0.324
These are the two values that contain the coordinates where we are going to shift our mean inside the radius:

Our centroid gets shifted northwest. Let’s reposition it:

In this case, our new radius touches the same data points. Our centroid converged to this region and it will stay there.
Next iteration is doing exactly the same for every data point. Let’s visualize what would happen with another data point in a different region:

In this case, the mean would shift southeast. If we start in an isolated data point:

In this case, the mean that started on this data point won’t change. This data point will form a cluster by itself.
During the algorithm run, different data points will land in a same "center of mass". Tthese peaks will be combined to become a cluster (different implementations of mean-shift may do this last step in different ways). The key is that all implementations respect the principle: "same density peak, same cluster".
Depending on the implementation, the algorithm will finish when one of the conditions is met:
- Reaches a certain number of iterations given by the user. One iteration is equal to one pass through all the data points.
- The centroids no longer move after a few iterations.
- Some implementations define that when all cluster peaks are within a radius boundary, the algorithm should finish.
In sklearn
‘s implementation, the choice is route number 1 – after 300 iterations the algorithm stops, by default.
Mean-shift has only one important hyperparameter – the radius. Compare the following two solutions:

We are revisiting our first example – in this case, our radius is 0.1. What happens we we raise this value?

With a radius of 0.5, we can see that more points will tend to navigate to the same density peak. The downside of the radius hyperparameter (user defined) is that it’s the same for every data point. Due to that, mean-shift may not be able to lead very well with datasets that contain clusters of different densities.
With no surprise, shifting is the most important operation of the algorithm! It’s time to see it in action, using the sklearn
implementation!
Mean-shift in Action
As customer segmentantion is one of the oldest (and most relevant) usages of clustering algorithms, we’ll play around with the Wholesale Customer Dataset, available on UCI’s Machine Learning Repository.
We’ll start with the usual load operation of the dataset:
import pandas as pd
wholesale_customers = pd.read_csv('data/Wholesale customers data.csv')
Here’s a preview of it:

According to the metadata available at the UCI repository, the description of the columns is the following:
- FRESH: annual spending (m.u.) on fresh products (Continuous);
- MILK: annual spending (m.u.) on milk products (Continuous);
- GROCERY: annual spending (m.u.)on grocery products (Continuous);
- FROZEN: annual spending (m.u.)on frozen products (Continuous)
- DETERGENTS_PAPER: annual spending (m.u.) on detergents and paper products (Continuous)
- DELICATESSEN: annual spending (m.u.)on and delicatessen products (Continuous);
- CHANNEL: Customer’s Channel – Horeca (Hotel/Restaurant/Cafe) or Retail channel (Nominal)
- REGION: Customers’s Region: " Lisbon, Oporto or Other (Nominal)
For this example, we’ll only use the monetary values.
The first step in our mean-shift solution is to determine the bandwith
of the algorithm. The bandwith
represents the radius we’ve seen in the step-by-step example.
Estimating bandwith manually is very hard. One cool function we have in sklearn
is estimate_bandwith
that uses nearest neighbors average distance to find a good baseline value.
But before checking the bandwidth, do you think there is something missing? We are dealing with an algorithm that relates to distances, so… we need standardization!
from sklearn.preprocessing import MinMaxScaler
mmc = MinMaxScaler()
scaled_dataset = mmc.fit_transform(
wholesale_customers[['Fresh','Milk','Grocery','Frozen','Detergents_Paper','Delicassen']]
)
With the scaled dataset, we can run estimate_bandwith
:
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwith_estimation = estimate_bandwidth(scaled_dataset)
The estimated bandwith was approximately 0.183. We can now fit the algorithm to the scaled dataset:
meanshiftmodel = MeanShift(bandwidth=bandwith_estimation)
wholesale_customers['meanshift_clustering'] = meanshiftmodel.fit_predict(
scaled_dataset
)
Let’s see how many data points we have per cluster:
wholesale_customers['meanshift_clustering'].value_counts()

We have some clusters that only contain one data point. These are probably outliers isolated by our radius. There’s three alternatives we can use:
- We can remove uni-variate outliers manually before fitting the algorithm.
- We can remove these clusters from our dataset.
- We can raise the bandwith we are using and re-run the clustering solution again.
Let’s go through the latter!
meanshiftmodel = MeanShift(bandwidth=bandwith_estimation+0.1)
wholesale_customers['meanshift_clustering_extra_band'] = meanshiftmodel.fit_predict(
scaled_dataset
)
Checking the clusters again:

Oh no! Too high. Tweaking the bandwidth
parameter is difficult. Let’s settle with the first solution we have and remove outlier clusters manually:
outlier_clusters = (
wholesale_customers['meanshift_clustering'].value_counts().loc[
wholesale_customers['meanshift_clustering'].value_counts() <= 3
]
).index
clustering_solution = wholesale_customers.loc[~wholesale_customers.meanshift_clustering.isin(outlier_clusters)]
And checking the profiling of our clusters:
clustering_solution.groupby(['meanshift_clustering']).mean().iloc[:,2:]

Cool! So our first cluster (Cluster 0) is characterized by moderate fresh buyers with low values on the other categories.
Cluster 1 is characterized by high grocery and detergent shoppers.
Finally, cluster 4 is characterized by high grocery and frozen shoppers.
A word of cautioning about this type of profilling – using simple mean profiling on density clusters is a biased solution (particularly in mean-shift, where clusters may not be spherical). An alternative to profile and check the quality of the clustering is to plot the clustering solution with T-SNE or UMAP embeddings.
Also, by the distribution of the data points in the clusters, this is not a good dataset to apply a density based clustering method. When you see a high concentration of data points in a single cluster, other methods should be tested (but don’t forget to play around with the bandwidth parameter first).
Hope you’ve enjoyed this blog post on mean-shift clustering! As mentioned, this is an interesting density based algorithm that can find out interesting patterns in your data.
On the upside, mean-shift can find out non-spherical clusters on the data, while uncovering the different distribution peaks. On the downside, it tends to isolate outliers as clusters. Also, the bandwith hyperparameter is relatively difficult to tweak.
If you are curious about other unsupervised learning algorithms, feel free to check the other blog posts of this series:
If you want to read / see more content related to AI and DS, subscribe my youtube Channel "The Data Journey":

The dataset used in this blog post is Wholesale Customer Dataset with license Creative Commons Attribution 4.0 (https://archive.ics.uci.edu/dataset/292/wholesale+customers)