The world’s leading publication for data science, AI, and ML professionals.

Are you answering the right churn questions?

How to structure customer churn problems with classification and survival analysis

Photo by Jan Tinneberg on Unsplash
Photo by Jan Tinneberg on Unsplash

Acquiring new customers can be very expensive. It is often times much cheaper for companies to retain existing customers than to acquire new ones. In order to retain existing customers, companies first have to understand who is at risk of leaving. This is where customer Churn or retention models come into the picture.

Churn or attrition refers to customers who stopped using your company’s product or service during a certain time frame. Each company tends to define churn slightly differently ranging from inactivity after a certain time-frame, formal cancellations to government-mandated account closures such as escheatment. This can be very confusing for us as data scientist when we try to build models to identify potential churners.

Even if we are able to work cross-functionally across the organization to define churn or we are fortunate enough to work for a company with clear definitions, we still have to make sure we are answering the right question. Say for example, you work for a large bank where customers formally have to request their accounts be closed in order to be considered churners. This sounds like it would be pretty straight forward problem, but often times it’s not.

From a business perspective, time is money and understanding when the customer will churn can have a massive impact on business outcomes. Do we need to predict who will churn within the next hour, day, week, month or quarter? Do we need to predict the time until customers churn? These questions will impact a potential retention strategy and as such warrants a lot of consideration before building our model.

Once we have defined churn, decided upon the question we need to answer, we can start our model development process. To get an practical understanding of the approaches we can use to predict customer churn, we will go to Kaggle and download the Telco Churn dataset. The Telco dataset has a binary target variable called churn. It is representative of customers who churned within a month’s time.

We will explore building models that answer two separate questions

  1. Which customers will leave within the next month?
  2. When will our customers leave?

We will look at different modeling techniques for these types of use cases including logistic regression, explainable boosting machines, gradient boosting machines and cox regressions. As we develop the models, we will also discuss how we need to transform our target variables to meet our business objectives.

This guide will be using Python and we will begin with importing the necessary libraries.

import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score,confusion_matrix,balanced_accuracy_score, accuracy_score, classification_report
import matplotlib.pylab as plt
import seaborn as sns
import shap
import xgboost as xgb
from interpret.glassbox import ExplainableBoostingClassifier, LogisticRegression, ClassificationTree, DecisionListClassifier
from interpret import show
from interpret.perf import ROC
import numpy as np
from lifelines import *
from lifelines import CoxPHFitter, KaplanMeierFitter, GeneralizedGammaFitter
from lifelines.utils import find_best_parametric_model
from lifelines.datasets import load_lymph_node
from lifelines.plotting import qq_plot
shap.initjs()

The data is sources from Kaggle and is the infamous Telco Customer Churn dataset. It’s a very commonly used churn dataset.

df = pd.read_csv('Telco-Customer-Churn.csv')

Data Understanding

Each row represents a customer, each column contains customer’s attributes described on the column Metadata. The data set includes information about:

  • Customers who left within the last month – the column is called Churn
  • Services that each customer has signed up for – phone, multiple lines, internet, online security, online backup, device protection, tech support, and streaming TV and movies
  • Customer account information – how long they’ve been a customer, contract, payment method, paperless billing, monthly charges, and total charges
  • Demographic info about customers – gender, age range, and if they have partners and dependents
df.head()

Lets check out the how many missing values we have:

df.isna().sum()

The Total Charges column has some of missing values it that looks like its just spaces. Let’s replace with nan for now and convert to floats

df['TotalCharges'] = df["TotalCharges"].replace(" ",np.nan).astype('float64')

Some models are unable to handle missing values so lets replace nan with zeroes.

X_train = X_train.replace(np.nan, 0)
X_test = X_test.replace(np.nan, 0)

Next, we check the data types of our columns. Most models can only handle numeric values so we may have to turn them into numerical values.

df.dtypes

Customer ID has all unique values and several of our categorical values have 2–3 values. Some of the object type variables appear to be binary (Yes/No) and as such we can potentially label or one-hot encode them.

df.nunique()

Target Variable

Our target variable is churn which has a flag of yes or no. In this dataset we are looking at people who churn within a one month timeframe, but the classification problem approach could work for any timeframe. The downside is that as we change the timeframe, we often have to rework our data pipeline. For now let’s simply convert our target variable to numeric. Using label encoder, we will transform No into 0 and Yes into 1.

## 1 = Churned , 0 = Not Churned
le = preprocessing.LabelEncoder()
le.fit(df['Churn'])
df['Churn'] = le.transform(df['Churn'])

Exploratory Data Analysis

Lets look at the cases of churn vs. non-churners we have.

sns.countplot(x="Churn", data=df)
Class Balance Plot
Class Balance Plot

Looks like we have a slight imbalance issue where have more cases of non-churners than churners. This is not uncommon in this type of use case as companies who loose half their customers in a short-time frame usually aren’t around for very long. There are several ways to address the imbalance issue such as up-sampling, down-sampling and adjusting model weights.

df.groupby('Churn').count()['customerID']/ df.shape[0]

An alternative style for visualizing the same information is offered by the pointplot() function. This function also encodes the value of the estimate with height on the other axis, but rather than showing a full bar, it plots the point estimate and confidence interval. Additionally, pointplot() connects points from the same hue category. This makes it easy to see how the main relationship is changing as a function of the hue semantic, because your eyes are quite good at picking up on differences of slopes.

sns.pointplot(x="Contract", y="Churn", hue="OnlineSecurity", kind="point", data=df);

We can also compare the Total Charges to see if there is a difference in distributions between churn and non-churners in their total charges.

sns.boxplot(x="Churn", y="TotalCharges", data=df)

Data Transformations

Some models can only work with numerical representations of the data. Due to this limitation, we need to to transform our categorical variables.

for each in ['gender', 'SeniorCitizen', 'Partner', 'Dependents',
       'tenure', 'PhoneService', 'MultipleLines', 'InternetService',
       'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling',
       'PaymentMethod' ]:
    le.fit(df[each])
    df[each] = le.transform(df[each])

While not all models are linear models, it can also be useful to evaluate the model relationships between variables in terms of correlation. A correlation matrix is a good visual representation of the correlation between variables.

# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))
# Generate a mask for the upper triangle
corr = df.corr()
mask = np.triu(np.ones_like(corr, dtype=np.bool))
# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)
# Plot correlation matrix
sns.heatmap(corr, cmap =cmap,linewidths=.5, mask=mask)

Train-Test Split

To be able to evaluate model performance, we separate our dataset into training and testing sets. This will allow us to test our models performance on unseen data. Since the data is a binary classification, we would ideally like to keep the same ratio of churners and non-churners. Luckily, this is already built into sklearn.

We will also take the opportunity to drop the churn column and the customer columns as won’t use those as dependent variables to train our classification model.

## Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(df.drop(['Churn','customerID'],axis =1), df['Churn'], test_size=0.25, random_state=42)

Binary Classification Models

To answer the question of who will churn within a given timeframe, we frame the problem as a classification problem with a binary outcome. This means, we can use the target variable we defined and simply build a binary classification model.

Classification models help us understand who will likely churn within a specific timeframe. In the Telco example, the timeframe is a month. Depending on the use-case the time-window may be much shorter (real-time call monitoring) or longer (likelihood to churn at next contract date for subscription services). Before we dive in to specific models, its good to get a baseline understanding of what classification models are.

In statistics, classification is the problem of identifying to which of a set of categories (sub-populations) a new observation belongs, on the basis of a training set of data containing observations (or instances) whose category membership is known. [1]

For our Telco churn dataset, this means we will try to predict whether or not a customer will churn using the Churn column in the dataset. As a reminder, the churn field is defined as:

Customers who will leave within the next month

Logistic Regression Model

The first classification model we will build is a logistic regression model. In statistics, the logistic model (or logit model) is used to model the probability of a certain class or event existing. Logistic regression is a statistical model that in its basic form uses a logistic function to model a binary dependent variable.

In practice, logistic regression is very useful if we need non-technical stakeholder agreement. We can implement and derive value from our model by providing intuitive explanations to other business units in order for them to set-up communication triggers, improve processes that cause churn or simply sign-off on the model to go to production.

We will use the interpret library for this model as it has some nice variable importance dashboards built in. Since the logistic regression model will be our baseline model, we will only use some of the features available. In practice, this would be the initial insights points stakeholders and experts will provide you with.

seed = 1255
lr = LogisticRegression(random_state=seed, max_iter = 1000,penalty='l1', solver='liblinear')
lr.fit(X_train[['PhoneService','PaperlessBilling', 'TotalCharges', 'Contract', 'InternetService']], y_train)

Interpret allows us to evaluate the variable importance as a dashboard. Since logistic regression models are locally and globally consistent, we know that the local importance of variable matches the global importance.

This provides great explainable power, but may limit some performance, especially when relationships in the data is non-linear.

lr_global = lr.explain_global(name='Logistic Regression')
show(lr_global)

Using the dashboard, we can also get the ROC curve and see the AUC score. A receiver operating characteristic curve, or ROC curve, is a graphical plot that illustrates the diagnostic ability of a binary classifier system as its discrimination threshold is varied. The ROC curve is created by plotting the true positive rate against the false positive rate at various threshold settings.

lr_perf = ROC(lr.predict_proba).explain_perf(X_test, y_test, name='Logistic Regression')
show(lr_perf)

Using the model we just trained, we will make predictions on the unseen data and evaluate the performance.

## Predict each class 
y_pred = lr.predict(X_test)
## Predict the probablities of each class
y_pred_proba = lr.predict_proba(X_test)

We can calculate some evaluation metrics including Accuracy, ROC AUC and Balanced Accuracy. The balanced accuracy in binary and multiclass classification problems to deal with imbalanced datasets. It is defined as the average of recall obtained on each class.

# Accuracy
LR_accuracy_score = accuracy_score(y_test, y_pred)
# ROC AUC
LR_roc_auc_score = roc_auc_score(y_test,y_pred_proba[:, 1])
# Balanced Accuracy
LR_balanced_accuracy_score = balanced_accuracy_score(y_test, y_pred)
data = pd.DataFrame(columns= ['metric','value'])
data['value'] = [LR_accuracy_score,LR_roc_auc_score,LR_balanced_accuracy_score]
data['metric'] = ['Accuracy Score','ROC AUC','Balanced_Accuracy']
# plot horizontal barplot
sns.set(rc={'figure.figsize':(10,5)})
ax = sns.barplot(x="value", y="metric", data=data)
ax.set(title='Logistic Regression Model Performance') # title barplot
# label each bar in barplot
for p in ax.patches:
    height = p.get_height() # height of each horizontal bar is the same
    width = p.get_width() # width 
 # adding text to each bar
    ax.text(x = width, # x-coordinate position of data label, padded 3 to right of bar
    y = p.get_y()+(height/2), # # y-coordinate position of data label, padded to be in the middle of the bar
    s = '{:.00%}'.format(width), # data label, formatted to ignore decimals
    va = 'center') # sets vertical alignment (va) to center

Lastly, we will generate a classification report that gives us more insights into how the model performs.

target_names = ['class 0', 'class 1']
print(classification_report(y_test, y_pred, target_names=target_names))

Explainable Boosting Machine (EBM)

After building our baseline model, we want to invest some time to build a higher performance model using a more sophisticated framework such as Explainable Boosting Machine (EBM) to improve performance.

EBM is an interpretable model developed at Microsoft Research. It uses modern machine learning techniques like bagging, gradient boosting, and automatic interaction detection to breathe new life into traditional GAMs (Generalized Additive Models).This makes EBMs as accurate as state-of-the-art techniques like random forests and gradient boosted trees. However, unlike these blackbox models, EBMs produce lossless explanations and are editable by domain experts.[6]

Now that we have a high-level understanding of EBM we will train our model. EBM allow us to set certain parameters such as the number of estimators and the number of variable interactions. Each of the parameters impacts performance and training times.

seed =100
ebm = ExplainableBoostingClassifier(random_state=seed,interactions=100, n_estimators = 400, max_tree_splits = 10, n_jobs =3)
ebm.fit(X_train,y_train)

We can review the global explanations or in other words what the model learned overall. Much like we discovered in the exploratory data contracts have a significant impact on the model.

ebm_global = ebm.explain_global(name='EBM')
show(ebm_global)
EBM Global Variable Importance
EBM Global Variable Importance

We can also review the local explanations individually.

ebm_local = ebm.explain_local(X_test[:50], y_test[:50], name='EBM')
show(ebm_local)
EBM Individual Performance
EBM Individual Performance

We can check the performance metrics.

ebm_perf = ROC(ebm.predict_proba).explain_perf(X_test, y_test, name='EBM')
show(ebm_perf)

Predict each class and associated probabilities.

y_pred = ebm.predict(X_test)
## Predict the probablities of each class
y_pred_proba = ebm.predict_proba(X_test)

Calculate accuracy, ROC-AUC and balanced accuracy.

# Traditional Accuracy
EBM_accuracy_score = accuracy_score(y_test, y_pred)
# ROC AUC
EBM_roc_auc_score = roc_auc_score(y_test,y_pred_proba[:, 1])
#Balanced Accuracy
EBM_balanced_accuracy_score = balanced_accuracy_score(y_test, y_pred)
data = pd.DataFrame(columns= ['metric','value'])
data['value'] = [LR_accuracy_score,LR_roc_auc_score,LR_balanced_accuracy_score]
data['metric'] = ['Accuracy Score','ROC AUC','Balanced_Accuracy']
## From: https://medium.com/swlh/quick-guide-to-labelling-data-for-common-seaborn-plots-736e10bf14a9
# plot horizontal barplot
sns.set(rc={'figure.figsize':(10,5)})
ax = sns.barplot(x="value", y="metric", data=data)
ax.set(title='Logistic Regression Model Performance') # title barplot
# label each bar in barplot
for p in ax.patches:
    height = p.get_height() # height of each horizontal bar is the same
    width = p.get_width() # width 
 # adding text to each bar
    ax.text(x = width, # x-coordinate position of data label, padded 3 to right of bar
    y = p.get_y()+(height/2), # # y-coordinate position of data label, padded to be in the middle of the bar
    s = '{:.00%}'.format(width), # data label, formatted to ignore decimals
    va = 'center') # sets vertical alignment (va) to center

Classification Report

target_names = ['class 0', 'class 1']
print(classification_report(y_test, y_pred, target_names=target_names))

XGBoost

Another model we will use is a gradient boosted decision tree model called XGBoost. XGBoost is one of the most laudered models in the industry due to its high-performance.

XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable. It implements machine learning algorithms under the Gradient Boosting framework. XGBoost provides a parallel tree boosting (also known as GBDT, GBM) that solve many data science problems in a fast and accurate way.[2]

We will use most of the default parameters for XGBoost except we will change the scale_pos_weights parameter to account for our class imbalance.

## XGBoost has its own datamatrix that helps speed up computations.
d_train = xgb.DMatrix(X_train, label=y_train)
d_test = xgb.DMatrix(X_test, label=y_test)
## Set up parameters. 
params = {
    "eta": 0.005,
    "objective": "binary:logistic",
    "eval_metric": "auc",
    "scale_pos_weight": 3
}
model = xgb.train(params, d_train, 
                  num_boost_round = 5000, verbose_eval=100)

XGBoost has its own variable importance plot to explain how the model makes predictions at the global level. The challenge with this is that XGBoost uses ensemble of decision trees so depending upon the path each example travels, different variables impact it differently. This means that the global importance from XGBoost is not locally consistent.

xgb.plot_importance(model)
plt.title("xgboost.plot_importance(model)")
plt.show()

Enter SHAP (SHapley Additive exPlanations).

SHAP is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions [3]

SHAP will help us understand how the model makes predictions. This framework helps make the blackbox nature of gradient boosted decision trees more transparent.

## Shap 'utf-8' codec can't decode byte 0xff for xgboost model" issue workaround
## https://github.com/slundberg/shap/issues/1215
model_bytearray = model.save_raw()[4:]
def myfun(self=None):
    return model_bytearray
model.save_raw = myfun
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)

To get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample.

The summary plot sorts features by the "sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output" [3]. The color represents the feature value (red high, blue low)

shap.summary_plot(shap_values, X_train)
SHAP Summary Plot
SHAP Summary Plot

Let’s look at local explanation produced by SHAP. The individual explanation shows how each feature helps contribute "to push the model output from the base value (the average model output over the training dataset we passed) to the model output" [3]. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

shap.force_plot(explainer.expected_value, shap_values[0,:], X_train.iloc[0,:])
SHAP individual explaination for one customer
SHAP individual explaination for one customer

While a SHAP summary plot gives a general overview of each feature a SHAP dependence plot show how the model output varies by feature value. Note that every dot is a person, and the vertical dispersion at a single feature value results from interaction effects in the model. The feature used for coloring is automatically chosen to highlight what might be driving these interactions. Note that the row of a SHAP summary plot results from projecting the points of a SHAP dependence plot onto the y-axis, then recoloring by the feature itself.

shap.dependence_plot("TotalCharges", shap_values,X_train, show=False)
pl.xlim(80,225)
pl.show()
SHAP Partial Dependence Plot
SHAP Partial Dependence Plot

Next lets use XGBoost to predict each class predict each class.

y_pred = model.predict(d_test)
## Make sure y_pred is binary (0 or 1)
y_pred = np.array(model.predict(d_test))
y_pred = y_pred > 0.5 
y_pred = y_pred.astype(int) 
## Predict the probablities of each class
y_pred_proba = model.predict(d_test)

Calculate evaluation metrics

## Traditional Accuracy
XGB_accuracy_score = accuracy_score(y_test, y_pred)
#ROC AUC
XGB_roc_auc_score = roc_auc_score(y_test,y_pred_proba)
## Balanced Accuracy
XGB_balanced_accuracy_score = balanced_accuracy_score(y_test, y_pred)
data = pd.DataFrame(columns= ['metric','value'])
data['value'] = [XGB_accuracy_score,XGB_roc_auc_score,XGB_balanced_accuracy_score]
data['metric'] = ['Accuracy Score','ROC AUC','Balanced_Accuracy']
# plot horizontal barplot
sns.set(rc={'figure.figsize':(10,5)})
ax = sns.barplot(x="value", y="metric", data=data)
ax.set(title='XGBoost Model Performance') # title barplot
# label each bar in barplot
for p in ax.patches:
    height = p.get_height() # height of each horizontal bar is the same
    width = p.get_width() # width 
 # adding text to each bar
    ax.text(x = width, # x-coordinate position of data label, padded 3 to right of bar
    y = p.get_y()+(height/2), # # y-coordinate position of data label, padded to be in the middle of the bar
    s = '{:.00%}'.format(width), # data label, formatted to ignore decimals
    va = 'center') # sets vertical alignment (va) to center

Classification Report

target_names = ['class 0', 'class 1']
print(classification_report(y_test, y_pred, target_names=target_names))
XGBoost Classification Report
XGBoost Classification Report

Compare the three different models.

In this scenario, the best model depends on how the model will be used. If we will be offering expensive promotions to retain these customers and the value of retaining them is not very high, it will be more important to identify true non-churners. On the flip side, if the value of retaining the customers is high and the cost of the communication strategy is low, we are more interested in how well the model identifies the true churners.

We have now built a churn model that predicts the likelihood that a customer leaves within the next month. This is usually where most churn model articles end, but as data scientists we know that the work doesn’t end once the ROC plot looks good. Often times, stakeholders have more questions once they start using the model in production.

Survival Analysis

Imagine if your model has been deployed to production and one of the senior leaders in your organization comes to you and says:

"I love the model we have for churn, but I was interested in knowing when someone would churn. Does your model tell you how long the customer will stay?"

We could say that it does and we simply build a model for every time-interval, but that would take a lot of time. As data scientist, we have to take a step-back and identify that this stakeholder is interested a time-to-event model.

We can answer this question through a type of regression called Survival Analysis.

Survival analysis (regression) models time to an event of interest. Survival analysis is a special kind of regression and differs from the conventional regression task as follows:

The label is always positive, since you cannot wait a negative amount of time until the event occurs.

The label may not be fully known, or censored, because "it takes time to measure time."

Often we have additional data aside from the duration that we want to use. The technique is called survival regression – the name implies we regress covariates against another variable – in this case durations. Unfortunately we cannot use traditional methods like linear regression because of censoring.

Survival analysis can be univariate or mulitvariate, much like other regression problems. In order to build intuition on survival analysis, we will start with univariate.

Kaplan Meier

Kaplan Meier is a non-parametric statistic used to estimate the survival function. It’s usually represented by the Kaplan Meier curve which shows the probability of an event at the respective time intervals.

The Kaplan Meier model takes two inputs, the Time variable T and a variable called E, which is a binary flag representing the event happened or not. Translating this to our problem, we can use the tenure variable as our T and the Churn binary variable as our event E.

## Assign the target variables
T= X_train["tenure"]
E = y_train
T_test = X_test["tenure"]
E_test = y_test
## Create Boolean values of churns. (we will use this for scoring with concordance_index_censored from sksurv libary )
y_train_bolean= y_train.replace(0,False).replace(1,True)
y_test_bolean= y_test.replace(0,False).replace(1,True)

Next we fit the KaplanMeier model and plot the curve.

kmf = KaplanMeierFitter()
kmf.fit(T, event_observed=E)
kmf.plot()
Kaplan Meier Curve
Kaplan Meier Curve

A very good survival analysis library in python is called Lifelines. Using the lifelines libary, we can plot residuals QQ plots for several of the parametric models at the same time.

fig, axes = plt.subplots(2, 2, figsize=(9, 9))
timeline = np.linspace(0, 0.25, 100)
wf = WeibullFitter().fit(T, E, label="Weibull", timeline=timeline)
lnf = LogNormalFitter().fit(T, E, label="Log Normal", timeline=timeline)
# plot what we just fit, along with the KMF estimate
kmf.plot_cumulative_density(ax=axes[0][0], ci_show=False)
wf.plot_cumulative_density(ax=axes[0][0], ci_show=False)
qq_plot(wf, ax=axes[0][1])
kmf.plot_cumulative_density(ax=axes[1][0], ci_show=False)
lnf.plot_cumulative_density(ax=axes[1][0], ci_show=False)
qq_plot(lnf, ax=axes[1][1])

We can also pick the best parametric model using the AIC score.

best_model, best_aic_ = find_best_parametric_model(T, E)
best_model.plot_hazard()
Best model based on AIC
Best model based on AIC

Since we have a couple of additional data points, let’s explore survival multivariate regression. As the name implies we regress covariates against another variable – in this case durations.

Cox’s proportional hazard model

The idea behind Cox’s proportional hazard model model is that the log-hazard of an individual is a linear function of their static covariates and a population-level baseline hazard that changes over time

The lifeline cox model takes two inputs. duration and binary event. We will use tenure and churn. First we need to add our churn data back into our dataset.

X_train['Churn'] = y_train
cph = CoxPHFitter()
cph.fit(X_train, duration_col='tenure', event_col='Churn')
cph.print_summary()

With a fitted model, an alternative way to view the coefficients and their ranges is to use the plot method.

cph.plot()

After fitting, we can plot what the survival curves look like as we vary a single feature while holding everything else equal. We can use this to see the impact of the feature. Let’s use this to dive into the Contract variable.

cph.plot_covariate_groups('Contract', [0, 1, 2], cmap='coolwarm')

XGBoost

XGBoost also has a cox-regression objective that we can use to predict time to churn.

We have to do some data preparation to set the model up for predictions. For XGBoost we need to turn censured events into negative values and keep non-censured events as positive.

We do this by simply replacing our zeros in Churn with -1 and multiplying churn with tenure to create our new target.

X_train['Churn'] = y_train
## Replace the customers who haven't churned with 0
X_train['Churn']= X_train['Churn'].replace(0,-1)
## Create target variable
y_cox_train = X_train['Churn'] * X_train['tenure']
## Drop churn and tenure to avoid leakage
X_train.drop(['Churn', 'tenure'], axis = 1, inplace =True)
## Convert these to xgb data matrix
xgb_train = xgb.DMatrix(X_train,label=y_cox_train)

We need to do the same thing to the test set.

X_test['Churn'] = y_test
## Replace the customers who haven't churned with 0
X_test['Churn']= X_test['Churn'].replace(0,-1)
## Create target variable
y_cox_test = X_test['Churn'] * X_test['tenure']
## Drop churn and tenure to avoid leakage
X_test.drop(['Churn','tenure'], axis = 1, inplace =True)
## Convert these to xgb data matrix
xgb_test = xgb.DMatrix(X_test, label=y_cox_test)

Time to set up the model. We use the survival:cox objetive.

# use validation set to choose # of trees
params = {
    "eta": 0.002,
    "max_depth": 3,
    "objective": "survival:cox",
    "subsample": 0.5
}
model_train = xgb.train(params, xgb_train, 10000, evals = [(xgb_test, "test")], verbose_eval=1000)

To evaluate the model performance, we calculate the C-statistic.

The C-statistic (sometimes called the "concordance" statistic or C-index) is a measure of goodness of fit for binary outcomes in a logistic regression model [3].

The C-statistic gives the probability a randomly selected person who churned had a higher risk score than a person who had not churned. The C-statistic should be interpreted similar to an AUC score.

  • A value below 0.5 indicates a very poor model.
  • A value of 0.5 means that the model is no better than predicting an outcome than random chance.
  • Values over 0.7 indicate a good model.
  • Values over 0.8 indicate a strong model.
  • A value of 1 indicates a perfect model.
def c_statistic_harrell(pred, labels):
    total = 0
    matches = 0
    for i in range(len(labels)):
        for j in range(len(labels)):
            if labels[j] > 0 and abs(labels[i]) > labels[j]:
                total += 1
                if pred[j] > pred[i]:
                    matches += 1
    return matches/total
# see how well we can order people by survival
c_statistic_harrell(model_train.predict(xgb_test, ntree_limit=5000), y_cox_test.array)

Out[373]: 0.9886022111068725

Our C-Statistic is very high. We should take a look at our feature importance plots using SHAP Summary Plot to see what variables are most predictive.

shap_values = shap.TreeExplainer(model_train).shap_values(X_train)
shap.summary_plot(shap_values, X_train)
SHAP Summary Plot
SHAP Summary Plot

It appears the charges variables are the most important features. We should confer with a domain expert to understand if those features are leading indicators or potential leakage variables before going to production with our survival analysis.

Conclusion

To help understand customer attrition and churn, we can deploy different types of modeling techniques. In this article, we explored binary classification and survival analysis.

Binary classification requires us to define a binary target with a specific timeframe for the churn event to take place. For classifications tasks, there are multiple modeling approaches with varying levels of explainability. In this article, we explored logistic regression, a GA2M variation called EBM and a gradient boosting technique called XGBoost. Each of these models come with different trade-offs, often times between accuracy and transparency.

The real benefit of using classification is that it’s relatively easy to understand. A lot of people can understand when we talk about churn within the next month etc. Even if we dive into probabilities, many people still understand it. The drawback of the classification approach is that we are only predicting churn within that timeframe. In many practical cases, we may want to know how long until this customer churns. With the classification approach, we would have to build many models to handle this question, which would ultimate confuse end-users.

Survival analysis allows us to model time to event for censored data. In churn cases, the models tend to be right censored, so our approach is to use a cox regression model. We explored two libraries and found that we the more complex XGBoost model outperforms the simpler cox model from the lifelines library.

Survival analysis is a bit more complex to understand than classification analysis, but it is the best way to model time to event. The big benefit is that we are able to understand the how soon the customer will leave.

References

[1] Statistical Classification,Wikipedia

[2] XGBoost Documentation

[3] Lundberg, Scott M and Lee, Su-In, A Unified Approach to Interpreting Model Predictions (2017), Advances in Neural Information Processing Systems 30

[4] Survival Analysis with Acellerated Failure Time, XGBoost Documentation

[5] Cameron Davidson-Pilon, Jonas Kalderstam, Noah Jacobson, Sean Reed, Ben Kuhn, Paul Zivich, … Dave Golland. CamDavidsonPilon/lifelines: v0.25.7 (Version v0.25.7). Zenodo. (2020, December 9).

[6] H. Nori, S. Jenkins, P. Koch, and R. Caruana, InterpretML: A Unified Framework for Machine Learning Interpretability (2019)


Related Articles