The Six Key Things You Need to Know About Scikit-plot

An intuitive library to add plotting functionality to scikit-learn objects.

Davis David
Towards Data Science

--

Photo by Marko Blazevic from Pexels

“Visualization gives you answers to questions you didn’t know you had.” — Ben Schneiderman

Visualization is one of the most crucial components in data science and machine learning. If you ever need to present your results to someone, you show them visualizations, not a bunch of numbers in Excel. when it comes to machine learning experiments, visualize the model during the training phase, and the results are very important that can help you make a good decision on what is the next step. By visualizing your results you can notice what you never expected to see. Scikit-plot can help you to perform the task in your machine learning project.

What is Scikit-plot?

Scikit-plot is the python package that can help you visualize your data, model(during training), and experiment results at different stages of your machine learning project. Scikit-plot is a humble attempt to provide the opportunity to generate quick and beautiful graphs and plots with as little boilerplate as possible.

In this article, you will learn the following modules provided in Scikit-plot library:-

“The greatest value of a picture is when it forces us to notice what we never expected to see” — John Tukey

Installation

Installation is simple! First, make sure you have the dependencies Scikit-learn and Matplotlib installed.

Then just run:

pip install scikit-plot

If using conda, you can install Scikit-plot by running:

conda install -c conda-forge scikit-plot

The following simple machine learning project will show you how you can implement these scikit-plot modules in your next machine learning project.

We will use the insurance dataset from Lagos and major cities in Nigeria. We will build a predictive model to determine if a building will have an insurance claim during a certain period or not. We will have to predict the probability of having at least one claim over the insured period of the building. You can download the dataset here.

We will start by importing important packages for this ML project.

Load the Insurance dataset.

top five rows

In this dataset, we have both categorical and numerical values.

Let’s see the shape of the dataset


data.shape
(7160, 14)

We have 14 variables and 7160 rows.

List of columns in the dataset.

list of columns

You can read here the meaning of each variable name in this dataset. Our target variable is a Claim which has two unique values.

  • 1 if the building has at least a claim over the insured period.
  • 0 if the building doesn’t have a claim over the insured period.

We will drop two columns Customer Id and Geo_Code.

Checking missing value in the dataset.

Missing values in each variable

We have 3 columns with missing values which are Garden, Building Dimension, and Date_of_Occupancy.

The following codes will handle all missing values by using most_frequent and mean strategies.

In the dataset, we have some categorical variables that will be converted in number by using the LabelEncoder method from scikit-learn.

The Residential, Builiding_Painted, Building_Fenced, Garden, and NumberOfWindows variables have been converted into numerical values.

Let’s separate our target variable from independent variables.

We will scale the independent variables by using MinMaxScaler from scikit-learn.

Split the dataset into the train and test set.

We will use 10% of the dataset as a test set. Before we move further into training the model, a list of independent variables will be saved in the feature_columns.

Now let's learn what scikit-plot can offer.

Estimators Module

The estimator's module in scikit-plot includes plots built specifically for scikit-learn estimator (classifier/regressor) instances e.g. Random Forest. You can use your own estimators, but these plots assume specific properties shared by scikit-learn estimators.

Now let’s create our classifier from scikit-learn. We will use the GradientBoosting algorithm as our classifier.

1. Learning Curve Plot.

Scikit-plot provides a plot_learning_curve method that can generate a plot of the train and test learning curves for a classifier. The curve can help you understand more about your model behavior during training.

We will pass important parameters in the plot_learning_curve method which are classifier, independent variables, target variable, this means the classifier will be trained and generate a learning curve at the same time.

learning curve

As you can see the learning curve can help you learn more about your model behavior. From the plot, it shows that the training score decelerate when the number of training examples increased while the cross-validation score tried to maintain its performance when the number of training examples increased.

2. Feature Importances Plot

Having a lot of features in your dataset does not mean you can create a model with good performance. Sometimes a few important features are what you need to create a model with good performance.

Scikit-plot can generate a plot of a classifier’s feature importance by using the plot_feature_importances method.

We will pass important parameters in the plot_feature_importances method which are trained model classifier and list features names.

importance features

The plot above shows a list of important features starting with the most important feature to the least important feature. The most important feature in our dataset is Building Dimension and the least important feature is Garden.

This means we can remove the Garden feature in our dataset because it does not contribute much when it’s come to model predictions.

Metrics Module

The Metrics module includes plots for machine learning evaluation metrics e.g. confusion matrix, silhouette scores, ROC, etc.

4.Confusion Matrix Plot

Confusion Matrix is very known to most Data scientists and Machine learning engineers. Confusion Matrix is a technique for summarizing the performance of a classification algorithm.
You can use the plot_confusion_matrix method to generate a confusion matrix plot from predictions and true labels.

We will pass important parameters in the plot_confusion_matrix method which are true labels and predicted labels.

Confusion Matrix

The above plot shows that 96% of class 0 are predicted correctly and only 21% of class 1 are predicted correctly.

4.Roc Curve

Roc Curve is a useful tool when predicting the probability of a binary outcome. It is a plot of the false positive rate (x-axis) versus the true positive rate (y-axis) for a number of different candidate threshold values between 0.0 and 1.0.

Scikit-plot can generate the ROC curves from labels and predicted scores/probabilities by using the plot_roc method.

We will pass true labels and predicted probabilities in the plot_confusion_matrix method to generate the plot.

Roc curve

The plot shows 4 different curves with their area’s values which are

  • Roc Curve of class 0.
  • Roc Curve of class 1.
  • Micro-average Roc Curve.
  • Macro-average Roc Curve.

5.Precision-Recall Curve

A precision-recall curve is a plot of the precision (y-axis) and the recall (x-axis) for different thresholds to evaluate classifier output quality. Precision-Recall is a useful measure of success of prediction when the classes are very imbalanced.

The precision-recall curve shows the tradeoff between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false-positive rate, and high recall relates to a low false-negative rate. High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall).

Scikit-plot can generate the Precision-Recall Curve from labels and probabilities by using the plot_precision_recall method.

We will pass true labels and predicted probabilities in the plot_precision_recall method to generate the plot.

Precision-Recall

The plot shows 3 different curves with their area’s values which are

  • Precision-recall Curve of class 0
  • Precision-recall Curve of class 1
  • Micro-average Precision-recall Curve

Decomposition Module

The Decomposition module includes plots built specifically for scikit-learn estimators that are used for dimensionality reduction e.g. PCA. You can use your own estimators, but these plots assume specific properties shared by scikit-learn estimators.

6.PCA Component Variance

Principal component analysis (PCA) is one of the earliest multivariate techniques. Yet not only it survived but it is arguably the most common way of reducing the dimension of multivariate data to reveal sometimes hidden, a simplified structure that often underlies it. Others define it as a mathematical algorithm that reduces the dimensionality of the data while retaining most of the variation in the data set.

A vital part of using PCA in practice is the ability to estimate how many components are needed to describe the data. This can be determined by looking at the cumulative explained variance ratio as a function of the number of components. Explained Variance Ratio in PCA calculates the percentage of variance explained by each of the selected components.

Scikit-plot can generate PCA components’ explained variance ratios by using the plot_pca_component_variance method.

We will pass the PCA instance that has the explained_variance_ratio_ attribute in the plot_precision_recall method to generate the curve.

Variance Ratio

This curve above quantifies how much of the total 14-dimensional variance is contained within the first N components. For example, we see that with the digits the first 3 components contain approximately 81.4% of the variance, while you need around 8 components to describe close to 100% of the variance.

You can read the article written by Matt Brems to learn more about PCA.

Wrap Up

In this article, you learned how to add plotting functionality to scikit-learn objects and gaining insights simply by looking at a colored plot. You can start to implement scikit-plots in your next machine learning project.

If you are interested in plotting clusters, scikit-plot has Clusterer Module includes plots built specifically for scikit-learn Clusterer instances, you can check the documentation to learn more about the Clusterer Module.

The dataset and source code for this article is available on Github.

If you learned something new or enjoyed reading this article, please share it so that others can see it. Feel free to leave a comment too. Till then, see you in the next post! I can also be reached on Twitter @Davis_McDavid.

One last thing: Read more articles like this in the following links.

--

--

Data Scientist 📊 | Software Developer | Technical Writer 📝 | ML Course Author 👨🏽‍💻 | Giving talks. Check my new ML course: https://bit.ly/OptimizeMLModels