tl;dr: Dirichlet Process Gaussian Mixture Models made easy.

Clustering is the bane of a data scientist’s life. How many of us have spent many a Sunday evening looking upon a fresh data-set with increasing apprehension, rummaging through algorithms in scikit-learn and fumbling over hyperparameters like distance metrics and cluster numbers whilst praying to the bias gods for mercy. The solution? Bayesian non-parametrics, the hero we all deserve but the one we all very much need.

Leon Chlon
Towards Data Science

--

Image Credit: ESA/Hubble & NASA

First and foremost: Clustering is the inverse of Generating.

A generative model is one that gives us observations. We can use a Bernoulli distribution model to generate coin flip observations. We can use a Poisson distribution model to simulate radioactive decay. We can use a trained neural network to generate classical music. The more realistic the model, the closer these observations will align with reality.

Gaussian Mixture Model

Below is a mixture of 400 samples generated from four independent bi-variate normal distributions with distinct means and equal standard deviations. The name for this model of mixed Gaussian distributions is, surprise surprise, a Gaussian Mixture Model.

K-Means Clustering

Using the K-Means algorithm and incredible Sherlockesque reasoning for the cluster number (the kernel density plot literally tells you there is 4 clusters), I was able to recover the generative model and colour-coded each cluster below. The recovery/clustering is clean because the data is squeaky clean. Each Gaussian contributed 100 samples, so we didn’t have to worry about mixing probabilities. I chose four very distinct Gaussian means to sample from and a small standard deviation so I didn’t have to worry about distribution overlap.

In reality Nature will never be this kind to you, we cannot afford to ignore the uncertainty around our estimates of cluster number, mixing probabilities and the moments of each Gaussian in the mixture. So how do we scale our clustering approach with data complexity?

Non-Parametric Models

Let’s distinguish between a parametric and non-parametric model. A parametric model is one we can write down on a piece of paper, using a fixed number of parameters. y=mx + c has two fixed parameters (m,c) and is therefore parametric.

A random forest classifier may take M parameters when fit to a dataset (y,X={x_{1},x_{2},…,x_{n}}) but take K more if we introduce another feature x_{n+1}; As our data increases in complexity, so does the model. The number of parameters has no upper bound, meaning that for data with sufficient complexity, there isn’t enough paper in the world to write down the full model. This is an example of a non-parametric model.

Who cares right? Well it turns out that when we have limited prior belief over the cluster number or mixing probabilities, we can turn non-parametric and consider infinitely many of them. Naturally many of these clusters will be redundant, and have mixing probabilities so close to 0 that we can just ignore them. Incredibly, this framework lets the data determine the most likely number of clusters.

So cool, how do we get started? First we need a way of describing a mixture of infinitely many distributions, and this is where Dirichlet Processes come in.

Dirichlet Process

A Dirichlet Process prior can be described using enough mathematical jargon to send one fleeing back to K-Means, so I’ll the migraine and give an intuitive overview instead.

We want to assign mixing probabilities 𝜋 = {𝜋_{1},𝜋_{2},…𝜋_{n}…} to an infinite number of clusters. We know that the sum of 𝜋 must be 1 and that each 𝜋_{i} is greater than or equal to 0. If we imagine the probability between 0 and 1 as a stick, we are looking for ways to break that stick up such that it accurately reflects the mixing contributions of each Gaussian in our GMM. This approach to defining a Dirichlet Process prior is called the stick-breaking process, which itself has a Beta distribution prior. I highly recommend reading up on the statistical details here.

An illustration of the stick breaking process. Credit to http://blog.shakirm.com/2015/12/machine-learning-trick-of-the-day-6-tricks-with-sticks/

In our earlier example we had equal contributions from four Gaussians, so 𝜋 = {𝜋_{1} = 0.25, 𝜋_{2} = 0.25, 𝜋_{3} = 0.25, 𝜋_{4} = 0.25, 𝜋_{5} = 0, 𝜋_{6} = 0, …, 𝜋_{n} =0,…}. We broke our stick an infinite number of times, but the probability density is exclusively centered around the first four components, allowing us to truncate the infinite distribution. So now that we have a non-parametric framework for mixing probability assignment, how do we fit this to the data?

Dirichlet Process Gaussian Mixture Models (DPGMMs)

Now for the big reveal: since 𝜋 tells us the relative contribution of each Gaussian in our GMM, it is effectively a distribution over distributions. Each 𝜋_{i} corresponds to a unique Gaussian N(μ_{i}, Σ_{i}) parameterised by a mean μ_{i} and covariance matrix Σ_{i}. If we let θ_{i} =(μ_{i}, Σ_{i}), our problem reduces to assigning a probability 𝜋_{i} to each θ_{i}, reflecting its degree of contribution to the data mixture.

Given some k-dimensional multivariate Gaussian data X, we start out with a prior belief that all p(𝜋) are equally likely. We then want to use our data X to compute the likelihood p(X|𝜋) of this new everything-is-likely construction. Finally, we update our model using Bayes theorem for the model posterior p(𝜋|X) ∝ p(X|𝜋)p(𝜋). Typically p(𝜋|X) involves some pretty nasty intractable integrals, so we rely on something like Markov Chain Monte Carlo sampling or Variational Inference to approximate p(𝜋|X). Luckily, given the right choice of priors, we can use the Expectation Maximisation (EM) algorithm for inference instead! Lets see a real application of this in scikit-learn.

Implementation

Firstly, we generate some really messy toy GMM dataset, rich with random underlying correlation structures. Using a 20-dimensional, 8 component GMM, 10,000 samples were drawn according to a weight schema initialised from a Dirichlet distribution with a uniform prior (link to code here).

A Kernel Density estimation across the first two features tells us little about the potential number of clusters in the data, but does reveal striking heterogeneity.

A scatterplot of the clusters across the first two PCA components of the data shows us how difficult recovering the general model is likely to be:

Fitting a DPGMM truncated to 15 components using the EM algorithm, we see that most of the probability mass is concentrated around the first 8 components:

In fact, if we use a metric like the Kullback–Leibler divergence, we can compute an asymmetric measure of the distance between our estimated parameters and the ground truth parameters in our dataset:

Each estimated DPGMM cluster maps its own unique ground truth cluster with a remarkably tiny KL-divergence relative to the remaining ground truth clusters it is compared against. With this kind of matchup, it’s no surprise that the DPGMM recovers the clustering almost perfectly, whilst a K-Means classifier with the cluster number chosen according to the largest average silhouette score struggles to recover the actual number of clusters.

Please check out the code linked above to try and implement this yourself and on your own data. Good luck and happy clustering!

--

--

Research Data Scientist — Facebook; Past: McKinsey Analytics Consultant | Harvard Medical School Postdoc | University of Cambridge PhD, MPhil