The world’s leading publication for data science, AI, and ML professionals.

Let’s Revisit Case-When in Different Libraries Including the New Player: Pandas

How to create conditional columns with different tools.

Photo by JESHOOTS.COM on Unsplash
Photo by JESHOOTS.COM on Unsplash

Whether you’re working on data analysis, Data Cleaning, or even feature engineering, it’s a frequently done operation to create a new column based on the values in other columns.

All the tools I’ve used for data cleaning and manipulation have functions for this task (e.g. SQL, R data table, PySpark). We now have a new player in the game: Pandas.

By the way, it was possible to create conditional columns with Pandas but it did not have a dedicated case-when function.

With Pandas 2.2.0, the case_when function has been introduced to create a Series object based on one or more conditions.

Let’s revisit how this super helpful operation is done with the commonly used data analysis and manipulation tools.

To keep it consistent and easier to spot differences among tools, we’ll use a small dataset.


SQL

The following is a small SQL table called "mytable".

+-------------+----------+---------+
|           a |        b |       c |
+-------------+----------+---------+
|           0 |        5 |       1 |
|           1 |       -1 |       0 |
|           5 |       20 |       0 |
|           4 |        8 |       1 |
|           4 |        4 |       1 |
|          10 |        7 |       0 |
|           4 |        2 |       1 |
+-------------+----------+---------+

We’ll create a new column based on the values in the existing columns. Here are the conditions:

  • If column a is greater than column b, then takes the value in column a
  • If column a is less than column b, then takes the product of column a and column c
  • Else (i.e. column a is equal to column b), takes the sum of column a and column b

We can create as many conditions as needed and also make them a lot more complex but these three conditions are enough to learn the case when logic.

Let’s call the new column "d" and here is the SQL code to create this column based on the two conditions above:

select 
  (case 
     when a > b then a
     when a < b then a * c
     else a + b end
  ) as d
from mytable

This SQL code only creates and selects column d from "mytable". If you also need it to return the other columns (i.e. a, b, and c), just write them in the select statement:

select 
  a,
  b,
  c,
  (case 
     when a > b then a
     when a < b then a * c
     else a + b end
  ) as d
from mytable

R data table

The data table package is a highly efficient data analysis and manipulation tool for the R programming language.

We’ll now learn how to create the conditional column d using this package. Let’s first create a data table that contains the same columns as in our SQL table.

mytable <- data.table(
      a=c(0, 1, 5, 4, 4, 10, 4), 
      b=c(5, -1, 20, 8, 4, 7, 2), 
      c=c(1, 0, 0, 1, 1, 0, 1)
)

The case when logic in data table can be implemented using the fcase function. We can write the conditions followed by the corresponding values separated by comma.

Here is how we can create the column d based on the given conditions before:

mytable[, d := (fcase(a > b, a, a < b, a*c, a==b, a+b))]

The first expression (a > b) inside the fcase function is the first condition and the second expression (a) is its corresponding value. The third expression (a < b) is the second condition and the fourth expression (a*c) is its corresponding value, and so on.

Now the data table "mytable" looks as follows:

    a  b c  d
1:  0  5 1  0
2:  1 -1 0  1
3:  5 20 0  0
4:  4  8 1  4
5:  4  4 1  8
6: 10  7 0 10
7:  4  2 1  4

PySpark

PySpark is the Python API for Spark, which is an analytics engine used for large-scale data processing. When it comes to working on datasets with billions of rows, Spark is usually the preferred tool.

PySpark API is very intuitive and has easy-to-understand syntax. Let’s first create a Spark DataFrame that contains the same columns and values as before.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()

data = [
    {"a": 0, "b": 5, "c": 1},
    {"a": 1, "b": -1, "c": 0},
    {"a": 5, "b": 20, "c": 0},
    {"a": 4, "b": 8, "c": 1},
    {"a": 4, "b": 4, "c": 1},
    {"a": 10, "b": 7, "c": 0},
    {"a": 4, "b": 2, "c": 1}
]

mytable = spark.createDataFrame(data)

We first started a spark session and then created the DataFrame. Note that if you’re working in a development environment like Databricks, you won’t have to explicitly create a spark session.

We can use the withColumn function to create a new column and to determine its values based on multiple conditions we’ll use the when function.

mytable = (
    mytable
    .withColumn("d",
                F.when(F.col("a") &gt; F.col("b"), F.col("a"))
                 .when(F.col("a") &lt; F.col("b"), F.col("a") * F.col("c"))
                 .otherwise(F.col("a") + F.col("b")))
)

mytable.show()

+---+---+---+---+
|  a|  b|  c|  d|
+---+---+---+---+
|  0|  5|  1|  0|
|  1| -1|  0|  1|
|  5| 20|  0|  0|
|  4|  8|  1|  4|
|  4|  4|  1|  8|
| 10|  7|  0| 10|
|  4|  2|  1|  4|
+---+---+---+---+

As with the previous tools, we can combine as many conditions as needed. Each condition will have its own when function and the value for the final condition (i.e. if none of the previous conditions are met) is specified using the otherwise function.


Pandas

Pandas might be the most frequently used data analysis and manipulation tool. Before version 2.2.0, we did not have a case when function in Pandas. We could still handle these operations using other functions such as Pandas where function or NumPy where and select functions. However, we now have a case_when function in Pandas so let’s see how it is used to do the task we’ve been demonstrating with other tools in this article.

Let’s first create the Pandas DataFrame.

import pandas as pd

mytable = pd.DataFrame(
    {
        "a": [0, 1, 5, 4, 4, 10, 4],
        "b": [5, -1, 20, 8, 4, 7, 2],
        "c": [1, 0, 0, 1, 1, 0, 1]
    }
)

The case_when function takes a case list as argument, which contains the conditions and corresponding values as tuples.

caselist = [
    (mytable["a"] &gt; mytable["b"], mytable["a"]),
    (mytable["a"] &lt; mytable["b"], mytable["a"] * mytable["c"]),
    (mytable["a"] == mytable["b"], mytable["a"] + mytable["b"])
]

mytable.loc[:, "d"] = mytable["a"].case_when(caselist)

mytable

    a  b  c  d
0   0  5  1  0
1   1 -1  0  1
2   5 20  0  0
3   4  8  1  4
4   4  4  1  8
5  10  7  0 10
6   4  2  1  4

Each tuple in the caselist contains a condition and its corresponding value. It’s important to note then the case_when function is applied to a column in a DataFrame. If there are any rows for which none of the given conditions are met, then that row in the new column takes the value from the original column.

In the example above, if there were such rows, the corresponding value in column d would be taken from column a. Let’s try it by removing the final condition in the caselist .

caselist = [
    (mytable["a"] &gt; mytable["b"], mytable["a"]),
    (mytable["a"] &lt; mytable["b"], mytable["a"] * mytable["c"])
]

mytable.loc[:, "d"] = mytable["a"].case_when(caselist)

mytable

    a  b  c  d
0   0  5  1  0
1   1 -1  0  1
2   5 20  0  0
3   4  8  1  4
4   4  4  1  4
5  10  7  0 10
6   4  2  1  4

See the values in row number 4. The value in column d is the same as the value in column a.


Final words

There are a lot of tools and libraries for data cleaning, analysis, and manipulation. In most cases, which tool you’re using is a matter of choice. All these tools are capable of doing the typical tasks efficiently. However, it’s still good to know how to do certain operations using different tools.

In this article, we learned how to create conditional columns using SQL, PySpark, R data table, and Pandas.

Thank you for reading. Please let me know if you have any feedback.


Related Articles