Data Visualization

Making Plots in Jupyter Notebook Beautiful & More Meaningful

Customizing matplotlib.pyplot to make plots better

Bipin P.
Towards Data Science
10 min readMar 11, 2020

--

Photo by ©iambipin

As our world has become more and more data-driven, important decisions of the people who could make a tremendous impact on the world we live in, like the governments, big corporates, politicians, business tycoons(you name it) are all influenced by the data in an unprecedented manner. Consequently, data visualization started playing a pivotal role in the day to day affairs than ever before. Hence acquiring skills in this arena is gaining prominence.

In the world of data science, Python is the programming language of choice(the undisputed leader in data science). Hence we would be considering Matplotlib for plotting. Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. matplotlib.pyplot is a collection of command style functions that enables matplotlib to work like MATLAB. pyplot function can be made to create a figure, create a plotting area in a figure, plot some lines in a plotting area, decorate the plot with labels, etc. Visualizations can be quickly generated using a pyplot.

Let’s create some code in Jupyter notebook to create a normal distribution. For the uninitiated, normal distribution is a continuous probability distribution for a real-valued random variable. It can be easily identified by the bell-shaped curve(Probability Density Function) and its symmetry.

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
mu = 0
std = 1
x = np.linspace(start=-4, stop=4, num=100)
y = stats.norm.pdf(x, mu, std)
plt.plot(x, y)
plt.show()
Plot 1: Normal Distribution | Photo by ©iambipin

The code creates a simple plot of the normal distribution with mean=0 and standard deviation=1. As our primary concern is about making plots more beautiful, the explanation of code about the mathematical aspects will be rather brief. np.linspace() returns evenly spaced samples(number of samples equal to num) over a specific interval[start, stop]. scipy.stats.norm() returns a normal continuous random variable. scipy.stats.norm.pdf() computes the PDF at any point for a given value of mean(mu) and standard deviation(std).

The graph seems to appear too ordinary and bland. There are neither labels nor title to provide some valuable information to a third person. There’s no grid to easily identify and correlate values. The size of the figure is also a bit small to my liking.

Let’s make the plots beautiful by harnessing the various features of pyplot.

Adding Grid Lines

Grids help to easily identify and correlate values in the plot. plt.grid() configures grid lines. It shows the grid in default settings.

plt.grid()

A simple code to create a figure is as follows:

import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
plt.grid()
plt.show()
Photo by ©iambipin

plt.figure() creates a new figure. plt.axes() assigns axes to the new figure and makes it the current axes(plural of axis). The figure of Matplotlib can be considered as a single container that contains all the information about axes, graphics, text, and labels. The axes can be seen as a bounding box with ticks and labels that will contain the plot elements of visualization. plt.show() displays all figures and block until the figures have been closed

Let’s play around with various aspects of the grid(). Start with the color of the grid.

fig = plt.figure()
ax = plt.axes()
plt.grid(b=True, color='r')
Photo by ©iambipin

The color is a keyword argument that assigns the color to the grid. The optional parameter b takes boolean values(True or False). If set to False, grids will disappear. However, if any keyword arguments(like alpha, color, linewidth, etc) is present, then b will be set to True irrespective of the value of b given. For instance, if the above code snippet is changed:

fig = plt.figure()
ax = plt.axes()
plt.grid(b=False, color = 'r')

The output will be the same plot with the red-colored grid as shown above.

For color, you can use any of the following strings as values:

  • The common names of colors like red, blue, brown, magenta, etc.
    color =’purple’
  • The color hex code #RRGGBB with values from 00 to FF. Here R=Red, G=Green, B=Blue
    color = ‘#e3e3e3’
  • RGB tuple with values from 0 to 1.
    color = (0.6,0.9,0.2)
  • Grayscale with values from 0.0 to 1.0. As the value increases the black lines of grid faints.
    color = ‘0.9’
  • Short color codes for RGB and CMYK. Here r: red, g: green, b: blue, c: cyan, m: magenta, y: yellow, k: black.
    color=’c’
  • HTML color names. You can find more HTML color names here.
    color=’tomato’
fig = plt.figure()
ax = plt.axes()
plt.grid(b=True, color='aqua', alpha=0.3, linestyle='-.', linewidth=2)
Photo by ©iambipin

The properties of various keyword arguments in the code are as mentioned below:

  • alpha is used to change the opacity of the grid. Zero value for alpha will make the grid disappear.
  • linestyle or ls will set the style of the grid lines. The suitable values for linestyle are ‘-’, ‘ — ‘, ‘-.’, ‘:’, ‘’, ‘solid’, ‘dotted’, ‘dashed’, ‘dashdot’ etc.
  • linewidth or lw determines the width of the grid lines.

Note: The curves or lines of the plot can be styled in the same way as the grid using the same keyword arguments.

import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 5, 100)
plt.plot(x, np.sin(x), color='Indigo', linestyle='--', linewidth=3)
plt.grid(b=True, color='aqua', alpha=0.3, linestyle='-.', linewidth=2)
plt.show()
Photo by ©iambipin

The color and linestyle arguments can be combined to give a non-keyword argument. For example, ‘c- -’ is equivalent to setting color=’cyan’, linestyle=’- -‘ as shown in the code below:

import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 5, 100)
plt.plot(x, np.sin(x), 'c--', linewidth=3)
plt.grid(b=True, color='aqua', alpha=0.3, linestyle='-.', linewidth=2)
plt.show()
Photo by ©iambipin

This combining of color and linestyle will not work with the grid().

Major and Minor Grid Lines

The major and minor grid lines can be shown by providing any of the three values to which parameter of plt.grid(). The three values are major, minor and both. These are self-explanatory as their name suggests. The value major displays major gridlines while minor displays minor grid lines. The value both displays both major and minor gridlines simultaneously.

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 7, 100)
line1, = plt.plot(x, np.sin(x), label='sin')
line2, = plt.plot(x, np.cos(x), label='cos')
plt.legend(handles=[line1, line2], loc='lower right')
#major grid lines
plt.grid(b=True, which='major', color='gray', alpha=0.6, linestyle='dashdot', lw=1.5)
#minor grid lines
plt.minorticks_on()
plt.grid(b=True, which='minor', color='beige', alpha=0.8, ls='-', lw=1)
plt.show()
Photo by ©iambipin

plt.minorticks_on() displays minor ticks on the axes. The disadvantage is that it may reduce performance.

Adding Labels for Axes

Start by adding labels to the x-axis and y-axis. matplotlib.pylot.xlabel() is for adding labels to the x-axis. Similarly, ylabel() is for assigning labels to the y-axis. The fontsize is a matplotlib text keyword argument( **kwargs) and is used to control the size of the fonts of the labels.

plt.xlabel('x', fontsize=15)
plt.ylabel('PDF', fontsize=15)

Setting Axes Limit

Matplotlib sets the default axes limits if the axes limits are not specifically mentioned. You can set the limits using plt.xlim() and plt.ylim for setting the x limits and y limits respectively.

Without Axes Limits | Photo by ©iambipin
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 5, 100)
plt.plot(x, np.sin(x), 'c--', linewidth=3)
plt.xlim(-1, 6)
plt.ylim(-1.25, 1.25)
plt.grid(b=True, color='aqua', alpha=0.3, linestyle='-.', linewidth=2)
plt.show()
With Axes Limits | Photo by ©iambipin

Tick Parameters

We have seen how to customize grids. Nevertheless, there is a plt.tick_params() method that can change the appearance of ticks, tick labels, and gridlines in one line of code. Isn’t it convenient and powerful?

import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 5, 100)
plt.plot(x, np.sin(x), 'c--', linewidth=3)
plt.xlim(-1, 6)
plt.ylim(-1.25, 1.25)
plt.tick_params(axis='both', direction='out', length=6, width=2, labelcolor='b', colors='r', grid_color='gray', grid_alpha=0.5)
plt.grid()
plt.show()
Photo by ©iambipin

The axis argument specifies the axis(‘x’, ‘y’ or ‘both’) upon which the parameters are to be applied. direction places the ticks inside the axes, outside the axes, or both. The arguments length and width give the tick length and tick width(in float). The labelcolor assigns colors to the tick labels. colors argument sets the tick color and label color. grid_color sets the color of the grid and grid_alpha sets the opacity of the grid. plt.tick_params() can be used to change the important properties of the grid like grid_color(color), grid_alpha(float), grid_linewidth(float) and grid_linestyle(string).

Adding Legend

Legend is the wording on the map or diagram explaining the symbols used in it. plt.legend() method places the legend on the axes.

When multiple lines are present in a plot, the code varies a bit from the usual practice.

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 7, 100)
line1, = plt.plot(x, np.sin(x), label='sin')
line2, = plt.plot(x, np.cos(x), label='cos')
plt.legend(handles=[line1, line2], loc='best')
plt.grid(b=True, color='aqua', alpha=0.6, linestyle='dashdot')
plt.show()
Photo by ©iambipin

The position of legend can also be changed by setting appropriate values(‘lower right’, ‘lower left’, ‘upper right’, ‘upper left’ etc.) to loc.

plt.legend(handles=[line1, line2], loc='lower right')
The position of legend changed to the lower right position | Photo by ©iambipin

Let’s add all these functions to our code and see how the Plot 1 looks like.

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
mu = 0
std = 1
x = np.linspace(start=-4, stop=4, num=100)
y = stats.norm.pdf(x, mu, std)
plt.plot(x, y, label='PDF')
plt.xlabel('x', fontsize=15)
plt.ylabel('PDF', fontsize=15)
plt.grid(b=True, which='major', color='DarkTurquoise', alpha=0.4, linestyle=':', linewidth=2)
plt.minorticks_on()
plt.grid(b=True, which='minor', color='beige', alpha=0.2, linestyle='-', linewidth=2)
plt.legend()
plt.show()
Plot 2: Normal Distribution | Photo by ©iambipin

Saving Figure to an Image

One of the several ways to save the plot as an image is to right-click on the plot and select ‘save image as’ option(the default option with any web browser).

Photo by ©iambipin

The other option is to use plt.savefig() method. It saves the current figure to the current working directory.

plt.savefig('Normal Distribution.png')

To verify that the image has been properly saved, Image object of the IPython.display module is used. The following code will display the image.

from IPython.display import Image
Image('normal_distribution.png')

Increasing the size of Figure

By default, the size of the plot displayed by Jupyter notebook is pretty small. The default size is only 640x480. However, the saved images have even smaller dimensions.

Photo by ©iambipin

The saved file has dimensions 433px * 288px. Let’s tweak the settings to have custom size for the plot. The plt.rcParams[] is for this specific purpose. The general syntax is as follows:

plt.rcParams['figure.figsize'] = [width, height]

Matplotlib uses matplotlibrc configuration files to customize all kinds of properties, which we call ‘rc settings’ or ‘rc parameters’. Defaults of almost every property in Matplotlib can be controlled: figure size and DPI, line width, color and style, axes, axis and grid properties, text and font properties and so on. Once the matplotlibrc file is found, it will not search for any other paths. The location of the currently active matplotlibrc file can be seen by typing the following commands:

import matplotlib
matplotlib.matplotlib_fname()

Conversion to Centimeters

The important point to be noted here is that the figsize tuple accepts the values in inches. To convert to centimeter(cm), divide the value by 2.54 will do the job, as 1 inch = 2.54 cm.

plt.rcParams['figure.figsize'] = [10/2.54, 8/2.54]

Now try to customize the size of Plot 2 using rcParams.

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
mu = 0
std = 1
x = np.linspace(start=-4, stop=4, num=100)
y = stats.norm.pdf(x, mu, std)
plt.plot(x, y, label='PDF')
plt.xlabel('x', fontsize=15)
plt.ylabel('PDF', fontsize=15)
plt.grid(b=True, color='DarkTurquoise', alpha=0.2, linestyle=':', linewidth=2)
plt.rcParams['figure.figsize'] = [10/2.54, 8/2.54]
plt.legend()
plt.show()
Plot 3: Normal Distribution | Photo by ©iambipin

Setting Style

plt.style.use() can be used to switch to easy-to-use plotting styles. The style package provides a wide array of preset styles thereby making the plots attractive and their dimensions larger.

plt.style.use('classic')

The plt.style.available command lists all the various styles that are available for use.

print(plt.style.available)['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright', 'seaborn-colorblind', 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'seaborn', 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']

Suppose if we wish to have a dark background for the plot, then plt.style.use(‘dark_background’) will serve the purpose.

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
plt.style.use('dark_background')
mu = 0
std = 1
x = np.linspace(start=-4, stop=4, num=100)
y = stats.norm.pdf(x, mu, std)
plt.plot(x, y, label='PDF')
plt.xlabel('x', fontsize=15)
plt.ylabel('PDF', fontsize=15)
plt.grid(b=True, color='DarkTurquoise', alpha=0.2, linestyle=':', linewidth=2)
plt.legend()
plt.show()
Plot 4: Normal Distribution | Photo by ©iambipin

Conclusion

Matplotlib is the de-facto Python visualization library. I have covered every important aspect of Pyplot to make your plots in Jupyter notebook stand out. I believe the information being shared here would make your plots more meaningful and beautiful.

The data visualization is witnessing change as I type. As someone rightly said, the only thing that never changes is the change itself. Permanence is an illusion. Many new tools are emerging like Tableau, Bokeh, Plotly, etc. Maybe in the future, a more feature-rich and technically sophisticated tool might replace Matplotlib as the numero uno tool. Till then, Happy Coding!!!

References

--

--