An Introduction to ProtoDash — An Algorithm to Better Understand Datasets and Machine Learning Models

Stacey Ronaghan
Towards Data Science
5 min readNov 5, 2020

--

Photo by Pixabay on Pexels

This post has been written in the hope you will discover a new way to explore and sample your data.

To understand our data, we start with descriptive analytics, slicing and dicing in search of expected and interesting patterns. We might then look for clusters; in text data, these might represent be topics, whilst for customers, these might be personas or cohorts.

Alternatively, we may wish to find other examples within our data that are similar to a subset we are interested in.

ProtoDash can be used to support these activities by selecting representative examples within the data.

What is ProtoDash?

The ProtoDash algorithm was created through a collaboration between Amazon and IBM Research. It is a method for selecting prototypical examples that capture the underlying distribution of a dataset. It also weights each prototype to quantify how well it represents the data.

ProtoDash requires three key inputs: the dataset you want to explain, the dataset you want to select prototypical explanations from, and the number of prototypes to select.

One of the key benefits of ProtoDash is that it is designed to find diverse prototypes, examples that reflect the dataset in different ways, to give a more complete picture.

Once the initial prototype — the most typical data point — is discovered, the algorithm will search for a second prototype. Whilst it is again looking for an example with common behaviors, it is also trying to ensure new characteristics are discovered, that the second prototype differs from the first. This search continues until it finds the number of prototypes requested.

Use Cases

ProtoDash for Subset Selection

By using ProtoDash as designed, you can discover examples in one dataset that best represent the distribution of another. This can be very useful when you want to target customers for a particular product based upon prior customers, or retrospectively understand the impact of an event.

For example, you may want to explore the effects of an action, such as “did receiving a marketing email impact sales?”. Since there are many other factors that influence sales, we would want to ensure that when comparing those that received the marketing email, the treatment group, and those that did not, the control group, have similar characteristics. To do this you would pass the treatment group as the dataset you wish to explain and the control group to find prototypes from.

ProtoDash for Explainable AI

If you pass only one data point as the dataset you want to explain, you are asking for prototypes in the second dataset that are most similar to this one example.

In order to trust machine learning models, we wish to better understand how they work. One of the key questions is “do similar examples get the same outcome?”. For example, if a loan was rejected, we may want to make sure that similar applicants were also rejected. If this is not the case, it might indicate the machine learning model is not behaving as expected.

By identifying similar examples using ProtoDash, you can evaluate if the machine learning model predictions from each are similar in ways that you would expect if it were working as desired.

ProtoDash for Segmentation

If you pass only one dataset to ProtoDash, indicating that prototypical explanations should be selected from the dataset you wish to be explained, the selected prototypes will summarize the underlying data distribution.

This can be used as an alternative to clustering for segmentation. Each prototype would be an example that represents a particular segmentation. For example, you might want to understand common topics within customer emails, and ProtoDash may provide examples that represent queries around payments, delivery, sign-up, etc.

To generate clusters from a prototype, you would use the prototype as a cluster centroid and assign all other examples within your dataset to their nearest prototype.

How does it work?

The goal of the ProtoDash algorithm is to approximate the distribution of a dataset with weighted samples from another. To do this, it selects samples that minimize the Maximum Mean Discrepancy (MMD). MMD indicates how well the distribution of a dataset is represented by the sample of another distribution.

ProtoDash requires the kernel for MMD to be specified. If “Linear” is used, this means that the underlying distribution will be focused only on the mean of each of the features. However, when “Gaussian” is used, this takes into consideration mean, variance, skew, etc. for each feature, capturing a much more detailed understanding of the data.

A simple reformulation of MMD becomes the objective function to be maximized when selecting prototypes and weights. The methodology is as follows:

1. Compute the value of the objective function given the current prototype set (required for step 2)

2. Calculate the gradient of each example with respect to the objective function

3. Choose the example with the largest gradient value and add this to the prototype set

4. Compute the weights by optimizing the objective function using quadratic programming

5. Repeat until the appropriate number of prototypes have been selected

Example

The ProtoDash whitepaper includes an example for MNIST, an image dataset with examples for digits 0 to 9. The experiments were designed to validate how well ProtoDash could select prototypes from a source dataset that match the distribution of a different target dataset. For example, instances of the digit 0 are passed as the dataset to be explained, and a dataset including examples of all digits (0 to 9) are passed to find prototypes from.

Below are examples of prototypes the algorithm discovered. You can see that for digit, the algorithm has discovered the underlying features as each returned is clearly correctly discovered. In addition, the examples themselves vary by width, thickness, slant.

Image from Paper: Efficient Data Representation by Selecting Prototypes with Importance Weights

ProtoDash’s ability to select prototypes from one dataset that matches closely the distribution of another has applications in transfer learning, multi-task learning, and covariate shift correction.

Summary

ProtoDash is available as part of the AI Explainability 360 Toolkit, an open-source library that supports the interpretability and explainability of datasets and machine learning models.

By finding prototypical examples, ProtoDash provides an intuitive method of understanding the underlying characteristics of a dataset. This can be a valuable solution for a range of scenarios, including subset selection, segmentation, and supporting machine learning explainability.

Links

Paper: Efficient Data Representation by Selecting Prototypes with Importance Weights

Video: ProtoDash: Fast Interpretable Prototype Selection by Karthik Gurumoorthy

AI Explainability 360 Toolkit (AIX360)

Example AIX360 ProtoDash Notebooks

--

--