The world’s leading publication for data science, AI, and ML professionals.

Building Interactive Data Visualizations in Python: An Introduction to Plotly

Discover the power of interactive visualizations for Data Analysis and Machine Learning

Image by Gerd Altmann on Pixabay
Image by Gerd Altmann on Pixabay

Data visualization is one of the most important tasks for Data Professionals. It helps us, in fact, understand the data and to ask more questions for further investigations.

But data visualization is not only a task we have to conclude in the Exploratory Data Analysis phase. We may also need to present the data, often to an audience to help it grab some conclusions.

In Python, we generally use matplotlib and seaborn as libraries to plot our graphs.

However, sometimes we may need some interactive visualizations. In some cases, for a better understanding of the data. In some other cases, just to better present our solutions.

In this article, we’ll talk about plotly which is a Python library for making interactive visualizations.


What is Plotly?

As we can read on their website:

Plotly’s Python graphing library makes interactive, publication-quality graphs. Examples of how to make line plots, scatter plots, area charts, bar charts, error bars, box plots, histograms, heatmaps, subplots, multiple-axes, polar charts, and bubble charts.

Plotly.py is free and open source and you can view the source, report issues or contribute on GitHub.

So, Plotly is a free and open-source Python library for making interactive visualizations.

As we can see on their website, it gives us the possibility to create plots for different scopes: AI/ML, statistical, scientific, financial, and much more.

Since we’re interested in Machine Learning and Data Science, we’ll show some plots related to this field and how to create them in Python.

Finally, to install it, we have to type:

$ pip install plotly

1. Interactive Bubble charts

One interesting and useful feature of Plotly we can use is the possibility of interactive bubble plots.

In the case of bubble charts, in fact, sometimes the bubbles can intersect, making it difficult to read the data. If the graph is interactive, instead, we can read the data more easily.

Let’s see an example:

import plotly.express as px
import pandas as pd
import numpy as np

# Generate random data
np.random.seed(42)
n = 50
x = np.random.rand(n)
y = np.random.rand(n)
z = np.random.rand(n) * 100  # Third variable for bubble size

# Create a DataFrame
data = pd.DataFrame({'X': x, 'Y': y, 'Z': z})

# Create the scatter plot with bubble size with Plotly
fig = px.scatter(data, x='X', y='Y', size='Z',
      title='Interactive Scatter Plot with Bubble Plot')

# Add labels to the bubbles
fig.update_traces(textposition='top center', textfont=dict(size=11))

# Update layout properties
fig.update_layout(
    xaxis_title='X-axis',
    yaxis_title='Y-axis',
    showlegend=False
)

# Display the interactive plot
fig.show()

And we get:

The interactive bubble chart we've coded. Image by Federico Trotta.
The interactive bubble chart we’ve coded. Image by Federico Trotta.

So, we’ve created some data with NumPy and stored them in a Pandas data frame. Then, we’ve created the interactive plot with the method px.scatter() retrieving the data from the data frame, and specifying the title (as opposed to Matplotlib in which we insert the title outside the method used to create the plot itself).

2. Interactive correlation matrices

One of the tasks I sometimes struggle with is a proper visualization of correlation matrices. These, in fact, can sometimes be challenging to read and visualize when we have a lot of data.

One way to solve the problem is to use Plotly to create an interactive visualization.

For the scope, let’s create a Pandas data frame with 10 columns and create an interactive correlation matrix with Plotly:

import pandas as pd
import numpy as np
import plotly.figure_factory as ff

# Create random data
np.random.seed(42)
data = np.random.rand(100, 10)

# Create DataFrame
columns = ['Column' + str(i+1) for i in range(10)]
df = pd.DataFrame(data, columns=columns)

# Round values to 2 decimals
correlation_matrix = df.corr().round(2)  

# Create interactive correlation matrix with Plotly
figure = ff.create_annotated_heatmap(
    z=correlation_matrix.values,
    x=list(correlation_matrix.columns),
    y=list(correlation_matrix.index),
    colorscale='Viridis',
    showscale=True
)

# Set axis labels
figure.update_layout(
    title='Correlation Matrix',
    xaxis=dict(title='Columns'),
    yaxis=dict(title='Columns')
)

# Display the interactive correlation matrix
figure.show()

And we get:

The interactive correlation matrix we created. Image by Federico Trotta.
The interactive correlation matrix we created. Image by Federico Trotta.

So, in a very simple way, we can create an interactive correlation matrix with the method ff.create_annotated_map().

3. Interactive ML plots

In Machine Learning, we sometimes need to compare quantities graphically. And, in these cases, sometimes it’s hard to read our plots.

The typical case is a ROC/AUC curve where we compare the performance of different ML models. Sometimes, in fact, the curves intercepts and we’re not able to properly visualize them.

To improve our visualizations, we can use Plotly to create a ROC/AUC curve like so:

import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.metrics import roc_curve, auc
import plotly.graph_objects as go

# Create synthetic binary classification data
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)

# Scale the data using StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# Initialize the models
knn = KNeighborsClassifier()
rf = RandomForestClassifier()
dt = DecisionTreeClassifier()
svm = SVC(probability=True)

# Fit the models on the train set
knn.fit(X_train, y_train)
rf.fit(X_train, y_train)
dt.fit(X_train, y_train)
svm.fit(X_train, y_train)

# Predict probabilities on the test set
knn_probs = knn.predict_proba(X_test)[:, 1]
rf_probs = rf.predict_proba(X_test)[:, 1]
dt_probs = dt.predict_proba(X_test)[:, 1]
svm_probs = svm.predict_proba(X_test)[:, 1]

# Calculate the false positive rate (FPR) and true positive rate (TPR) for ROC curve
knn_fpr, knn_tpr, _ = roc_curve(y_test, knn_probs)
rf_fpr, rf_tpr, _ = roc_curve(y_test, rf_probs)
dt_fpr, dt_tpr, _ = roc_curve(y_test, dt_probs)
svm_fpr, svm_tpr, _ = roc_curve(y_test, svm_probs)

# Calculate the AUC (Area Under the Curve) for ROC curve
knn_auc = auc(knn_fpr, knn_tpr)
rf_auc = auc(rf_fpr, rf_tpr)
dt_auc = auc(dt_fpr, dt_tpr)
svm_auc = auc(svm_fpr, svm_tpr)

# Create an interactive AUC/ROC curve using Plotly
fig = go.Figure()
fig.add_trace(go.Scatter(x=knn_fpr, y=knn_tpr, name='KNN (AUC = {:.2f})'.format(knn_auc)))
fig.add_trace(go.Scatter(x=rf_fpr, y=rf_tpr, name='Random Forest (AUC = {:.2f})'.format(rf_auc)))
fig.add_trace(go.Scatter(x=dt_fpr, y=dt_tpr, name='Decision Tree (AUC = {:.2f})'.format(dt_auc)))
fig.add_trace(go.Scatter(x=svm_fpr, y=svm_tpr, name='SVM (AUC = {:.2f})'.format(svm_auc)))
fig.update_layout(title='AUC/ROC Curve',
                  xaxis=dict(title='False Positive Rate'),
                  yaxis=dict(title='True Positive Rate'),
                  legend=dict(x=0.7, y=0.2))
# Show plot
fig.show()

And we get:

The AUC/ROC curve we've created. Image by Federico Trotta.
The AUC/ROC curve we’ve created. Image by Federico Trotta.

So, with the method add.trace(go.Scatter) we’ve created scatterplots for each ML model we’ve used (KNN, SVM, Decision Tree, and Random Forest).

This way, it’s easy to display the details and the values of the zones where the curves intercept themselves.


Conclusions

In this article, we’ve shown a quick introduction to Plotly and we’ve seen how we can use it for better visualizations in data analysis and Machine Learning.

As we can see, this is a low-code library that, indeed, helps us better visualize our data, improving our results.


Federico Trotta
Federico Trotta

Hi, I’m Federico Trotta and I’m a freelance Technical Writer.

Want to collaborate with me? Contact me.


Related Articles