Double Machine Learning, Simplified: Part 2 — Targeting & the CATE

Learn how to utilize DML for estimating idiosyncratic treatment effects to enable data-driven targeting

Jacob Pieniazek
10 min readJul 31, 2023

--

This article is the 2nd in a 2 part series on simplifying and democratizing Double Machine Learning. In the 1st part, we covered the fundamentals of Double Machine Learning, along with two basic causal inference applications. Now, in pt. 2, we will extend this knowledge to turn our causal inference problem into a prediction task, wherein we predict individual level treatment effects to aid in decision making and data-driven targeting.

Double Machine Learning, as we learned in part 1 of this series, is a highly flexible partially-linear causal inference method for estimating the average treatment effect (ATE) of a treatment. Specifically, it can be utilized to model highly non-linear confounding relationships in observational data (especially when our set of controls/confounders is of extremely high dimensionality) and/or to reduce the variation in our key outcome in experimental settings. Estimating the ATE is particularly useful in understanding the average impact of a specific treatment, which can be extremely useful for future decision making. However, extrapolating this treatment effect assumes a degree homogeneity in the effect; that is, regardless of the population we roll treatment out to, we anticipate the effect to be similar to the ATE. What if we are limited in the number of individuals who we can target for future rollout and thus want to understand among which subpopulations the treatment was most effective to drive highly effective rollout?

This issue described above concerns estimating treatment effect heterogeneity. That is, how does our treatment effect impact different subsets of the population? Luckily for us, DML provides a powerful framework to do exactly this. Specifically, we can make use of DML to estimate the Conditional Average Treatment Effect (CATE). First, let’s revisit our definition of the ATE:

(1) Average Treatment Effect

Now with the CATE, we estimate the ATE conditional on a set of values for our covariates, X:

(2) Conditional Average Treatment Effect

For example, if we wanted to know the treatment effect for males versus females, we can estimate the CATE conditional on the covariate being equal to each subgroup of interest. Note that we can estimate highly aggregated CATEs (i.e., at a male vs. female level) or we can allow X to take on an extremely high dimensionality and thus closely estimate each individuals treatment effect. You may immediately notice the benefits in being able to do this: we can utilize this information to make highly informed decisions in future targeting of the treatment! Even more notable, we can create a CATE function to make predictions on what we predict the treatment effect to be on previously unexposed individuals!

DML provides two primary methodologies for estimating this CATE function; namely, Linear DML and Non-Parametric DML. We will show how to estimate the CATE mathematically and then provide examples for each case.

Note: Unbiased estimation of the CATE still requires the exogeneity/CIA/Ignorability assumption to hold as covered in part 1.

Everything demonstrated below can and should be extended to the experimental setting (RCT or A/B Testing), where exogeneity is satisfied by construction, as covered in application 2 of part 1.

Linear DML for Estimating the CATE

Estimating the CATE in the linear DML framework is a simple extension of DML for estimating the ATE, as done in part 1:

(3) DML for Estimating the ATE

where y is our outcome, T is our treatment, & 𝑀𝑦 and MT are both flexible ML models (our nuisance functions) to predict y and T given confounders and/or controls, X, respectively. To estimate the CATE function using Linear DML, we can simply include interaction terms of the treatment residuals with our covariates. Observe:

(4) Linear DML for Estimating the CATE

where Ω is the vector of coefficients for the interaction terms. Now our CATE function, call it τ, takes the form τ(X) = β₁ + , where we can predict each individuals CATE given X. If T is continuous, this CATE function is for a 1 unit increase in T. Note that τ(X) = β₁ in eq. (3) where τ(X) is assumed a constant. Let’s take a look at this in action!

First, let’s use the same casual DAG from part 1, where we will be looking at the effect of an individuals time spent on the website on their purchase amount, or sales, in the past month (assuming we observe all confounders).:

Let’s then simulate this DFP using a similar process as utilized in part 1 (note that all values & data are chosen and generated arbitrarily for demonstrative purposes, and thus should not necessarily represent a large degree real world intuition per se outside of our estimates of the CATE). Observe that we now include interaction terms in the sales DGP to model the CATE, or treatment effect heterogeneity (note that the DGP in part 1 had no treatment effect heterogeneity by construction):

import numpy as np
import pandas as pd

# Sample Size
N = 100_000

# Confounders (X)
age = np.random.randint(low=18,high=75,size=N)
num_social_media_profiles = np.random.choice([0,1,2,3,4,5,6,7,8,9,10], size = N)
yr_membership = np.random.choice([0,1,2,3,4,5,6,7,8,9,10], size = N)

# Arbitrary Covariates (Z)
Z = np.random.normal(loc=50, scale = 25, size = N)

# Error Terms
ε1 = np.random.normal(loc=20,scale=5,size=N)
ε2 = np.random.normal(loc=40,scale=15,size=N)

# Treatment (T = g(X) + ε1)
time_on_website = np.maximum(10
- 0.01*age
- 0.001*age**2
+ num_social_media_profiles
- 0.01 * num_social_media_profiles**2
- 0.01*(age * num_social_media_profiles)
+ 0.2 * yr_membership
+ 0.001 * yr_membership**2
- 0.01 * (age * yr_membership)
+ 0.2 * (num_social_media_profiles * yr_membership)
+ 0.01 * (num_social_media_profiles * np.log(age) * age * yr_membership**(1/2))
+ ε1
,0)

# Outcome (y = f(T,X,Z) + ε2)
sales = np.maximum(25
+ 5 * time_on_website # Baseline Treatment Effect
- 0.2 * time_on_website * age # Heterogeneity
+ 2 * time_on_website * num_social_media_profiles # Heterogeneity
+ 2 * time_on_website * yr_membership # Heterogeneity
- 0.1*age
- 0.001*age**2
+ 8 * num_social_media_profiles
- 0.1 * num_social_media_profiles**2
- 0.01*(age * num_social_media_profiles)
+ 2 * yr_membership
+ 0.1 * yr_membership**2
- 0.01 * (age * yr_membership)
+ 3 * (num_social_media_profiles * yr_membership)
+ 0.1 * (num_social_media_profiles * np.log(age) * age * yr_membership**(1/2))
+ 0.5 * Z
+ ε2
,0)

df = pd.DataFrame(np.array([sales,time_on_website,age,num_social_media_profiles,yr_membership,Z]).T
,columns=["sales","time_on_website","age","num_social_media_profiles","yr_membership","Z"])

Now, to estimate our CATE function, as outlined in eq. (4), we can run:

import statsmodels.formula.api as smf
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import cross_val_predict

# DML Procedure for Estimating the CATE
M_sales = GradientBoostingRegressor()
M_time_on_website = GradientBoostingRegressor()

df[‘residualized_sales’] = df["sales"] - cross_val_predict(M_sales, df[["age","num_social_media_profiles","yr_membership"]], df[‘sales’], cv=3)
df[‘residualized_time_on_website’] = df[‘time_on_website’] - cross_val_predict(M_time_on_website, df[["age","num_social_media_profiles","yr_membership"]], df[‘time_on_website’], cv=3)

DML_model = smf.ols(formula='residualized_sales ~ 1 + residualized_time_on_website + residualized_time_on_website:age + residualized_time_on_website:num_social_media_profiles + residualized_time_on_website:yr_membership', data = df).fit()
print(DML_model.summary())

With the following results:

Here we can see that the linear DML closely modeled the true DGP for the CATE (see coefficients on interaction terms in sales DGP). Let’s evaluate the performance of our CATE function by comparing the linear DML predictions to the true CATE for a 1 hour increase in time on the spent on the website:

from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

# Predict CATE of 1 hour increase
df_predictions = df[['residualized_time_on_website','age','num_social_media_profiles','yr_membership']].copy()
df_predictions['linear_DML_CATE']= (DML_model.predict(df_predictions.assign(residualized_time_on_website= lambda x : x.residualized_time_on_website + 1))
- DML_model.predict(df_predictions))

# True CATE of 1 hour increase
df_predictions['true_CATE'] = 5 - 0.2 * df_predictions['age'] + 2 * df_predictions['num_social_media_profiles'] + 2 * df_predictions['yr_membership']

# Performance Metrics
mean_squared_error(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])
mean_absolute_error(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])
r2_score(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])

Here we obtain a MSE of ~0.45, MAE of ~0.55, & R2 of ~0.99. Plotting the distributions of the predicted CATE and true CATE, we obtain:

Additionally, plotting the predicted values versus the true values we obtain:

Overall, we have pretty impressive performance! However, the primary limitation in this approach is that we must manually specify the functional form of the CATE function, thus if we are only including linear interaction terms we may not capture the true CATE function. In our example, we simulated the DGP to only have these linear interaction terms and thus the performance is strong by construction, but let’s see what happens when we tweak the DGP for the CATE to be arbitrarily non-linear:

# Outcome (y = f(T,X,Z) + ε2)
sales = np.maximum( 25
+ 5 * time_on_website # Baseline Treatment Effect
- 0.2 * time_on_website * age # Heterogeneity
- 0.0005 * time_on_website * age**2 # Heterogeneity
+ 0.8 * time_on_website * num_social_media_profiles # Heterogeneity
+ 0.001 * time_on_website * num_social_media_profiles**2 # Heterogeneity
+ 0.8 * time_on_website * yr_membership # Heterogeneity
+ 0.001 * time_on_website * yr_membership**2 # Heterogeneity
+ 0.005 * time_on_website * yr_membership * num_social_media_profiles * age # Heterogeneity
+ 0.005 * time_on_website * (yr_membership**3 / (1 + num_social_media_profiles**2)) * np.log(age) ** 2 # Heterogeneity
- 0.1*age
- 0.001*age**2
+ 8 * num_social_media_profiles
- 0.1 * num_social_media_profiles**2
- 0.01*(age * num_social_media_profiles)
+ 2 * yr_membership
+ 0.1 * yr_membership**2
- 0.01 * (age * yr_membership)
+ 3 * (num_social_media_profiles * yr_membership)
+ 0.1 * (num_social_media_profiles * np.log(age) * age * yr_membership**(1/2))
+ 0.5 * Z
+ ε2
,0)

df = pd.DataFrame(np.array([sales,time_on_website,age,num_social_media_profiles,yr_membership,Z]).T
,columns=["sales","time_on_website","age","num_social_media_profiles","yr_membership","Z"])

# DML Procedure
M_sales = GradientBoostingRegressor()
M_time_on_website = GradientBoostingRegressor()

df[‘residualized_sales’] = df["sales"] - cross_val_predict(M_sales, df[["age","num_social_media_profiles","yr_membership"]], df[‘sales’], cv=3)
df['residualized_time_on_website'] = df[‘time_on_website’] - cross_val_predict(M_time_on_website, df[["age","num_social_media_profiles","yr_membership"]], df[‘time_on_website’], cv=3)

DML_model = smf.ols(formula='residualized_sales ~ 1 + residualized_time_on_website + residualized_time_on_website:age + residualized_time_on_website:num_social_media_profiles + residualized_time_on_website:yr_membership', data = df).fit()

# Predict CATE of 1 hour increase
df_predictions = df[['residualized_time_on_website','age','num_social_media_profiles','yr_membership']].copy()
df_predictions['linear_DML_CATE']= (DML_model.predict(df_predictions.assign(residualized_time_on_website= lambda x : x.residualized_time_on_website + 1))
- DML_model.predict(df_predictions))

# True CATE of 1 hour increase
df_predictions['true_CATE'] = (5 - 0.2*df_predictions['age'] - 0.0005*df_predictions['age']**2 + 0.8*df_predictions['num_social_media_profiles'] + 0.001*df_predictions['num_social_media_profiles']**2
+ 0.8*df_predictions['yr_membership'] + 0.001*df_predictions['yr_membership']**2 + 0.005*df_predictions['yr_membership']*df_predictions['num_social_media_profiles']*df_predictions['age']
+ 0.005 * (df_predictions['yr_membership']**3 / (1 + df_predictions['num_social_media_profiles']**2)) * np.log(df_predictions['age'])**2)

# Performance Metrics
mean_squared_error(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])
mean_absolute_error(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])
r2_score(df_predictions['true_CATE'], df_predictions['linear_DML_CATE'])

Here we see a stark decrease in performance, where we obtain a MSE of ~55.92, MAE of ~4.50, & R2 of ~0.65. Plotting the distributions of the predicted CATE and true CATE, we obtain:

Additionally, plotting the predicted values versus the true values we obtain:

This non-linearity in the CATE function is precisely where Non-Parametric DML can shine!

Non-Parametric DML for Estimating the CATE

Non-Parametric DML goes one step further and allows for another flexible non-parametric ML model to be utilized for learning the CATE function! Let’s take a look at how we can, mathematically, do exactly this. Let τ(X) continue to denote our CATE function. Let’s start with defining our error term relative to eq. 3 (note we drop the intercept β₀ as we are not interested in this parameter for the CATE; we could similarly drop this in the linear DML formulation, but for the sake of simplicity and consistency with part 1, we do not do this):

(5) Error in DML Framework

Then define the causal loss function as such (note this is just the MSE!):

(6) Causal Loss Function

What does this mean? We can directly learn τ(X) with any flexible ML model via minimizing our causal loss function! This amounts to a weighted regression problem with our target and weights, respectively, as:

(7) Target & Weights in Non-Parametric DML

Take a moment and soak in the elegance of this result… We can directly learn the CATE function & predict an individuals CATE given our residualized outcome, y, and treatment, T!

Let’s take a look at this in action now. We will reuse the DGP for the non-linear CATE function that was utilized in the example where linear DML performs poorly above. To construct of Non-Parametric DML model, we can run:

# Define Target & Weights
df['target'] = df['residualized_sales'] / df['residualized_time_on_website']
df['weights'] = df['residualized_time_on_website']**2

# Non-Parametric CATE Model
CATE_model = GradientBoostingRegressor()
CATE_model.fit(df[["age","num_social_media_profiles","yr_membership"]], df['target'], sample_weight=df['weights'])

And to make predictions + evaluate performance:

# CATE Predictions
df_predictions['Non_Parametric_DML_CATE'] = CATE_model.predict(df[["age","num_social_media_profiles","yr_membership"]])

# Performance Metrics
mean_squared_error(df_predictions['true_CATE'], df_predictions['Non_Parametric_DML_CATE'])
mean_absolute_error(df_predictions['true_CATE'], df_predictions['Non_Parametric_DML_CATE'])
r2_score(df_predictions['true_CATE'], df_predictions['Non_Parametric_DML_CATE'])

Here we obtain a much superior performance over linear DML, with a MSE of 4.61, MAE of 1.37, & R2 of 0.97. Plotting the distributions of the predicted CATE and true CATE, we obtain:

Additionally, plotting the predicted values versus the true values we obtain:

Here we can see that, although not perfect, the non-parametric DML approach was able to model the non-linearities in the CATE function much better than the linear DML approach. We can of course further improve the performance via tuning our model. Note that we can use explainable AI tools, such as SHAP values, to understand the nature of our treatment effect heterogeneity!

Conclusion

And there you have it! Thank you for taking the time to read through my article. I hope this article has taught you how to go beyond estimating only the ATE & utilize DML to estimate the CATE to further understanding heterogeneity in the treatment effects and drive more causal inference- & data- driven targeting schemes.

As always, I hope you have enjoyed reading this as much as I enjoyed writing it!

Resources

[1] V. Chernozhukov, D. Chetverikov, M. Demirer, E. Duflo, C. Hansen, and a. W. Newey. Double Machine Learning for Treatment and Causal Parameters. ArXiv e-prints, July 2016.

Access all the code via this GitHub Repo: https://github.com/jakepenzak/Blog-Posts

I appreciate you reading my post! My posts on Medium seek to explore real-world and theoretical applications utilizing econometric and statistical/machine learning techniques. Additionally, I seek to provide posts on the theoretical underpinnings of various methodologies via theory and simulations. Most importantly, I write to learn and help others learn! I hope to make complex topics slightly more accessible to all. If you enjoyed this post, please consider following me on Medium!

--

--

Data Scientist | M.A. Economics | I write on theoretical and real-world applications of econometric, mathematical, and statistical/machine learning techniques