
This article is the continuation of the Causal Inference series linked below. I strongly recommend going through the below articles first if you are new to the causal inference space.
1 . Getting started with Causal Inference
Double Machine Learning (DML) is the method for estimating heterogeneous treatment effects especially when we have a large number of confounders. It also works with continuous treatment, unlike the meta-learners we discussed in the previous article. The method was introduced in the paper Double Machine for treatment and causal parameters which won the 2020 Nobel Prize in Economics.
DML is based on this cool regression concept which states if you have two sets of features X₁ and X₂, and you estimate the model parameters β₁ and β₂ using linear regression Y = β₁X₁ + β₂X₂, then **** the same set of parameters β₁can be obtained by following these steps
- Regress Y on the second set of features Y₁ᵣ= γ₁X₂
- Regress the first set of features on the second set of features X₁ᵣ = γ₂ X₂
- Obtain residuals Yᵣ = Y- Y₁ᵣ, Xᵣ = X₁ – X₁ᵣ
- Regress the outcome residuals on feature residuals Yᵣ = α+ β₁Xᵣ
We get the same β₁as we will get by regressing all features together. Pretty cool, huh? The first set of features (X₁ above) can be our treatment and the second set of features can be confounders whose impact on the outcome can be estimated separately.
Double/Debiased Machine learning
DML uses this very simple idea explained above, but instead of linear regression, it uses ML models, so that we can deal with non-linearities and interactions in estimating outcome and treatment residuals.
Yᵢ -Mᵧ(Xᵢ) = τ(Tᵢ-Mₜ(Xᵢ)) + εᵢ
Mᵧ(Xᵢ) is the outcome obtained by regressing Y on X and Mₜ(Xᵢ) is the outcome obtained by regressing T on X. τ is the impact of treatment on the outcome after controlling for confounders, but it is a constant and not a function of any covariate, that is, its represent ATE(Average treatment effect discussed at length in previous articles Getting started with Causal Inference and Methods for Inferring Causality)
We can also estimate CATE, where τ is allowed to vary with the unit’s covariates
Yᵢ -Mᵧ(Xᵢ) = τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)) + εᵢ
To estimate CATE using DML, we will use the residuals of treatment and outcome, but now we will interact with the treatment residuals with other covariates.
Yᵣ = β₁ + β₂XᵢTᵢᵣ
One issue that DML can suffer from is overfitting. If Mᵧ model overfits, the residual Yᵣ would be smaller than it should be, capturing more than the relationship between X and Y, and a part of that something more would be the relationship between T and Y, which will make out a causal estimate(residual regression) biased towards zero. If Mₜ overfits, the treatment residual would have less variance than it should. It is as if the treatment is the same for everyone, making it very difficult to estimate what would happen under different treatment regimes. It is therefore recommended to use cross-validation in the outcome and treatment models
Let’s code this up and understand this through an example. We will be using the synthetic sales data of a juice brand. Our goal is to predict price elasticity with respect to temperature.


Instead of developing our own DML function, let us now use DML function in the EconML library, developed by Microsoft.I strongly recommend going through EconML documentation on DML as they have different versions of DML implemented in their library.

Non-Parametric DML
In the above setup, we are using a linear model on the residuals, which means we are assuming that the effect of treatment on the outcome is linear, and this assumption might not always hold.
We have seen this equation
Yᵢ -Mᵧ(Xᵢ) = τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)) + εᵢ
If we rearrange, we get the error term as
εᵢ = (Yᵢ -Mᵧ(Xᵢ)) – τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ))
This is what we call the causal loss function, and we can minimize this to obtain τ(Xᵢ), our CATE.
Lₙ(τ(x))= 1/n Σ ((Yᵢ -Mᵧ(Xᵢ)) – τ(Xᵢ)(Tᵢ-Mₜ(Xᵢ)))²
Using the residual terms, we get this
Lₙ(τ(x))= 1/n Σ(Yᵢᵣ – τ(Xᵢ)Tᵢᵣ)²
Taking out Tᵢᵣ and isolating τ(Xᵢ)
Lₙ(τ(x))= 1/n Σ Tᵢᵣ²(Yᵢᵣ/Tᵢᵣ – τ(Xᵢ))²
We can minimize the above loss by minimizing what’s inside the parenthesis and using Tᵢᵣ² as weights. Here are the steps
- Create Tᵢᵣ²
- Create Yᵢᵣ/Tᵢᵣ
- Use any ML model to predict Yᵢᵣ/Tᵢᵣ using Tᵢᵣ² as weight.
Let’s predict the cate for the above example using non-parametric DML.

DML is commonly used for price elasticity problems in the industry, as it works well with continuous treatment. Below are a few case studies where Double Machine learning has been used to solve industry-relevant problems
Case studies
Customer Segmentation at an online media company using DML
Pricing with DML approach to Causal Inference
Other customer scenarios for Machine Learning based Heterogeneous Treatment Effect