The world’s leading publication for data science, AI, and ML professionals.

GMM: Gaussian Mixture Models – How to Successfully Use It to Cluster Your Data?

An intuitive explanation of GMMs with helpful Python examples

Machine Learning

GMM - Gaussian Mixture Models. Image by author.
GMM – Gaussian Mixture Models. Image by author.

This article is part of the series that explains how different Machine Learning algorithms work and provides you a range of Python examples to help you get started with your own Data Science project.

The story covers the following topics:

  • The category of algorithms Gaussian Mixture Models (GMM) belongs to.
  • Description of how the GMM algorithm works.
  • Python examples of how to use GMM for Clustering.

What category of algorithms does GMM belong to?

While it is not always possible to categorize every algorithm perfectly, it is still beneficial to try and do so. The below interactive chart is my attempt to help you see the broader universe of Machine Learning.

Make sure to click👇 on different categories to enlarge and reveal more.

Note, in many cases, the same algorithm can be used to solve multiple types of problems. E.g., one can use Neural Networks for classification, regression, and as part of the reinforcement learning.

If you enjoy Data Science and Machine Learning, please subscribe to get an email whenever I publish a new story.

Since Gaussian Mixture Models (GMM) are used in clustering, they sit under the unsupervised branch of Machine Learning.

As you may already know, unsupervised techniques, in particular clustering, are often used for segmentation analysis or as a way to find similarities/differences between observations in your dataset. This is different from supervised learning models, which are typically used for making predictions.

Explanation of GMM algorithm

Types of clustering algorithms

Not all clustering algorithms are created equal. Different clustering algorithms implement different ideas on how to best cluster your data. There are 4 main categories:

  • Centroid-based – uses Euclidean distance to assign every point to the nearest cluster center. Example: K-Means
  • Connectivity-based – assumes that nearby objects (data points) are more related than far away objects. Example: Hierarchical Agglomerative Clustering (HAC).
  • Density-based – defines clusters as dense regions of space separated by low-density regions. They are good at finding arbitrarily shaped clusters. Example: Density-Based Spatial Clustering of Applications with Noise (DBSCAN).
  • Distribution-based – assumes the existence of a specified number of distributions within the data. Each distribution with its own mean (μ) and variance (σ²) / covariance (Cov). Example: Gaussian Mixture Models (GMM).

Note, variance is used for single variable analysis and covariance for multivariate analysis. Examples in this article use multivariate setup (multiple distributions/clusters).

Brief Description of GMM

As you might have figured, Gaussian Mixture Models assume that your data follows Gaussian (a.k.a. Normal) distribution. Since there can be multiple such distributions within your data, you get to specify their number, which is essentially the number of clusters that you want to have.

Also, since separate distributions can overlap, the model output is not a hard assignment of points to specific clusters. It is based on a probability that the point belongs to a said distribution. Say, if point A has a probability of 0.6 belonging to "Cluster 0" and a probability of 0.4 belonging to "Cluster 1," then the model would recommend "Cluster 0" to be the label for that point (since 0.6>0.4).

To aid the explanation further, let’s look at a few graphs.

  • The below image shows 4 clusters of Australian cities identified by GMM. Note how each cluster has its own mean (center), covariance (shape), and size. Also, there is a notable overlap between different clusters (purple, blue, and green).
Gaussian Mixture Models (GMM) - 4 clusters of Australian cities. Image by author.
Gaussian Mixture Models (GMM) – 4 clusters of Australian cities. Image by author.
  • As mentioned earlier, the cluster label assignment was based on picking the highest probability of a specific data point belonging to a specific cluster. However, that does not mean that the point is definitely part of that cluster (distribution).
  • See what happens when we ask the model to generate new data points (samples) for the distributions it has found above:
Gaussian Mixture Models (GMM) - 10,000 new samples generated for the 4 distributions. Image by author.
Gaussian Mixture Models (GMM) – 10,000 new samples generated for the 4 distributions. Image by author.
  • Due to the overlapping nature of the distributions, you get some purple points mixed in with blue points, as well as some green points mixed in with blue and purple points. Hence, GMM can be described as a soft clustering approach where no hard decision boundary exists to separate different clusters.

Expectation-Maximization (EM)

To understand how GMM works in practice, we need to look at the Expectation-Maximization (EM) algorithm. The EM uses an iterative method to calculate and recalculate the parameters of each cluster (distribution), i.e., mean, variance/covariance, and size.

I will not go into the complicated maths on what happens within each step. Instead, I will give you an intuitive explanation starting with the below chart for easy visualization:

Gaussian Mixture Models (GMM) - iterative process to find 4 clusters. Image by author.
Gaussian Mixture Models (GMM) – iterative process to find 4 clusters. Image by author.

At the outset, the model initializes a specified number of clusters with a set of parameters that can either be random or specified by the user. Smart initialization options are also available in some implementations (e.g., sklearn’s implementation of GMM by default uses kmeans to initialize clusters).

For the above graph, I have specified my own set of mean values (starting centers) to initialize clusters, which helped me to create a nicer visualization. It has also sped up the convergence when comparted to random initialization.

However, you have to be very careful with initialization, because GMM’s final result tends to be quite sensitive to the initial starting parameters. Hence, it is recommended to either use smart initialization or to randomely initialize many times and then pick the best result.

So, with clusters initialized, we have the mean (μ), covariance (Cov), and size (𝜋) available to use.

  • Expectation (E-step) – for each data point, a "responsibility" r is calculated, which is, in simple terms, a probability of that data point belonging to a cluster c. This is done for each point with regard to each cluster.
  • Maximization (M-step) – then "responsibilities" are used to recalculate the mean, covariance, and size of each cluster (distribution). At this step, you can also think of "responsibility" as a weight. The less likely it is that the data point belongs to a cluster, the smaller the weight it will carry in the recalculation of μ, Cov, and 𝜋. In the GIF image above, you can see how the position, shape, and size of the clusters change with each iteration.

The process of E-step and M-step is repeated many times until no further improvements can be made, i.e., convergence is achieved.

Python example of GMM clustering

Setup

We will use the following data and libraries:

Let’s import all the libraries:

import pandas as pd # for data manipulation
import numpy as np # for data manipulation

from sklearn.mixture import GaussianMixture # for GMM clustering
from sklearn import metrics # for calculating Silhouette score

import matplotlib.pyplot as plt # for data visualization
import plotly.express as px  # for data visualization
import plotly.graph_objects as go # for data visualization

from geopy.geocoders import Nominatim # for getting city coordinates
from progressbar import ProgressBar # for displaying progress 
import time # for adding time delays

Then we get the Australian weather data from Kaggle, which you can download following this link: https://www.kaggle.com/jsphyg/weather-dataset-rattle-package.

We ingest the data and derive a new variable, "Location2", which has the right format to extract city coordinates using Geopy.

# Set Pandas options to display more columns
pd.options.display.max_columns=50

# Read in the weather data csv
df=pd.read_csv('weatherAUS.csv', encoding='utf-8')

# Drop records where target RainTomorrow=NaN
df=df[pd.isnull(df['RainTomorrow'])==False]

# For other columns with missing values, fill them in with column mean
df=df.fillna(df.mean())

# Add spaces between multiple words in location names
df['Location2']=df['Location'].str.replace( r"([A-Z])", r" 1").str.strip()
# Update Location for Pearce RAAF so it can be found by geolocator
df['Location2']=df['Location2'].apply(lambda x: 'Pearce, Bullsbrook' if x=='Pearce R A A F' else x)

# Show a snaphsot of data
df
A snippet of Kaggle's Australian weather data with some modifications. Image by author.
A snippet of Kaggle’s Australian weather data with some modifications. Image by author.

Since our original data only contains location (city) names and not coordinates, we will use Geopy’s Nominatim to get those coordinates. Note that we add a sleep time of 1 second between each call not to overload the server.

# Create a list of unique locations (cities)
loc_list=list(df.Location2.unique())

geolocator = Nominatim(user_agent="add-your-agent-name")
country ="Australia"
loc_res=[]

pbar=ProgressBar() # This will help us to show the progress of our iteration
for city in pbar(loc_list):
    loc = geolocator.geocode(city+','+ country)
    res = [city, loc.latitude, loc.longitude]
    loc_res = loc_res + [res]
    time.sleep(1) # sleep for 1 second before submitting the next query

# Add locations to a dataframe
df_loc=pd.DataFrame(loc_res, columns=['Loc', 'Latitude', 'Longitude'])

# Show data
df_loc

And this is the snippet of what we get in return:

Australian city coordinates. Image by author.
Australian city coordinates. Image by author.

Next, let’s plot cities on a map:

# Create a figure
fig = go.Figure(data=go.Scattergeo(
        lat=df_loc['Latitude'],
        lon=df_loc['Longitude'],
        hovertext=df_loc['Loc'], 
        mode = 'markers',
        marker_color = 'black',
        ))

# Update layout so we can zoom in on Australia
fig.update_layout(
        width=980,
        height=720,
        margin={"r":0,"t":10,"l":0,"b":10},
        geo = dict(
            scope='world',
            projection_type='miller',
            landcolor = "rgb(250, 250, 250)",
            center=dict(lat=-25.69839, lon=139.8813), # focus point
            projection_scale=6 # zoom in on
        ),
    )
fig.show()
Australian cities on a map. Image by author.
Australian cities on a map. Image by author.

GMM clustering – picking the number of clusters

There is more than one way to select how many clusters you should have. It can be based on your knowledge of the data or something more data-driven like the Silhouette score. Here is a direct quote from sklearn:

The Silhouette Coefficient is defined for each sample and is composed of two scores:

a: The mean distance between a sample and all other points in the same class.

b: The mean distance between a sample and all other points in the next nearest cluster.

The Silhouette Coefficient s for a single sample is then given as:

The Silhouette Coefficient for a set of samples is given as the mean of the Silhouette Coefficient for each sample.

Let’s create multiple GMM models using a different number of clusters and plot Silhouette scores.

# Create empty list
S=[]

# Range of clusters to try (2 to 10)
K=range(2,11)

# Select data for clustering model
X = df_loc[['Latitude', 'Longitude']]

for k in K:
    # Set the model and its parameters
    model = GaussianMixture(n_components=k, n_init=20, init_params='kmeans')
    # Fit the model 
    labels = model.fit_predict(X)
    # Calculate Silhoutte Score and append to a list
    S.append(metrics.silhouette_score(X, labels, metric='euclidean'))

# Plot the resulting Silhouette scores on a graph
plt.figure(figsize=(16,8), dpi=300)
plt.plot(K, S, 'bo-', color='black')
plt.xlabel('k')
plt.ylabel('Silhouette Score')
plt.title('Identify the number of clusters using Silhouette Score')
plt.show()
Deciding on the number of GMM clusters using Silhouette score. Image by author.
Deciding on the number of GMM clusters using Silhouette score. Image by author.

Generally, the higher the Silhouette score, the better defined your clusters are. In this example, I chose to have 4 clusters instead of 2 despite the score being slightly higher for a 2 cluster setup.

Note, if you are well familiar with your data, you may prefer to use the Silhouette score as a guide rather than a hard rule when deciding on the number of clusters.

GMM clustering – building a model

Let’s now build our GMM model:

# Select data for clustering model
X = df_loc[['Longitude', 'Latitude']]

# Set the model and its parameters - 4 clusters
model4 = GaussianMixture(n_components=4, # this is the number of clusters
                         covariance_type='full', # {'full', 'tied', 'diag', 'spherical'}, default='full'
                         max_iter=100, # the number of EM iterations to perform. default=100
                         n_init=1, # the number of initializations to perform. default = 1
                         init_params='kmeans', # the method used to initialize the weights, the means and the precisions. {'random' or default='k-means'}
                         verbose=0, # default 0, {0,1,2}
                         random_state=1 # for reproducibility
                        )

# Fit the model and predict labels
clust4 = model4.fit(X)
labels4 = model4.predict(X)

# Generate 10,000 new samples based on the model
smpl=model4.sample(n_samples=10000)

# Print model summary
print('*************** 4 Cluster Model ***************')
#print('Weights: ', clust4.weights_)
print('Means: ', clust4.means_)
#print('Covariances: ', clust4.covariances_)
#print('Precisions: ', clust4.precisions_)
#print('Precisions Cholesky: ', clust4.precisions_cholesky_)
print('Converged: ', clust4.converged_)
print(' No. of Iterations: ', clust4.n_iter_)
#print('Lower Bound: ', clust4.lower_bound_)

Below is the summary printed by the above code. Note, the convergence has been achieved after 7 iterations with means (cluster centers) displayed:

Outputs from the GMM model. Image by author.
Outputs from the GMM model. Image by author.

Now, let’s plot clusters on a map:

# Attach cluster labels to the main dataframe
df_loc['Clust4']=labels4

# Create a figure
fig = go.Figure(data=go.Scattergeo(
        lat=df_loc['Latitude'],
        lon=df_loc['Longitude'],
        hovertext=df_loc[['Loc', 'Clust4']], 
        mode = 'markers',
        marker=dict(colorscale=['#ae34eb', 'red', 'blue', '#34eb34']),
        marker_color = df_loc['Clust4'],
        ))

# Update layout so we can zoom in on Australia
fig.update_layout(
        showlegend=False,
        width=1000,
        height=760,
        margin={"r":0,"t":30,"l":0,"b":10},
        geo = dict(
            scope='world',
            projection_type='miller',
            landcolor = "rgb(250, 250, 250)",
            center=dict(lat=-25.69839, lon=139.8813), # focus point
            projection_scale=6 # zoom in on
        ),
    )
fig.show()
Gaussian Mixture Models (GMM) - 4 clusters of Australian cities. Image by author.
Gaussian Mixture Models (GMM) – 4 clusters of Australian cities. Image by author.

Finally, you can also plot the sample of 10,000 new points generated by the model as seen in a graph earlier in the article:

# Create a figure
fig = go.Figure(data=go.Scattergeo(
        lat=smpl[0][:,1],
        lon=smpl[0][:,0],
        mode = 'markers',
        marker=dict(colorscale=['#ae34eb', 'red', 'blue', '#34eb34']),
        marker_color = smpl[1],
        marker_size=3
        ))

# Update layout so we can zoom in on Australia
fig.update_layout(
        showlegend=False,
        width=1000,
        height=760,
        margin={"r":0,"t":30,"l":0,"b":10},
        geo = dict(
            scope='world',
            projection_type='miller',
            landcolor = "rgb(250, 250, 250)",
            center=dict(lat=-25.69839, lon=139.8813), # focus point
            projection_scale=6 # zoom in on
        ),
    )
fig.show()
Gaussian Mixture Models (GMM) - 10,000 new samples generated for the 4 distributions. Image by author.
Gaussian Mixture Models (GMM) – 10,000 new samples generated for the 4 distributions. Image by author.

Conclusion

Gaussian Mixture Models are useful in situations where clusters have an "elliptical" shape. While K-Means only use means (centroids) to find clusters, GMMs also include variance/covariance. This is exactly what gives GMMs an advantage over K-Means when identifying non-circular clusters.

In general, you can think of GMM as a more advanced version of K-Means, but keep in mind that it may not always give you superior results. This is because a lot depends on the shape of the data and what you are trying to achieve.

I hope this article helped you to understand Gaussian Mixture Models better. If you want to learn more about alternative clustering algorithms, you can refer to my articles below.

Cheers!👏 Saul Dobilas


K-Means Clustering – A Comprehensive Guide to Its Successful Use in Python

HAC: Hierarchical Agglomerative Clustering. Is It Better Than K-Means?

DBSCAN Clustering Algorithm – How to Build Powerful Density-Based Models


Related Articles