The world’s leading publication for data science, AI, and ML professionals.

Gaussian Mixture Models (GMMs): from Theory to Implementation

In-depth explanation of GMMs and the Expectation-Maximization algorithm used to train them

Gaussian Mixture Models (GMMs) are statistical models that represent the data as a mixture of Gaussian (normal) distributions. These models can be used to identify groups within the dataset, and to capture the complex, multi-modal structure of data distributions.

GMMs are used in a variety of machine learning applications, including clustering, density estimation, and pattern recognition.

In this article we will first explore mixture models, focusing on Gaussian mixture models and their underlying principles. Then, we will examine how to estimate the parameters of these models using a powerful technique known as Expectation-Maximization (EM), and provide a step-by-step guide to implementing it from scratch in Python. Finally, we will demonstrate how to perform Clustering with GMM using the Scikit-Learn library.

Image by Markéta Klimešová from Pixabay
Image by Markéta Klimešová from Pixabay

Mixture Models

A mixture model is a probability model for representing data that may arise from several different sources or categories, each of which is modeled by a separate probability distribution. For example, financial returns typically behave differently under normal market conditions and during periods of crisis, and thus can be modeled as a mixture of two distinct distributions.

Formally, if X is a random variable whose distribution is a mixture of K component distributions, the probability density function (PDF) or probability mass function (PMF) of X can be written as:

A mixture model
A mixture model

where:

  • p(x) is the overall density or mass function of the mixture model.
  • K is the number of component distributions in the mixture.
  • fₖ(x; θₖ) is the density or mass function of the k-th component distribution, parametrized by θₖ.
  • wₖ is the mixing weight of the k-th component, with 0 ≤ wₖ ≤ 1 and the sum of the weights being 1. wₖ is also known as the prior probability of component k.
  • θₖ represents the parameters of the k-th component, such as the mean and standard deviation in the case of Gaussian distributions.

The mixture model assumes that each data point comes from one of the K component distributions, with the specific distribution being selected according to the mixing weights wₖ. The model does not require knowing which component each data point belongs to.

A Gaussian Mixture Model (GMM) is a common mixture model, where the probability density is given by a mixture of Gaussian distributions:

A Gaussian mixture model
A Gaussian mixture model

where:

  • x is a d-dimensional vector.
  • μₖ is the mean vector of the k-th Gaussian component.
  • Σ is the covariance matrix of the k-th Gaussian component.
  • N(x; μₖ, Σ) is the multivariate normal density function for the k-th component:

In the case of univariate Gaussian distributions, the probability density can be simplified to:

A mixture model of univariate Gaussian distributions
A mixture model of univariate Gaussian distributions

where:

  • μₖ is the mean of the k-th Gaussian component.
  • σₖ is the covariance matrix of the k-th Gaussian component.
  • N(x; μₖ, σₖ) is the univariate normal density function for the k-th component:

For example, the following Python function plots a mixture distribution of two univariate Gaussian distributions:

from scipy.stats import norm

def plot_mixture(mean1, std1, mean2, std2, w1, w2):
    # Generate points for the x-axis
    x = np.linspace(-5, 10, 1000)

    # Calculate the individual nomral distributions
    normal1 = norm.pdf(x, mean1, std1)
    normal2 = norm.pdf(x, mean2, std2)

    # Calculate the mixture
    mixture = w1 * normal1 + w2 * normal2

    # Plot the results
    plt.plot(x, normal1, label='Normal distribution 1', linestyle='--')
    plt.plot(x, normal2, label='Normal distribution 2', linestyle='--')
    plt.plot(x, mixture, label='Mixture model', color='black')
    plt.xlabel('$x$')
    plt.ylabel('$p(x)$')
    plt.legend()

Let’s use this function to plot a mixture of two Gaussian distributions with parameters _μ_₁ = -1, _σ_₁ = 1, _μ_₂ = 4, _σ_₂ = 1.5, and mixture weights of _w_₁ = 0.7 and _w_₂ = 0.3:

# Parameters for the two univariate normal distributions
mean1, std1 = -1, 1
mean2, std2 = 4, 1.5
w1, w2 = 0.7, 0.3

plot_mixture(mean1, std1, mean2, std2, w1, w2)
A mixture model of two univariate Gaussian distributions
A mixture model of two univariate Gaussian distributions

The dashed lines represent the individual normal distributions, and the solid black line shows the resulting mixture. This plot illustrates how the mixture model combines the two distributions, each with its own mean, standard deviation, and weight in the overall mixture. ​

Learning the GMM Parameters

Our goal is to find the parameters of the GMM (means, covariances, and mixing coefficients) that will best explain the observed data. To that end, we first define the likelihood of the model given the input data.

For a GMM with K components and a dataset X = {x₁, …, x} of n data points, the likelihood function L is given by the product of the probability densities of each data point, as defined by the GMM:

The likelihood of the GMM model
The likelihood of the GMM model

where θ represents all the parameters of the model (means, variances, and mixture weights).

In practice, it is easier to work with the log-likelihood, since the product of probabilities can to lead to numerical underflow for large datasets. The log- likelihood is given by:

The parameters of the GMM can be estimated by maximizing this log-likelihood function with respect to θ. __ However, we cannot directly apply Maximum Likelihood Estimation (MLE) to estimate the parameters of a GMM due to the following reasons:

  1. The log-likelihood function is highly non-linear and complex to maximize analytically.
  2. The model has latent variables (the mixture weights), which are not directly observable in the data.

To overcome these issues, the Expectation-Maximization (EM) algorithm is commonly used instead. This algorithm is described in the next section.

Expectation-Maximization (EM)

The EM algorithm is a powerful method for finding maximum likelihood estimates of parameters in statistical models that depend on unobserved latent variables.

The algorithm begins by randomly initializing the model parameters. Then it iterates between two steps:

  1. Expectation step (E-step): Compute the expected log-likelihood of the model with respect to the distribution of the latent variables, given the observed data and the current estimates of the model parameters. This step involves an estimation of the probabilities of the latent variables.
  2. Maximization step (M-step): Update the parameters of the model to maximize the log-likelihood of the observed data, given the estimated latent variables from the E-step.

These two steps are repeated until convergence, typically determined by a threshold on the change in the log-likelihood or a maximum number of iterations.

Let’s formulate the update equations used in the EM steps for estimating the parameters of a Gaussian Mixture Model. In GMMs, the latent variables represent the unknown component memberships of each data point. Let Zᵢ be the random variable indicating the component from which data point x was generated. Zᵢ can take one of the values {1, …, K}, corresponding to the K components.

E-Step

In the E-step, we compute the probability distributions of the latent variables Zᵢ given the current estimates of the model parameters. In other words, we calculate the membership probabilities for each data point in each Gaussian component.

The probability that Zᵢ = k, i.e., that x belongs to the k-th component, can be computed using Bayes’ rule:

Let’s denote this probability by the variable γ(zᵢₖ). Thus, we can write:

The variables γ(zᵢₖ) are often referred to as responsibilities, __ since they describe how responsible is each component for each observation. These responsibilities serve as proxies for the missing information about the latent variables.

The expected log-likelihood with respect to the distribution of the latent variables can now be written as follows:

The function Q is a weighted sum of the log-likelihoods of all the data points under each Gaussian component, with the weights being the responsibilities. Note that Q is different from the log-likelihood function l(θ|X) shown earlier. The log-likelihood l(θ|X) expresses the likelihood of the observed data under the mixture model as a whole, without explicitly accounting for the latent variables, whereas Q represents an expected log-likelihood over both the observed data and the estimated latent variable distributions.

M-Step

In the M-step, we update the parameters θ of the GMM (means, covariances, and mixing weights) so as to maximize the expected likelihood Q(θ) using the responsibilities calculated in the E-step.

The parameter updates are as follows:

  1. Update the means for each component:

That is, the new mean of the k-th component is a weighted average of all the data points, with the weights being the probabilities that these points belong to component k.

This update formula can be derived from maximizing the expected log-likelihood function Q with respect to the means μₖ. I will show here the proof for the univariate Gaussian distributions case.

Proof:

The expected log-likelihood in the case of univariate Gaussian distributions is:

Taking the derivative of this function with respect to μₖ and setting it to 0 gives us:


  1. Update the covariances for each component:

That is, the new covariance of the k-th component is a weighted average of the squared deviations of each data point from the component’s mean, where the weights are the probabilities of the points assigned to that component.

In the case of univariate normal distributions, this update is simplified to:

  1. Update the mixing weights:

That is, the new weight of the k-th component is the total probability of the points belonging to this component, normalized by the number of points n.

Repeating these two steps is guaranteed to convergence to a local maximum of the likelihood function. Since the final optimum reached depends on the initial random parameter values, it is a common practice to run the EM algorithm several times with varied random initializations and keep the model that obtains the highest likelihood.

Implementation in Python

We will now implement the EM algorithm for estimating the parameters of a GMM of two univariate Gaussian distributions from a given dataset.

We start by importing the required libraries:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm

np.random.seed(0)  # for reproducibility

Next, let’s write a function to initialize the parameters of the GMM:

def init_params(x):    
    """Initialize the parameters for the GMM
    """    
    # Randomly initialize the means to points from the dataset
    mean1, mean2 = np.random.choice(x, 2, replace=False)

    # Initialize the standard deviations to 1
    std1, std2 = 1, 1

    # Initialize the mixing weights uniformly
    w1, w2 = 0.5, 0.5

    return mean1, mean2, std1, std2, w1, w2

The means are initialized from random data points in the dataset, the standard deviations are set to 1, and the mixing weights are set uniformly to be 0.5.

We now implement the E-step, in which we compute the responsibilities (probabilities) for each data point belonging to each Gaussian component:

def e_step(x, mean1, std1, mean2, std2, w1, w2):
    """E-Step: Compute the responsibilities
    """    
    # Compute the densities of the points under the two normal distributions  
    prob1 = norm(mean1, std1).pdf(x) * w1
    prob2 = norm(mean2, std2).pdf(x) * w2

    # Normalize the probabilities
    prob_sum = prob1 + prob2 
    prob1 /= prob_sum
    prob2 /= prob_sum

    return prob1, prob2

In the M-step, we update the model parameters based on the responsibilities calculated in the E-step:

def m_step(x, prob1, prob2):
    """M-Step: Update the GMM parameters
    """    
    # Update means
    mean1 = np.dot(prob1, x) / np.sum(prob1)
    mean2 = np.dot(prob2, x) / np.sum(prob2)

    # Update standard deviations
    std1 = np.sqrt(np.dot(prob1, (x - mean1)**2) / np.sum(prob1))
    std2 = np.sqrt(np.dot(prob2, (x - mean2)**2) / np.sum(prob2))

    # Update mixing weights
    w1 = np.sum(prob1) / len(x)
    w2 = 1 - w1

    return mean1, std1, mean2, std2, w1, w2

Finally, we write the main function that runs the EM algorithm, iterating between the E-step and M-step for a specified number of iterations:

def gmm_em(x, max_iter=100):
    """Gaussian mixture model estimation using Expectation-Maximization
    """    
    mean1, mean2, std1, std2, w1, w2 = init_params(x)

    for i in range(max_iter):
        print(f'Iteration {i}: μ1 = {mean1:.3f}, σ1 = {std1:.3f}, μ2 = {mean2:.3f}, σ2 = {std2:.3f}, ' 
              f'w1 = {w1:.3f}, w2 = {w2:.3f}')

        prob1, prob2 = e_step(x, mean1, std1, mean2, std2, w1, w2)
        mean1, std1, mean2, std2, w1, w2 = m_step(x, prob1, prob2)     

    return mean1, std1, mean2, std2, w1, w2

To test our implementation, we will create a synthetic dataset by sampling data from a known mixture distribution with predefined parameters. Then, we will use the EM algorithm to estimate the parameters of the distribution, and compare the estimated parameters with the original ones.

First, let’s write a function to sample data from a mixture of two univariate normal distributions:

def sample_data(mean1, std1, mean2, std2, w1, w2, n_samples):    
    """Sample random data from a mixture of two Gaussian distribution.
    """
    x = np.zeros(n_samples)
    for i in range(n_samples):
        # Choose distribution based on mixing weights
        if np.random.rand() < w1:
            # Sample from the first distribution
            x[i] = np.random.normal(mean1, std1)
        else:
            # Sample from the second distribution
            x[i] = np.random.normal(mean2, std2)

    return x

We will now use this function to sample 1,000 data points from the mixture distribution we have defined earlier:

# Parameters for the two univariate normal distributions
mean1, std1 = -1, 1
mean2, std2 = 4, 1.5
w1, w2 = 0.7, 0.3

x = sample_data(mean1, std1, mean2, std2, w1, w2, n_samples=1000)

We can now run the EM algorithm on this dataset:

final_dist_params = gmm_em(x, max_iter=30)

We get the following output:

Iteration 0: μ1 = -1.311, σ1 = 1.000, μ2 = 0.239, σ2 = 1.000, w1 = 0.500, w2 = 0.500
Iteration 1: μ1 = -1.442, σ1 = 0.898, μ2 = 2.232, σ2 = 2.521, w1 = 0.427, w2 = 0.573
Iteration 2: μ1 = -1.306, σ1 = 0.837, μ2 = 2.410, σ2 = 2.577, w1 = 0.470, w2 = 0.530
Iteration 3: μ1 = -1.254, σ1 = 0.835, μ2 = 2.572, σ2 = 2.559, w1 = 0.499, w2 = 0.501
...
Iteration 27: μ1 = -1.031, σ1 = 1.033, μ2 = 4.180, σ2 = 1.371, w1 = 0.675, w2 = 0.325
Iteration 28: μ1 = -1.031, σ1 = 1.033, μ2 = 4.181, σ2 = 1.370, w1 = 0.675, w2 = 0.325
Iteration 29: μ1 = -1.031, σ1 = 1.033, μ2 = 4.181, σ2 = 1.370, w1 = 0.675, w2 = 0.325

The algorithm has converged to parameters that are close to the original parameters of the mixture: _μ_₁ = -1.031, _σ_₁ = 1.033, _μ_₂ = 4.181, _σ_₂ = 1.370, and mixture weights of _w_₁ = 0.675 and _w_₂ = 0.325.

Let’s use the plot_mixture() function we have written earlier to plot the final distribution. We will update the function to plot an histogram of the sampled data as well:

def plot_mixture(x, mean1, std1, mean2, std2, w1, w2):
    # Plot an histogram of the input data
    sns.histplot(x, bins=20, kde=True, stat='density', linewidth=0.5, color='gray')

    # Generate points for the x-axis
    x_ = np.linspace(-5, 10, 1000)

    # Calculate the individual nomral distributions
    normal1 = norm.pdf(x_, mean1, std1)
    normal2 = norm.pdf(x_, mean2, std2)

    # Calculate the mixture
    mixture = w1 * normal1 + w2 * normal2

    # Plot the results
    plt.plot(x_, normal1, label='Normal distribution 1', linestyle='--')
    plt.plot(x_, normal2, label='Normal distribution 2', linestyle='--')
    plt.plot(x_, mixture, label='Mixture model', color='black')
    plt.xlabel('$x$')
    plt.ylabel('$p(x)$')
    plt.legend()
plot_mixture(x, *final_dist_params)

The result is shown in the following graph:

The mixture distribution estimated from the dataset using the EM algorithm
The mixture distribution estimated from the dataset using the EM algorithm

As can be seen, the estimated distribution closely aligns with the histogram of the data points.

Exercise: Extend the code above to handle multivariate normal distributions and any number of distributions K. Hint: You can use the function scipy.stats.multivariate_normalto compute the PDF of a multivariate normal distribution.

GMM in Scikit-Learn

Scikit-Learn provides an implementation of Gaussian mixture model in the class [sklearn.mixture.GaussianMixture](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html). Important parameters of this class include:

  • n_components: The number of mixture components (defaults to 1).
  • covariance_type: The type of covariance parameters to use. Can be one of the following options:
  • 'full': Each component has its own covariance matrix.
  • 'tied': All components share the same covariance matrix.
  • 'diag': Each component has its own covariance matrix, which must be diagonal.
  • 'spherical': Each component has its own single variance.
  • tol: The convergence threshold. The EM algorithm will stop when the average improvement of the log-likelihood falls below this threshold (defaults to 0.001).
  • max_iter: The number of EM iterations to perform (defaults to 100).
  • n_init: The number of random initializations to perform (defaults to 1).
  • init_params: The method used to initialize the parameters of the model. Can take one of the following options: 'kmeans': The parameters are initialized using k-means (the default). 'k-means++': The parameters are initialized using k-means++. 'random': The parameters are randomly initialized. 'random_from_data': The initial means are randomly selected from the given data points.

In addition, this class provides the following attributes:

  • weights_: The mixture weights.
  • means_: The means of each component.
  • covariances_: The covariance of each component.
  • converged_: A Boolean indicating whether a convergence has been reached by the EM algorithm.
  • n_iter_: The number of steps used by the EM to reach convergence.

Note that unlike other clustering algorithms in Scikit-Learn, this class does not provide a labels_ attribute. Therefore, to get the cluster assignments of the data points, you need to call the predict()method on the fitted model (or call fit_predict()).

For example, let’s use this class to perform clustering on the following dataset, which consists of two elliptical blobs and a spherical one:

from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=500, centers=[(0, 0), (4, 4)], random_state=0)

# Apply a linear transformation to make the blobs elliptical
transformation = [[0.6, -0.6], [-0.2, 0.8]]
X = np.dot(X, transformation) 

# Add another spherical blob
X2, y2 = make_blobs(n_samples=150, centers=[(-2, -2)], cluster_std=0.5, random_state=0)
X = np.vstack((X, X2))

Let’s plot the dataset:

def plot_data(X):
    sns.scatterplot(x=X[:, 0], y=X[:, 1], edgecolor='k', legend=False)
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')

plot_data(X)

Next, we instantiate the GMMclass with n_components=3, and call its fit_predict() method to get the cluster assignments:

from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(n_components=3)
labels = gmm.fit_predict(X)

We can check how many iterations it took for the EM algorithm to converge:

print(gmm.n_iter_)
2

It took only two iterations for the EM algorithm to converge in this case.

We can also examine the estimated GMM parameters:

print('Weights:', gmm.weights_)
print('Means:n', gmm.means_)
print('Covariances:n', gmm.covariances_)
Weights: [0.23077331 0.38468283 0.38454386]
Means:
 [[-2.01578902 -1.95662033]
 [-0.03230299  0.03527593]
 [ 1.56421574  0.80307925]]
Covariances:
 [[[ 0.254315   -0.01588303]
  [-0.01588303  0.24474151]]

 [[ 0.41202765 -0.53078979]
  [-0.53078979  0.99966631]]

 [[ 0.35577946 -0.48222654]
  [-0.48222654  0.98318187]]]

We can see that the estimated weights are very close to the original proportions of the three blobs, and the mean and variance of the spherical blob are very close to its original parameters.

Let’s plot the clusters:

def plot_clusters(X, labels):    
    sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=labels, palette='tab10', edgecolor='k', legend=False)
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')

plot_clusters(X, labels)
The results of GMM clustering
The results of GMM clustering

GMM has correctly identified all three clusters.

In addition, we can use the method predict_proba() to get the membership probabilities for each data point in each cluster.

prob = gmm.predict_proba(X)

For example, the first point in the dataset has a very high probability of belonging to the green cluster:

print('x =', X[0])
print('prob =', prob[0])
x = [ 2.41692591 -0.07769481]
prob = [3.11052582e-21 8.85973054e-10 9.99999999e-01]

We can visualize these probabilities by making the size of each point proportional to its probability of belonging to the cluster it was assigned to:

sizes = prob.max(axis=1)
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=labels, size=sizes, palette='tab10', edgecolor='k', legend=False)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.savefig('figures/elliptical_blobs_gmm_prob.pdf')
Probabilities of cluster assignments
Probabilities of cluster assignments

We can see that the points that lie in the border between the two elliptical clusters have lower probability. Data points that have a significantly low probability density (e.g., falling below a predefined threshold) can be identified as anomalies or outliers.

For comparison, the following figure shows the results of other clustering algorithms applied to the same dataset:

As can be seen, other clustering algorithms fail to identify correctly the elliptical clusters.

Model Evaluation

The log-likelihood is a primary measure used to evaluate GMMs. It is also monitored during training to check for convergence of the EM algorithm. However, sometimes we need to compare models with different number of components or different covariance structures.

To that end, we have two additional measures, which balance the model complexity (number of parameters) against its goodness of fit (represented by the log-likelihood):

  1. Akaike Information Criterion (AIC):

where:

  • p is the number of parameters in the model (including all the means, covariances, and mixing weights).
  • L is the maximum likelihood of the model (the likelihood of the model with the optimal parameter values).

Lower values of AIC indicate a better model. AIC rewards models that fit the data well, but also penalizes models with more parameters.

  1. Bayesian Information Criterion (BIC):

where p and L are defined as before, and n is the number of data points.

Similar to AIC, BIC balances model fit and complexity, but places a greater penalty on models with more parameters, as p is multiplied by log(n) instead of 2.

In Scikit-Learn, you can compute these measures using the methods aic() and bic() of the GMMclass. For example, the AIC and BIC values of the GMM clustering of the blobs dataset are:

print(f'AIC = {gmm.aic(X):.3f}')
print(f'BIC = {gmm.bic(X):.3f}')
AIC = 4061.318
BIC = 4110.565

These measures can be used to find the optimal number of components by fitting GMMs with different numbers of components to the dataset and then selecting the model with the lowest AIC or BIC value.

Summary

Let’s summarize the pros and cons of GMMs as compared to other clustering algorithms:

Pros:

  • Unlike k-means, which assumes spherical clusters, GMMs can adapt to ellipsoidal shapes thanks to the covariance component. This allows GMMs to capture a wider variety of cluster shapes.
  • Can deal with clusters with varying sizes due to their use of covariance matrices and mixing coefficients, which account for the spread and proportion of each cluster.
  • GMMs provide probabilities (soft assignments) of each point belonging to each cluster, which can be more informative in understanding the data.
  • Can deal with overlapping clusters, since it assigns data points to clusters based on probabilities rather than hard boundaries.
  • Easy to explain the clustering results, because each cluster is represented by a Gaussian distribution with specific parameters.
  • In addition to clustering, GMMs can also be used for density estimation and anomaly detection.

Cons:

  • Require specifying the number of components (clusters) in advance.
  • Assume that the data in each cluster follows a Gaussian distribution, which might not always be a valid assumption for real-world data.
  • May not work well when clusters contain only a few data points, as the model relies on sufficient data to accurately estimate the parameters of each component.
  • The clustering results can be sensitive to the initial choice of parameters.
  • The EM algorithm used in GMMs can get stuck in a local optimum, and its convergence can be slow.
  • Badly-conditioned covariance matrices (i.e., matrices that are near singular or have a very high condition number) can lead to numerical instabilities during the EM computations.
  • Computationally more intensive than simpler algorithms like k-means, especially for large datasets or when the number of components is high.

Thanks for reading!

All the images are by the author unless stated otherwise.

You can find the code examples of this article on my github: https://github.com/roiyeho/medium/tree/main/gmm


Related Articles