Hands-on Tutorials

Retain Customers with Time to Event Modeling-Driven Intervention

Pinpointing customer churn at the right time using XGBoost

Charles Frenzel
Towards Data Science
14 min readJun 3, 2021

--

By Charles Frenzel, Baichuan Sun and Yin Song

It usually costs more to acquire a customer than it does to retain a customer.

Focusing on customer retention enables companies to maximize customer revenue over their lifetime.

This blog post will show you how to train a model to predict both the risk AND the time of a customer attrition event using XGBoost.

Combined with a production-level end-to-end Machine Learning pipeline like Customer Churn Pipeline on AWS that has time to event churn modeling baked in, this allows for timely interventions to stop customer attrition.

Image by Author

Customer attrition or simply churn is when a customer ‘leaves’ before maximal revenue is achieved. Stopping these events to retain revenue is so established that churn-based models are one of the very first machine learning solutions that go into production.

However, these models are seldom done optimally as they rely on binary classification flags (churn yes or no). Churn classification models do not tell WHEN a customer is likely to leave but only indicate that it’s going to happen within a certain number of days or months.

This blog post will present how to measure churn risk during a customer’s lifecycle to find the point in time churn intervention is needed.

Paradigm Shift: from Event Only to Timely Intervention

A good reason churn models could be better is that many are set on arbitrary time threshold for a fixed binary outcome. This means that time is held constant! For example, setting an arbitrary threshold that after 40 days of inactivity a customer churns.

Flagging customers based on such heuristics leads to slippage, mainly:

  • Customers churning before the threshold.
  • Customers that churn far, far after the threshold.
  • Ignores customer lifetime differences

It is probably a mistake to treat a customer that is at risk of leaving in 40 days the same as a customer that remains for over a 100 days. Traditional churn modeling does not make this differentiation.

For example, in the chart below, Customer B is captured accurately by the model because they leave at the exact point in time that the threshold was set (40 days). Customer A actually churns after the threshold and they are lost because the model is not able to account for them. Customer C did the opposite, staying on far longer than the data time window. They will most likely churn but we cannot model for when with a classification model.

Image by Author

The only point in time here is the “within 40 days” threshold. As it fails to account for time, we have no clear idea at what point a marketing intervention is needed and it causes preventable customer attrition.

Re-framing the Problem to Know When

Rather then use a binary classifier, we are going to re-frame the problem as time-dependent one. This enables us to intervene at the right time to stop customer attrition before it happens. No longer relying on thresholds, we now set churn as continuous time conditioned event. As the below graph shows, we now know the time that attrition risk is most likely to happen.

Image by Author

No longer is time held constant, we now track risk over time to determine when a marketing intervention is needed to retain the customer. If we model for both the time and event, the right moment to intervene and prevent attrition is apparent. A modeling technique called Survival Analysis allows for us to do this and with the advent of modern Machine Learning, it’s now a trivial task. A deep dive of Survival Analysis and the maths behind, is out scope and we encourage you to look at all the great posts on Medium’s Toward Data Science for more information.

A Quick Glimpse at the Data

In this example, you will use a synthetic churn dataset for an imaginary telecommunications company with the outcome ‘Churn?’ flagged as either True (churned) or False (did not churn). Features include customer details such as plan and usage information. The churn dataset is publicly available and mentioned in the book Discovering Knowledge in Data by Daniel T. Larose. It is attributed by the author to the University of California Irvine Repository of Machine Learning Datasets. The notebook for all the code is located here.

df = pd.read_csv("../../data/churn.txt")

# denoting churn and duration
df["event"] = np.where(df["churn?"] == "False.", 0, 1)
df = df.rename(columns={"account_length": "duration"})

del df['churn?']

df = df.dropna()
df = df.drop_duplicates()
df.head()
Image by Author

Examining our targets further shows that, there are a total of 5,000 records of which 49.9% end up churning. The dataset is balanced on the target. In real world data, this is not always the case and a churn event could be 1% out of millions of records. There’s strategies to remedy that but it’s out of scope for this blog post.

Looking at the duration, represented as the Account Length (our time component), it shows that the median time is 102 days, which is close to the average of 101 days.

print("Total Records:",df.shape[0],"\n")
print("Percent Churn Rate:",df.event.mean())
print("")
print("Duration Intervals")
print(df['duration'].describe())
Total Records: 5000

Percent Churn Rate: 0.4996

Duration Intervals
count 5000.0000
mean 101.6758
std 57.5968
min 1.0000
25% 52.0000
50% 102.0000
75% 151.0000
max 200.0000

For Survival models data is different from a traditional classification problem and requires:

  • A Censor — For our purposes these are customers who’ve yet to churn. Read about right censoring here.
  • Duration — The duration or time t of the customer’s activity. In this case, it’s Account Length in days.
  • Event — The binary target, in this case if they terminated their phone plan marked by Churn?.

We can plot the first 10 customers on timeline to understand how right censored data works and how the problem is framed.

ax = plot_lifetimes(df.head(10)['duration'], df.head(10)['event'])

_=ax.set_xlabel("Duration: Account Length (days)")
_=ax.set_ylabel("Customer Number")
_=ax.set_title("Observed Customer Attrition")
png
Image By Author

In the above plot, the red lines indicates when a customer has left with the dots indicating the specific point in time. Blue lines are customers that are still active up to the time measured on the x-axis in Duration.

Here we see that customer number 8 did not attrit until up to 195 days, with customer numbers 0 and 4 leaving in 163 and 146 days respectively. All other customers are still active.

Notice how all customers are set on the same time scale because the data is analytically aligned. Each customer might have come in at different times but we’ve set the days as the same. This is what allowed us to right-censor the data on the churn event. Real world data needs both censoring and aligning before modeling can begin.

The Risk of Churn

A more informative approach might be to estimate the Survival Function or the time in days a customer has until they attrit. For this purpose, we will use a Kaplan Meier Estimator to calculate how long until attrition occurs. The estimator is defined as:

Source: Lifelines

Where 𝑑𝑖 are the number of churn events at time 𝑡 and 𝑛𝑖 is the number of customers at risk of churn just prior to time 𝑡.

We will use the great python package lifelines to plot the Survival Function as the function is a component of the final churn model.

kmf = KaplanMeierFitter()

kmf.fit(df['duration'], event_observed=df['event'])

kmf.plot_survival_function()
_=plt.title('Survival Function for Telco Churn');
_=plt.xlabel("Duration: Account Length (days)")
_=plt.ylabel("Churn Risk (Percent Churned)")
_=plt.axvline(x=kmf.median_survival_time_, color='r',linestyle='--')
png
Image By Author

Let’s look at the median survival time. This is the point by which half of customers have churned out. According to this graph, where it’s marked by the red dotted line, by about 152 days half of customers churn. This is helpful because it gives overall baseline when intervention is needed. However, for each individual customer this is uninformative.

What is missing is the point in time in which churn risk is highest for each customer.

For that we will create a model using Cox’s Proportional Hazard which uses a log-risk function h(x). The Hazard function is conditioned on rate of a customers remaining until time t or later, this allows to estimate the risk of churn overtime. This will enable us to score each customer and anticipate when a marketing intervention is needed. However, beefore we proceed to that, we need to preprocess the data.

Data Splitting and Preprocessing

First we will split the data into training and testing. We’ll use the testing set as the validation for the example. In practice, you want all three of these splits so that you don’t tune to the validation set.

Next, we take the numeric features and categorical features and then preprocess them for downstream modeling. In the case of categories, we will first impute with the constant and then simply one-hot encode them. In the case of numerics, we will fill with the median then standardize them between values of 0 and 1. This is all wrapped into Sklearn’s Pipeline and ColumnTransformer for simplicity’s sake.

As part of the Churn Pipeline all these steps are included with the final preprocessor saved for use at inference time.

df_train, df_test = train_test_split(df , test_size=0.20, random_state=SEED)numerical_idx = (
df_train.select_dtypes(exclude=["object", "category"])
.drop(['event','duration'],1)
.columns.tolist()
)

categorical_idx = df_train.select_dtypes(exclude=["float", "int"]).columns.tolist()

numeric_transformer = Pipeline(
steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler()),
]
)

categorical_transformer = Pipeline(
steps=[
("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
("onehot", OneHotEncoder(sparse=False, handle_unknown="ignore")),
]
)

preprocessor = ColumnTransformer(
[
("numerical", numeric_transformer, numerical_idx),
("categorical", categorical_transformer, categorical_idx),
],
remainder="passthrough",
)

train_features = preprocessor.fit_transform(df_train.drop(['event','duration'],1))
test_features = preprocessor.transform(df_test.drop(['event','duration'],1))

Converting the Target for XGBoost

We will use the DMatrix format for XGBoost to run through the regular, non-scikit API. For the survival function, this requires a transformation, setting duration as the target and then making it positive for events and negative for non-events. Rather than having a tuple of event and duration or binary outcome, this gives you a positive/negative single continuous variable as the target. For more about how Survival works in XGBoost see this tutorial.

def survival_y_cox(dframe:pd.DataFrame) -> np.array:
"""Returns array of outcome encoded for XGB"""
y_survival = []

for idx, row in dframe[["duration", "event"]].iterrows():
if row["event"]:
# uncensored
y_survival.append(int(row["duration"]))
else:
# right censored
y_survival.append(-int(row["duration"]))
return np.array(y_survival)

dm_train = xgb.DMatrix(
train_features, label=survival_y_cox(df_train), feature_names=feature_names
)

dm_test = xgb.DMatrix(
test_features, label=survival_y_cox(df_test), feature_names=feature_names
)

More on the Hazard Function

The hazard function provides customer attrition risk — telling us when churn is most likely to happen.

Whereas the Survival function S(t) returns the probability of churn beyond a point of time S(t) = P(T > t), the Hazard function h(t) instead gives an approximate probability of that customer stays up to the time t such that:

Source: Lifelines

On a side note, with the Hazard function, it is also possible to get the survival function because:

Source: Lifelines Documentation

Gradient Boosting and Cox’s Partial Likelihood

In the case of Gradient Boosting multiple base learners are combined to obtain boosted overall ensemble of learners defined as an additive model of:

Source: Scikit-Survival Documentation

For the Survival analysis case, the objective is to maximize the log partial likelihood function, but replacing the traditional linear model f(x) with the additive model:

Source: Scikit-Survival Documenation

For more on this see this excellent tutorial for python library Scikit-Survival.

The normal parameters all apply here with the exception that we’ve changed the objective to surivial:cox this will allow for training a boosted survival tree.

params = {
"eta": 0.1,
"max_depth": 3,
"objective": "survival:cox",
"tree_method": "hist",
"subsample": 0.8,
"seed": 123
}
bst = xgb.train(
params,
dm_train,
num_boost_round=300,
evals=[(dm_train, "train"), (dm_test, "test")],
verbose_eval=int(1e1),
early_stopping_rounds=10
)
[0] train-cox-nloglik:7.25501 test-cox-nloglik:5.86755
...
[151] train-cox-nloglik:6.67063 test-cox-nloglik:5.39344

A Note On Predictions

Predictions for this model are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function h(t) = h0(t) * HR).This means that the output can come out either as the exponentiated marginal prediction or as the non-exponentiated version. For predicting when churn is most likely to occur we will want the exponentiated version as it intuits to a probability (even though it’s not really one technically speaking). For more on this see how output happens see.

Examining Global Predictions

We can then take the scores Hazard scores or the probability of churn conditioned on time t (in this case Account Length), and look at overall when churn is most likely to occur.

png
Bucketed Churn Risk (Image by Author)

Bucketing the values into time periods shows that the highest churn risk happens at days 53 to 62. Following this period, the most likely times for attrition are at days 80 to 102. In practice, you should ignore the final bar from days 191 to 200 as this is the point of truncation.

Evaluating Performance

In the case of Survival Models Harrell’s Concordance Index and Brier Score are commonly used to evaluate survival models.

Harrell’s Concordance Index (Harrell et al. 1982)

The concordance index or C-index is a generalization of the area under the ROC curve (AUC) that can take into account censored data.

It should be thought of as a goodness of fit measure for models which baselines the model’s ability to correctly provide a reliable ranking of the survival times based on the individual risk scores.

Values of C = 0.5 indicate that risk scores are no better than a coin flip.

This is expressed as as the # of concordiant pairs / (# of concordiant pairs + # of discordiant pairs) or:

Source: PySurvival Documenation

Brier Score (Brier 1950)

The Brier score is used to evaluate the accuracy of a predicted survival function at a given time t it represents the average squared distances between the observed survival status and the predicted survival probability and is always a number between 0 and 1, with 0 being the best possible value.

Source: Pysurvival Documenation

However, if the dataset contains samples that are right censored, then it is necessary to adjust the score by weighting. That’s where Scikit-Survival’s Brier Score metric comes to the rescue. As stated in their docs, the time-dependent Brier score is the mean squared error at time point t

Source: Scikit-Survival Documentation

That means the measurement is now adjusted for right censored data, and is thus more accurate.

Let’s score the model!

print("CIC")
print(
surv_metrics.concordance_index_ipcw(
y_train,
y_test,
df_test['preds'],
tau=100 # within 100 days
)
)

print("Brier Score")
times, score = surv_metrics.brier_score(
y_train,y_test, df_test['preds'], df_test['duration'].max() - 1
)
print(score)
CIC
(0.7514910949902487, 177342, 58706, 0, 1218)

Brier Score
[0.37630957]

These results are okay given that the data was not prepared with survival analysis in mind. The Concordance Index comes out to 0.75 which is better than simple random chance. The Brier Score is 0.376 which is not so great. Ideally, we’d want it to be 0.25 or lower as the Pysurvival documentation indicates.

There are other ways to evaluate this further. For example, the Scikit-survival package offers a wide variety of metrics for evaluation such as the Time dependent Area under the ROC and more.

Model Explainability with SHAP

The good news is that SHAP (SHapley Additive exPlanations provides a high-speed exact algorithm for tree-based ensembles to help explain feature importance within the model. Specifically, it will allow for use to understand what raises and lowers attrition risk for customers.

explainer = shap.TreeExplainer(bst, feature_names=feature_names)
shap_values = explainer.shap_values(test_features)

shap.summary_plot(shap_values, pd.DataFrame(test_features, columns=feature_names))
png
Image by Author

Moreover, we now have an explanation at the customer level of how each churn risk score is calculated. This can help inform intervention strategies choices. For example, night charges, night calls and evening minutes all drive customer attrition. Adding in the daily minutes (as a positive predictor) shows that clearly customers making calls in the evening or at night are at higher risk. This could go into an intervention strategy by setting intervention communications to go out to at risk customers during the night and evening.

This provides details that are easy to show for business users and ways in which to further analyze what features drive customer attrition.

idx_sample = 128
shap.force_plot(
explainer.expected_value,
shap_values[idx_sample, :],
pd.DataFrame(test_features, columns=feature_names).iloc[idx_sample, :],
matplotlib=True,
)

print(f"The real label is Churn={y_test[idx_sample][0]}")
png
How a Non-Churner was Calculated, Image by Author
The real label is Churn=False

Lastly, since this is a tree-based model we can also plot what the tree(s) look like. Though we set up to 100 trees to train, early stopping set in and the best iteration for the run was 67 trees. Let’s just take the first tree and look at it’s splits to judge to understand how inference is run.

xgb.plot_tree(bst, rankdir="LR", num_trees=0)
fig = plt.gcf()
fig.set_size_inches(150, 100)
png
First Tree Grown, Image by Author

Unsurprisingly, the most predictive features are the ones that it split on first. In this case, that means splits on night charge, night calls, day minutes and evening minutes. Obviously, this tree also is a great way to demonstrate what your model is doing for business users to understand.

The Final Check

For our final check we will treat the problem as a classifier just because we can. Since the model itself is conditioned on time, these aren’t really good metrics to evaluate how technically well it’s doing. We are simply running this for the naysayers and to show that it works.

from sklearn import metrics

y_preds = df_test.preds.apply(lambda x : np.exp(x))
y_pred = np.where(y_preds > 0.5, 1, 0)

print(f"Accuracy score: {metrics.accuracy_score(df_test.event, y_pred)}")
print(f"Area Under the Curve {metrics.roc_auc_score(df_test.event, y_pred)}")
print("")
print(metrics.classification_report(df_test.event, y_pred))
Accuracy score: 0.932
Area Under the Curve 0.9339674490815218

precision recall f1-score support

0 0.97 0.90 0.93 527
1 0.89 0.97 0.93 473

accuracy 0.93 1000
macro avg 0.93 0.93 0.93 1000
weighted avg 0.94 0.93 0.93 1000

Amazingly, the model gets better results than the reported 86% accuracy when this dataset first appeared in the blogosphere in 2017. Not only do we have accuracy but we can identify WHEN. This allows for timely marketing intervention to retain customers.

Conclusion

This blog post showed how to train a churn model with a time component. Using Survival Analysis with Cox Proportional Hazard allows us to prevent customer attrition by pinpointing when its risk is the highest. This allows for proactive, point in time intervention to stop customers from leaving.

Now that you have a model it’s time to work on productionizing it, like with the Customer Churn Pipeline on AWS which includes a template for Time to Event modeling for churn. Once the pipeline is configured, you can run inference on records over time accumulating scores with batch inference to make intervention decisions.

png
Customer Risk by Day Interval (Image by Author)

(The above show a table accumulated results of batch jobs by customer over days. There is now a churn risk history of the customer to monitor and flag when to intervene.)

The use of Machine Learning for Survival Analysis is a great way to frame problems such as churn. We encourage you to look at the references links below for all the different techniques that are available.

References

Lifelines, Cameron Davidson-Pilon 2014

XGBoost: A Scalable Tree Boosting System, Chen & Guestrin 2016

PySurvival: Open source package for Survival Analysis modeling, Fotso 2019

scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn, Sebastian Polsterl 2020

SHAP (SHapley Additive exPlanations), Lundberg 2017

--

--