Matplotlib+ Seaborn + Pandas: An Ideal Amalgamation for Statistical Data Visualisation

Meet Desai
Towards Data Science
11 min readSep 29, 2019

--

Exploratory Data Analysis involves two fundamental steps

  1. Data Analysis (Data Pre processing, Cleaning and Manipulation).
  2. Data Visualisation (Visualise relationships in data using different types of plots).

Pandas is the most commonly used library for data analysis in python.
There are tons of libraries available in python for data visualisation and among them, matplotlib is the most commonly used. Matplotlib provides full control over the plot to make plot customisation easy, but what it lacks is built in support for pandas. Seaborn is a data visualisation library built on top of matplotlib and closely integrated with pandas.

This post would cover,

  1. Different types of plots available in seaborn.
  2. How the integration of pandas with seaborn helps in making complex multidimensional plots with minimal code?
  3. How to customise plots made using seaborn, with the help of matplotlib?

Who should read this post?

If you have working knowledge of matplotlib and pandas, and want to explore seaborn, this is a good place to start. If you are just starting with python, I would suggest to come back here after getting a basic idea about matplotlib and pandas.

1. Matplotlib

Although many tasks can be accomplished using just the seaborn functions, it is essential to understand the basics of matplotlib for two main reasons:

  1. Behind the scenes, seaborn uses matplotlib to draw the plots.
  2. Some customisation might require direct use of matplotlib.

Here is a quick review of matplotlib basics. The following figure shows the anatomy of a matplotlib Figure.

The three main classes to understand are Figure, Axes and Axis

Figure

It refers to the whole figure that you see. It is possible to have multiple sub-plots (Axes) in the same figure. In the above example, we have four sub-plots (Axes) in a single figure.

Axes

An Axes refers to the actual plot in the figure. A figure can have multiple Axes but a given Axes can be part of only one figure. In the above example, we have four Axes in one Figure

Axis

An Axis refers to an actual axis (x-axis/y-axis) in a specific plot.

Each example in the post assumes that the required modules and data set have been loaded as shown here

Let us try to understand Figure and Axes class with an example

plt.subplots() creates a single Figure instance, (nrows * ncols) Axes instances and returns the created Figure and Axes instances . In the above example, since we have passed nrows=1 and ncols=1, it creates only a single Axes instance. If nrows > 1 or ncols > 1, it creates a grid of Axes and returns them in an (nrows * ncols) shaped numpy array.

Most frequently used methods of Axes class for customisation are

Here is an example which uses some of the above methods to make some customisation

Now that we have reviewed the basics of matplotlib, lets move on to seaborn

2. Seaborn

Each plotting function in seaborn is either a figure-level function or an Axes-level function, and it is essential to understand the difference between the two. As mentioned earlier, a Figure refers to the whole figure that you see, where as an Axes refers to a specific subplot in the figure. An Axes-level function draws onto a single matplotlib Axes and does not effect the rest of the figure. A figure-level function, on the other hand, controls the entire figure. One way to think about this is, a figure-level function can call different Axes-level functions to draw different types of subplots on different Axes

2.1 Axes Level Functions

Here is an exhaustive list of all Axes-level functions in seaborn

There are two main things to understand to use any Axes-level function

  1. Different ways to give input data to the Axes-level function.
  2. Specifying the Axes to be used to make the plot.

2.1.1 Different Ways to give input data to the Axes-level function

There are three different ways to pass data to Axes-level function

  1. Lists, Arrays or Series

The most common way to pass data to an Axes-level function is using iterables like lists, arrays or series

2. Using pandas Dataframe and names of columns.

One of the main reasons for seaborn’s popularity, is its ability to directly work with pandas Dataframes. In this method for passing data, column names should be passed to x and y parameters and Dataframe should be passed to data parameter

3. Passing only the Dataframe

In this method for passing data, only the Dataframe is passed to data parameter. Each numeric column in the dataset will be plotted using this method. This method can be used only with following Axes-level functions

A specific use case for this method of passing input data is comparing the distribution of multiple numeric variables in a dataset, using any of the above mentioned Axes-level functions

sns.boxplot(data=iris)

2.1.2 Specifying the axes to be used to make the plot

Each Axes-level function in seaborn takes an explicit ax argument. The Axes passed to the ax argument will be then used to make the plot. This provides great flexibility in terms of controlling which Axes is to be used for plotting.
For example, let’s say we want to look at the relationship between total bill and tip (using a scatter plot) as well as their distribution (using box plot) in the same figure but on different Axes.

Each Axes-level function also returns the Axes on which the plot has been made. If an Axes has been passed to ax argument the same Axes object will be returned. The returned Axes object can then be used for further customisation using different methods like Axes.set_xlabel(), Axes.set_ylabel() etc

If no Axes is passed to the ax argument, seaborn uses the current (active) Axes to make the plot.

In the above example even though we haven’t explicitly passed curr_axes (currently active axes) to ax argument, seaborn still uses it to make the plot, since it is the currently active Axes.
id(curr_axes) == id(scatter_plot_axes) returns True indicating that they are the same Axes.

If no Axes is passed to ax argument and there is no currently active Axes object, seaborn creates a new Axes object to make the plot and then returns that Axes object

The Axes-level functions in seaborn do not have any direct parameter to control the figure size. However, since we can specify which Axes is to be used for plotting, by passing the Axes in ax argument, we can control the figure size as follows

2.2 Figure Level Functions

When exploring a multi dimensional dataset, one of the most common use case for data visualisation, is drawing multiple instances of same plot on different subsets of data. The figure-level functions in seaborn are tailor made for this use case. A figure-level function has complete control over the entire figure and each time a figure level function is called, it creates a new figure which can include multiple Axes, all organised in a meaningful way. The three most generic figure-level functions in seaborn are FacetGrid, PairGrid, JointGrid

2.2.1 FacetGrid

Consider a following use case, we want to visualise the relationship between total bill and tip (via a scatter plot) on different subsets of data. Each subset of data is categorised by a unique combination of values for following variables
1. day (Thur, Fri Sat, Sun)
2. smoker (whether the person is a smoker or not)
3. sex (male or female)
This can easily be done in matplotlib as follows

The above code can be broken down into three steps:

  1. Create an Axes (subplot) for each subset of data
  2. Divide the dataset into subsets
  3. On each Axes, draw the scatter plot using subset of data
    corresponding to that Axes

Step 1 can be done in seaborn using FacetGrid()
Step 2 and Step 3 can be done using FacetGrid.map()

Using FacetGrid, we can create Axes for dividing the dataset upto three dimensions using row, col and hue parameters.
Once we have created a FacetGrid, we can plot the same kind of plot on all Axes using FacetGrid.map() by passing the type of plot as an argument. We also need to pass the name of columns to be used for plotting.

Thus “Matplotlib offers good support to make plots with multiple Axes but seaborn builds on top of it by directly linking the structure of plot with the structure of dataset”. Using FacetGrid, we neither have to explicitly create Axes for each subset nor do we have to explicitly divide the data into subsets. That is done internally by FacetGrid() and FacetGrid.map() respectively.
We can pass different Axes level function to FacetGrid.map().

Also, seaborn provides three Figure-Level functions (high level interfaces) which use FacetGrid() and FacetGrid.map() in the background.
1. relplot()
2. catplot()
3. lmplot()
Each of the above figure level function use FacetGrid() to create multiple Axes, and take an Axes-level function in kind argument, which is then passed to FacetGrid.map() internally. So the above three functions are different in terms of what Axes-level functions can be passed to each one of them.

Explicitly using FacetGrid provides more flexibility than directly using high level interfaces like relplot(), catplot() or lmplot(); for example, with FacetGrid(), we can also pass custom functions to FacetGrid.map() but with high level interfaces you can use only the built in Axes-level functions in kind argument. If you do not need that flexibility, you can directly use the high level interfaces

Each of the above three figure level functions as well as FacetGrid returns an instance of FacetGrid. Using FacetGrid instance, we can get access to individual Axes which can then be used to tweak the plot (like adding axis labels, titles etc). Also, controlling the size of figure level functions is different compared to controlling the size of matplotlib figures. Instead of setting the overall figure size, we can set the height and aspect of each Facet (subplot) using the height and aspect parameters.

Refer FacetGrid for more examples.

2.2.2 PairGrid

PairGrid is used to plot pairwise relationships between variables in a dataset. Each subplot shows a relationship between a pair of variables. Consider a following use case, we want to visualise relationship (via scatter plot) between every pair of variables. This can be easily done in matplotlib as follows

The above code can be broken down into two steps

  1. Create an Axes for each pair of variables
  2. On each Axes, draw the scatter plot using the data
    corresponding to that pair of variables

Step 1 can be done using PairGrid()
Step 2 can be done using PairGrid.map() .

Thus PairGrid() creates Axes for each pair of variables and PairGrid.map() draws the plot on each Axes using data corresponding to that pair of variables. We can pass different Axes-level function to PairGrid.map()

It does not make sense to plot a scatter plot on the diagonal Axes. It is possible to plot one kind of plot on diagonal Axes and another kind of plot on non-diagonal Axes.

It is also possible to draw different kind of plots on Upper Triangular Axes, Diagonal Axes and Lower Triangular Axes.

Seaborn also provides a high level interface pairplot() to plot pairwise relationships of variables if you don’t need all the flexibility of PairGrid(). It uses PairGrid() and PairGrid.map() in the background.

sns.pairplot(data=iris)

Both PairGrid() and PairPlot() return an instance of PairGrid(). Using PairGrid() instance, we can get access to individual Axes which can then be used to tweak plot like adding axis labels, titles etc

Refer PairGrid for more examples

2.2.3 JointGrid

JointGrid is used when we want to plot a bi-variate distribution along with marginal distributions in the same plot. Joint Distribution of two variables can be visualised using scatter plot/regplot or kdeplot. Marginal Distribution of variables can be visualised by histograms and/or kde plot. The Axes-level function to use for joint distribution must be passed to JointGrid.plot_joint(). The Axes-level function to use for marginal distribution must be passed to JointGrid.plot_marginals()

If you don’t need all the flexibility of JointGrid(), seaborn also provides a high level interface jointplot() to plot bi-variate distribution along with marginal distributions. It uses JointGrid() and JointGrid.plot_joint() in the background.

Both JointGrid() and jointplot() return an instance of JointGrid(). Using JointGrid() instance, we can get access to individual Axes which can then be used to tweak plots like adding labels, title etc

Refer JointGrid for more examples

Summary

Integration of seaborn with pandas helps in making complex multidimensional plots with minimal code. Each plotting function in seaborn is either an Axes-level function or a figure-level function. An Axes-level function draws onto a single matplotlib Axes and does not effect the rest of the figure. A figure-level function, on the other hand, controls the entire figure.Hers is a quick summary for Axes-level and Figure-level functions

Axes-Level

Figure-Level

Suggestions/tips to make this article better are welcome. Thanks for reading!!!

--

--