The world’s leading publication for data science, AI, and ML professionals.

Heterogeneous Treatment Effect Using Double Machine Learning

Causal Inference using Double Machine Learning

Photo by Markus Winkler on Unsplash
Photo by Markus Winkler on Unsplash

This article is the continuation of the Causal Inference series linked below. I strongly recommend going through the below articles first if you are new to the causal inference space.

1 . Getting started with Causal Inference

  1. Methods for inferring Causality
  2. Heterogeneous Treatment Effect using Meta learners

Double Machine Learning (DML) is the method for estimating heterogeneous treatment effects especially when we have a large number of confounders. It also works with continuous treatment, unlike the meta-learners we discussed in the previous article. The method was introduced in the paper Double Machine for treatment and causal parameters which won the 2020 Nobel Prize in Economics.

DML is based on this cool regression concept which states if you have two sets of features X₁ and X₂, and you estimate the model parameters β₁ and β₂ using linear regression Y = β₁X₁ + β₂X₂, then **** the same set of parameters β₁can be obtained by following these steps

  1. Regress Y on the second set of features Y₁ᵣ= γ₁X₂
  2. Regress the first set of features on the second set of features X₁ᵣ = γ₂ X₂
  3. Obtain residuals Yᵣ = Y- Y₁ᵣ, Xᵣ = X₁ – X₁ᵣ
  4. Regress the outcome residuals on feature residuals Yᵣ = α+ β₁Xᵣ

We get the same β₁as we will get by regressing all features together. Pretty cool, huh? The first set of features (X₁ above) can be our treatment and the second set of features can be confounders whose impact on the outcome can be estimated separately.

Double/Debiased Machine learning

DML uses this very simple idea explained above, but instead of linear regression, it uses ML models, so that we can deal with non-linearities and interactions in estimating outcome and treatment residuals.

Yᵢ -Mᵧ(Xᵢ) = τ(Tᵢ-Mₜ(Xᵢ)) + εᵢ

Mᵧ(Xᵢ) is the outcome obtained by regressing Y on X and Mₜ(Xᵢ) is the outcome obtained by regressing T on X. τ is the impact of treatment on the outcome after controlling for confounders, but it is a constant and not a function of any covariate, that is, its represent ATE(Average treatment effect discussed at length in previous articles Getting started with Causal Inference and Methods for Inferring Causality)

We can also estimate CATE, where τ is allowed to vary with the unit’s covariates

Yᵢ -Mᵧ(Xᵢ) = τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)) + εᵢ

To estimate CATE using DML, we will use the residuals of treatment and outcome, but now we will interact with the treatment residuals with other covariates.

Yᵣ = β₁ + β₂XᵢTᵢᵣ

One issue that DML can suffer from is overfitting. If Mᵧ model overfits, the residual Yᵣ would be smaller than it should be, capturing more than the relationship between X and Y, and a part of that something more would be the relationship between T and Y, which will make out a causal estimate(residual regression) biased towards zero. If Mₜ overfits, the treatment residual would have less variance than it should. It is as if the treatment is the same for everyone, making it very difficult to estimate what would happen under different treatment regimes. It is therefore recommended to use cross-validation in the outcome and treatment models

Let’s code this up and understand this through an example. We will be using the synthetic sales data of a juice brand. Our goal is to predict price elasticity with respect to temperature.

Instead of developing our own DML function, let us now use DML function in the EconML library, developed by Microsoft.I strongly recommend going through EconML documentation on DML as they have different versions of DML implemented in their library.

Non-Parametric DML

In the above setup, we are using a linear model on the residuals, which means we are assuming that the effect of treatment on the outcome is linear, and this assumption might not always hold.

We have seen this equation

Yᵢ -Mᵧ(Xᵢ) = τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)) + εᵢ

If we rearrange, we get the error term as

εᵢ = (Yᵢ -Mᵧ(Xᵢ)) – τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ))

This is what we call the causal loss function, and we can minimize this to obtain τ(Xᵢ), our CATE.

Lₙ(τ(x))= 1/n Σ ((Yᵢ -Mᵧ(Xᵢ)) – τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)))²

Using the residual terms, we get this

Lₙ(τ(x))= 1/n Σ(Yᵢᵣ – τ(Xᵢ)Tᵢᵣ)²

Taking out Tᵢᵣ and isolating τ(Xᵢ)

Lₙ(τ(x))= 1/n Σ Tᵢᵣ²(Yᵢᵣ/Tᵢᵣ – τ(Xᵢ))²

We can minimize the above loss by minimizing what’s inside the parenthesis and using Tᵢᵣ² as weights. Here are the steps

  1. Create Tᵢᵣ²
  2. Create Yᵢᵣ/Tᵢᵣ
  3. Use any ML model to predict Yᵢᵣ/Tᵢᵣ using Tᵢᵣ² as weight.

Let’s predict the cate for the above example using non-parametric DML.

DML is commonly used for price elasticity problems in the industry, as it works well with continuous treatment. Below are a few case studies where Double Machine learning has been used to solve industry-relevant problems

Case studies

Customer Segmentation at an online media company using DML

Pricing with DML approach to Causal Inference

Other customer scenarios for Machine Learning based Heterogeneous Treatment Effect


References


Related Articles