The world’s leading publication for data science, AI, and ML professionals.

Visualizations with Matplotlib

Learn how to visualize data using the Python library, matplotlib

Photo by Raymond Pang on Unsplash
Photo by Raymond Pang on Unsplash

Matplotlib is a very powerful Data Visualization library if you want to plot data with Python. The most used module is ‘pyplot‘ which provides a collection of functions that lets you easily plot out data.

Installation

Using conda:

conda install matplotlib

Using pip:

pip install matplotlib

As a prerequisite, you must have the basic knowledge of what Python dictionaries and lists are, as well as features and functions of Python’s Numpy library.

Simple line graph

Import the necessary modules.

import numpy as np
import matplotlib.pyplot as plt

Simple line graphs are the most commonly plotted graphs and are generally used to plot data that is tracked over a period of time like stock prices.

Lets create the same using pyplot by passing in a simple python list.

plt.plot([2,4,5,10])
Image by author
Image by author

Here you can see the plot of our 4 values.

If you provide a single list of values to the plot function, matplotlib will assume that they are y values and automatically generates the x values for you. The default x vector has the same length as the list you pass into the function but it starts with 0. Above we can see the x plot values are 0, 1, 2, and 3.

To plot x versus y, you can pass 2 lists into the function. First of which is the x vector and then the y vector.

plt.plot([1,3,7,10], [2,4,5,10])
Image by author
Image by author

Here we can see a different set of x values for our y values. The first x value corresponds to the first y value and so on.

Scatter plots

Scatter plots use dots to represent values for two distinct numeric variables.

You can pass in a third argument to the plot function to indicate the format of the color and the line type of the plot.

To obtain a scatter plot with in the color red, we add ‘ro’ as the third argument.

plt.plot([1,3,7,10], [2,4,5,10], 'ro')
Image by author
Image by author

So here we can see that our plot is converted from a blue line to red circles.

Generally, we use NumPy arrays instead of Python lists while working with plots in matplotlib as show below.

a = np.array([[12,23,34,45],
              [56,67,78,89]])
plt.plot(a[0], a[1], 'go')
Image by author
Image by author

We can also add labels to our axes.

plt.xlabel('x-axis')
plt.ylabel('y-axis')
plt.plot(a[0], a[1], 'p--')
Image by author
Image by author

Multiple sets of data in a graph

To plot multiple sets of data on a graph, pass multiple sets of arguments in the plot function. Let’s create an array that has a few evenly spaced numbers to understand how we can do this.

t = np.linspace(0, 5, 10) 
plt.plot(t, t**2, color = 'red', linestyle = '--') 
Image by author
Image by author

‘linspace’ is a NumPy function and ‘linestyle’ is a pyplot parameter.

If I wanted to plot another line on this line, I can easily do that by calling the plot function again.

plt.plot(t, t**2, color = 'red', linestyle = '--')
plt.plot(t, t*3, color='green', linewidth = 2)
Image by author
Image by author

Here, ‘linewidth’ is a pyplot parameter. As you can see the second line is printed on top of our previous line.

Plotting multiple graphs

We can use the subplot function to add more than one plot in one figure. This option takes three arguments: the number of rows, the number of columns and the index.

plt.subplot(1,3,1)
plt.plot(t,t,'c--')
Image by author
Image by author

Using the same subplot method we can create more plots of different types right next to our existing plot.

plt.subplot(1,3,1)
plt.plot(t,t,'c--')
plt.subplot(1,3,2)
plt.plot(t,t**2,'b-')
plt.subplot(1,3,3)
plt.plot(t,t**3,'mo')
Image by author
Image by author

Bar graphs

Furthermore, we can plot data using bar graphs which are one of the most common types of graphs. Bar graphs are used to show data that are split into different categories. For example, if I want to sales of different products I would use a bar graph to do that.

sales = {'Computers': 92,
         'Laptops': 76,
         'Mobiles': 97,
         'TVs': 82,
         'Speakers': 70}
plt.bar(range(len(sales)), list(sales.values()), color = 'pink')
Image by author
Image by author

In our example, the x coordinates range from 0 to 4 and y coordinates shows the height of each bar. We can see a bar graph of the values in our dictionary. The height of each bar represents the sales made in that category.

If we want to change these numbers to the names of the categories we can do it as shown below.

plt.bar(range(len(sales)), list(sales.values()), color = 'pink')
plt.xticks(range(len(sales)), list(sales.keys()))
Image by author
Image by author

Here we can see that under each bar, we have the corresponding names of the categories. Similarly, if we want to change the x and y labels we can use the same ‘plt.xlabel’ and ‘plt.ylabel’ functions.

plt.bar(range(len(sales)), list(sales.values()), color='pink')
plt.xticks(range(len(sales)), list(sales.keys()))
plt.xlabel("Categories")
plt.ylabel("Sales")
Image by author
Image by author

The last type of graphs that we will look at in this article are histograms. Histograms are a very common type of graph when we’re looking at data like height and weight, stock prices, etc. In our example, let’s generate a random continuous data of about 1000 entries.

x = np.random.randn(1000)
plt.hist(x, color = 'green')
Image by author
Image by author

Here we can see our histogram has a distribution of numbers from -4 to 3. And it peaks around 0 which means the bulk of our area lies there.

That’s it for this article.

Click here to view the code for this tutorial.

Thank you for giving it a read!


References

[1] Matplotlib documentation: https://matplotlib.org/


Related Articles