Data visualization is very important in the field of Data Science. It is not only used for delivering results but also an essential part in exploratory data analysis.
Matplotlib is a widely-used Python data visualization library. In fact, many other libraries are built on top of Matplotlib such as Seaborn.
The syntax of Matplotlib is usually more complicated than other visualization libraries for Python. However, it offers you flexibility. You can customize the plots freely.
This post can be considered as a Matplotlib tutorial but heavily focused on the practical side. In each example, I will try to produce a different plot that points out important features of Matplotlib.
I will do examples on a customer churn dataset that is available on Kaggle. I use this dataset quite often because it is a good mixture of categorical and numerical variables. Besides, it carries a purpose so the examples constitute an exploratory data analysis process.
Let’s first install the dependencies:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
Matplotlib consists of 3 layers which are the Backend, Artist, and Scripting layers. The scripting layer is the matplotlib.pyplot interface.
The scripting layer makes it relatively easy to create plots because it automates the process of putting everything together. Thus, it is the most widely-used layer by data scientists.
We will read the dataset into a Pandas dataframe.
cols = ['CreditScore', 'Geography', 'Gender', 'Age', 'Tenure', 'Balance', 'NumOfProducts', 'IsActiveMember', 'EstimatedSalary',
'Exited']
churn = pd.read_csv("/content/Churn_Modelling.csv", usecols=cols)
churn.head()

The dataset contains some features about the customers of a bank and their bank account. The "Exited" column indicates whether a customer churned (i.e. left the bank).
We are ready to start.
1. Number of customers in each country
This one is pretty simple but a good example for bar plots.
plt.figure(figsize=(8,5))
plt.title("Number of Customers", fontsize=14)
plt.bar(x=churn['Geography'].value_counts().index,
height=churn.Geography.value_counts().values)

In the first line, we create a Figure object with a specific size. The next line adds a title to the Figure object. The bar function plots the actual data.
2. Adjusting xticks and yticks
The default settings are usually appropriate but minor adjustments might be necessary in some cases. For instance, we can increase the fontsize and also adjust the value range of y-axis.
plt.xticks(fontsize=12, rotation=45)
plt.yticks(ticks=np.arange(0, 7000, 1000), fontsize=12)
Adding these two lines of codes to the previous plot will produce:

3. Changing the default figure size
The default figure size is (6,4) which I think is pretty small. If you don’t want to explicitly define the size for each figure, you may want to change the default setting.
The rcParams package of matplotlib is used to store and change the default settings.
plt.rcParams.get('figure.figsize')
[6.0, 4.0]
As you can see, the default size is (6,4). Let’s change it to (8,5):
plt.rcParams['figure.figsize'] = (8,5)
We can also change the default setting for other parameters such as line style, line width, and so on.
I have also changed the fontsize of xtick and yticks to 12.
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
4. Creating a simple histogram
Histogram is used to visualize the distribution of a variable.
The following syntax will create a simple histogram of customer balances.
plt.hist(x=churn['Balance'])

Most the customers have zero balance. When zero balance excluded, the distribution is close to the normal (Gaussian) distribution.
5. Customizing the histogram
The two essential features that define a histogram are the number of bins and the value range.
The default value for the number of bins is 10 so the value range will be divided into 10 equal bins. For instance, the first bin in the previous histogram is 0–25000. Increasing the bin size is like having more resolution. We will get a more accurate overview of the distribution to some point.
The value range is defined by taking the minimum and maximum values of the column. We can adjust it to exclude the outliers or specific values.
plt.hist(x=churn['Balance'], bins=12, color='darkgrey',
range=(25000, 225000))
plt.title("Distribution on Balance (25000 - 225000)", fontsize=14)

The values that are lower than 25000 or higher than 225000 are excluded and the number of bins increase from 10 to 16. We now see a typical normal distribution.
6. Creating a simple scatter plot
Scatter plots are commonly used to map the relationship between numerical variables. We can visualize the correlation between variables using a scatter plot.
sample = churn.sample(n=200, random_state=42) #small sample
plt.scatter(x=sample['CreditScore'], y=sample['Age'])

It seems like there is not a correlation between the age and credit score.
7. Scatter plots with subplots
We can put multiple scatter plots on the same Figure object. Although the syntax is longer than some other libraries (e.g. Seaborn), Matplotlib is highly flexible in terms of subplots. We will do several examples that consist of subplots.
The subplots function creates a Figure and a set of subplots:
fig, ax = plt.subplots()
We can create multiple plots on the figure and identify them with a legend.
plt.title("France vs Germany", fontsize=14)
ax.scatter(x=sample[sample.Geography == 'France']['CreditScore'], y=sample[sample.Geography == 'France']['Age'])
ax.scatter(x=sample[sample.Geography == 'Germany']['CreditScore'], y=sample[sample.Geography == 'Germany']['Age'])
ax.legend(labels=['France','Germany'], loc='lower left', fontsize=12)

8. Grid of subplots
The subplots do not have to be on top of each other. The subplots function allows creating a grid of subplots by using the nrows and ncols parameters.
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1)

We have an empty grid of subplots. In the following examples, we will see how to fill these subplots and make small adjustments to make them look nicer.
9. Rearranging and accessing the subplots
Before adding the titles, let’s put a little space between the subplots so that they will look better. We will do that with tight_layout function.
We can also remove the xticks in between and only have the ones at the bottom. This can be done with the sharex parameter.
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(9,6), sharex=True)
fig.tight_layout(pad=2)

There are two ways to access the subplots. One way is to define them explicitly and the other way is to use indexing.
# 1
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
first subplot: ax1
first subplot: ax2
# 2
fig, axs = plt.subplots(nrows=2, ncols=1)
first subplot: axs[0]
second subplot: axs[1]
10. Drawing the subplots
We will create a grid of 2 columns and add bar plots to each one.
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True,
figsize=(8,5))
countries = churn.Geography.value_counts()
products = churn.NumOfProducts.value_counts()
ax1.bar(x=countries.index, height=countries.values)
ax1.set_title("Countries", fontsize=12)
ax2.bar(x=products.index, height=products.values)
ax2.set_title("Number of Products", fontsize=12)

11. Creating a 2-D histogram
2D histograms visualize the distributions of a pair of variables. We get an overview of how the values of two variables change together.
Let’s create a 2D histogram of the credit score and age.
plt.title("Credit Score vs Age", fontsize=15)
plt.hist2d(x=churn.CreditScore, y=churn.Age)

The most populated group consists of the customers between ages 30 and 40 and have credit scores between 600 and 700.
Conclusion
What we have covered in this post is just a small part of Matplotlib’s capabilities. Some of the information I shared can be considered as a detail while some of them are very basic. However, they all are useful in making the most out of Matplotlib.
The best way to master Matplotlib, like in any other subject, is to practice. Once you are comfortable with the basic functionality, you can proceed to the more advanced features.
Thank you for reading. Please let me know if you have any feedback.