The world’s leading publication for data science, AI, and ML professionals.

3 Simple Ways to Create a Waterfall Plot in Python

Learn how to quickly create a presentation-ready plot to aid your data storytelling

Image generated using Midjourney
Image generated using Midjourney

Waterfall plots (or charts) are frequently used to demonstrate a cumulative change in a certain value over time. Alternatively, they can use fixed categories (for example, certain events) instead of time. As such, this kind of plot can be very useful while delivering presentations to business stakeholders, as we can easily show, for example, the evolution of our company’s revenue/customer base over time.

In this article, I will show you how to easily create waterfall charts in Python. To do so, we will be using 3 different libraries.

Setup and data

As always, we start with importing a few libraries.

import pandas as pd

# plotting
import matplotlib.pyplot as plt
import waterfall_chart
from waterfall_ax import WaterfallChart
import plotly.graph_objects as go

# settings
plt.rcParams["figure.figsize"] = (16, 8)

Then, we prepare fictional data for our toy example. Let’s assume that we are a data scientist in a startup that created some kind of mobile app. In order to prepare for the next all-hands meetings, we were asked to provide a plot showing the user base of our app in 2022. To deliver a complete story, we want to take into account the number of users at the end of 2021 and the monthly count in 2022. To do so, we prepare the following dataframe:

df = pd.DataFrame(
    data={
        "time": ["2021 end", "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"],
        "users": [100, 120, 110, 150, 160, 190, 240, 200, 230, 240, 250, 280, 300]
    }
)

Approach #1: waterfall_ax

We start with the simplest approach. I must say I was quite surprised to discover that Microsoft developed a small, matplotlib-based library to create waterfall plots. The library is called waterfall_ax and you can read more about it here. To generate a plot using our dataset, we need to run the following:

fig, ax = plt.subplots(1, 1, figsize=(16, 8))
waterfall = WaterfallChart(df["users"].to_list())
wf_ax = waterfall.plot_waterfall(ax=ax, title="# of users in 2022")

One thing to notice about the library is that it works with Python lists and it actually does not support pandas dataframes. That is why we have to use the to_list method while indicating the column with values.

While the plot is definitely presentable, we can do a bit better by including more information and replacing default step names. We do so in the following snippet.

fig, ax = plt.subplots(1, 1, figsize=(16, 8))
waterfall = WaterfallChart(
    df["users"].to_list(),
    step_names=df["time"].to_list(), 
    metric_name="# users", 
    last_step_label="now"
)
wf_ax = waterfall.plot_waterfall(ax=ax, title="# of users in 2022")

Approach #2: waterfall

A slightly more complex approach uses the [waterfall](https://github.com/chrispaulca/waterfall) library. In order to create a plot using that library, we need to add a column containing the deltas, that is, the differences between the steps.

We can easily add a new column to the dataframe and calculate the delta using the diff method. We fill in the NA value in the first row with the number of users from the end of 2021.

df_1 = df.copy()
df_1["delta"] = df_1["users"].diff().fillna(100)
df_1

Then, we can use the following one-liner to generate the plot:

waterfall_chart.plot(df_1["time"], df_1["delta"])

waterfall also offers the possibility to customize the plot. We do so in the following snippet.

waterfall_chart.plot(
    df_1["time"], 
    df_1["delta"], 
    threshold=0.2, 
    net_label="now", 
    y_lab="# users", 
    Title="# of users in 2022"
);

While most of the additions are quite self-explanatory, it is worth mentioning what the threshold argument does. It is expressed as a percentage of the initial value and it groups together all changes smaller than the indicated percentage into a new category. By default, that category is called other, but we can customize it with the other_label argument.

In comparison to the previous plot, we can see that the observations with a change of 10 are grouped together: 3 times a +10 and 1 time -10 give a net of +20.

This grouping functionality can be useful when we want to hide quite a lot of individually insignificant values. For example, such grouping logic is used in the shap library when plotting the SHAP values on a waterfall plot.

Approach #3: plotly

While the first two approaches used quite niche libraries, the last one will leverage a library you are surely familiar with –[plotly](https://plotly.com/). Once again, we need to do some preparations on the input dataframe to make it compatible with the plotly approach.

df_2 = df_1.copy()
df_2["delta_text"] = df_2["delta"].astype(str)
df_2["measure"] = ["absolute"] + (["relative"] * 12)
df_2

We created a new column called delta_text which contains the changes encoded as strings. We will use those as labels on the plot. Then, we also defined a measure column, which contains measures used by plotly. There are three types of measures accepted by the library:

  • relative – indicates changes in the sequence,
  • absolute— is used for setting the initial value or resetting the computed total,
  • total -used for computing sums.

Having prepared the dataframe, we can create the waterfall plot using the following snippet:

fig = go.Figure(
    go.Waterfall(
        measure=df_2["measure"],
        x=df_2["time"],
        textposition="outside",
        text=df_2["delta_text"],
        y=df_2["delta"],
    )
)

fig.update_layout(
    title="# of users in 2022",
    showlegend=False
)

fig.show()

Naturally, the biggest advantage of using the plotly library is the fact that the plots are fully interactive – we can zoom in, inspect tooltips for additional information (in this case, to see the cumulative sum), etc.

One clear difference from the previous plots is that we do not see the last block showing the net/total. Naturally, we can also add it using plotly. To do so, we must add a new row to the dataframe.

total_row = pd.DataFrame(
    data={
        "time": "now", 
        "users": 0, 
        "delta":0, 
        "delta_text": "", 
        "measure": "total"
    }, 
    index=[0]
)
df_3 = pd.concat([df_2, total_row], ignore_index=True)

As you can see, we do not need to provide concrete values. Instead, we provide the "total" measure, which will be used to calculate the sum. Additionally, we add a "now" label to get the same plot as before.

The code used for generating the plot did not actually change, the only difference is that we are using the dataframe with an additional row.

fig = go.Figure(
    go.Waterfall(
        measure=df_3["measure"],
        x=df_3["time"],
        textposition="outside",
        text=df_3["delta_text"],
        y=df_3["delta"],
    )
)

fig.update_layout(
    title="# of users in 2022",
    showlegend=False
)

fig.show()

You can read more about creating waterfall plots in plotly here.

Wrapping up

  • We showed how to easily and quickly prepare waterfall plots in Python using three different libraries: waterfall_ax, waterfall, and plotly.
  • While creating your plots, it is worth remembering that different libraries use different types of inputs (either raw values or deltas).

As always, any constructive feedback is more than welcome. You can reach out to me on Twitter or in the comments. You can find the code used in this article on GitHub.

Liked the article? Become a Medium member to continue learning by reading without limits. If you use this link to become a member, you will support me at no extra cost to you. Thanks in advance and see you around!

You might also be interested in one of the following:

Introducing the second edition of Python for Finance Cookbook

Turn VS Code into a One-Stop Shop for ML Experiments

5 types of plots that will help you with time series analysis

Violin plots explained


All images, unless noted otherwise, are by the author.


Related Articles