
Exploratory data analysis (EDA) is an essential part of the Data Science or the machine learning pipeline. In order to create a robust and valuable product using the data, you need to explore the data, understand the relations among variables, and the underlying structure of the data. One of the most effective tools in EDA is data visualization.
Data visualizations tell us much more than plain numbers. They are also more likely to stick to your head. In this post, we will try to explore a customer churn dataset using the power of visualizations.
We will create many different visualizations and, on each one, try to introduce a feature of Matplotlib or Seaborn library.
We start with importing related libraries and reading the dataset into a pandas dataframe.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid')
%matplotlib inline
df = pd.read_csv("/content/Churn_Modelling.csv")
df.head()

The dataset contains 10000 customers (i.e. rows) and 14 features about the customers and their products at a bank. The goal here is to predict whether a customer will churn (i.e. exited = 1) using the provided features.
Let’s start with a catplot which is a categorical plot of the Seaborn library.
sns.catplot(x='Gender', y='Age', data=df, hue='Exited', height=8, aspect=1.2)

Finding: People between the ages of 45 and 60 are more likely to churn (i.e. leave the company) than other ages. There is not a considerable difference between females and males in terms of churning.
The hue parameter is used to differentiate the data points based on a categorical variable.
The next visualization is the scatter plot which shows the relationship between two numerical variables. Let’s see if the estimated salary and balance of a customer are related.
plt.figure(figsize=(12,8))
plt.title("Estimated Salary vs Balance", fontsize=16)
sns.scatterplot(x='Balance', y='EstimatedSalary', data=df)

We first used matplotlib.pyplot interface to create a Figure object and set the title. Then, we drew the actual plot on this figure object with Seaborn.
Finding: There is not a meaningful relationship or correlation between the estimated salary and balance. Balance seems to have a normal distribution (excluding the customers with zero balance).
The next visualization is the boxplot which shows the distribution of a variable in terms of median and quartiles.
plt.figure(figsize=(12,8))
ax = sns.boxplot(x='Geography', y='Age', data=df)
ax.set_xlabel("Country", fontsize=16)
ax.set_ylabel("Age", fontsize=16)

We also adjusted the font sizes of x and y axes using set_xlabel and set_ylabel.
Here is the structure of boxplots:

Median is the point in the middle when all points are sorted. Q1 (first or lower quartile) is the median of the lower half of the dataset. Q3 (third or upper quartile) is the median of the upper half of the dataset.
Thus, boxplots give us an idea about the distribution and outliers. In the boxplot we created, there are many outliers (represented with dots) on top.
Finding: The distribution of the age variable is right-skewed. The mean is greater than the median due to the outliers on the upper side. There is not a considerable difference between countries.
Right-skewness can also be observed in the univariate distribution of a variable. Let’s create a distplot to observe the distribution.
plt.figure(figsize=(12,8))
plt.title("Distribution of Age", fontsize=16)
sns.distplot(df['Age'], hist=False)

The tail on the right side is heavier than the one on the left. The reason is the outliers as we also observed on the boxplot.
The distplot also provides a histogram by default but we changed it using the hist parameter.
Seaborn library also provides different types of pair plots which give an overview of pairwise relationships among variables. Let’s first take a random sample from our dataset to make the plots more appealing. The original dataset has 10000 observations and we will take a sample with 100 observations and 4 features.
subset=df[['CreditScore','Age','Balance','EstimatedSalary']].sample(n=100)
g = sns.pairplot(subset, height=2.5)

On the diagonal, we can see the histogram of variables. The other part of the grid represents pairwise relationships.
Another tool to observe pairwise relationships is the heatmap which takes a matrix and produces a color encoded plot. Heatmaps are mostly used to check correlations between features and the target variable.
Let’s first create a correlation matrix of some features using the corr function of pandas.
corr_matrix = df[['CreditScore','Age','Tenure','Balance',
'EstimatedSalary','Exited']].corr()
We can now plot this matrix.
plt.figure(figsize=(12,8))
sns.heatmap(corr_matrix, cmap='Blues_r', annot=True)

Finding: The "Age" and "Balance" columns are positively correlated with customer churn ("Exited").
As the amount of data increases, it gets trickier to analyze and explore it. There comes the power of visualizations which are great tools in exploratory data analysis when used efficiently and appropriately. Visualizations also help to deliver a message to your audience or inform them about your findings.
There is no one-fits-all kind of visualization method so certain tasks require different kinds of visualizations. Depending on the task, different options may be more suitable. What all visualizations have in common is that they are great tools for exploratory data analysis and the storytelling part of data science.
Thank you for reading. Please let me know if you have any feedback.