The world’s leading publication for data science, AI, and ML professionals.

Churn prediction with Spark

Churn prediction is crucial to customer retention and one of the primary keys for success as it helps businesses proactively reach out to…

Photo by Randy Fath on Unsplash
Photo by Randy Fath on Unsplash

For the capstone project of Udacity’s Data Scientist Nanodegree, I attempted to build a machine learning pipeline for Churn Prediction for Sparkify, a fictional music streaming service. The code I wrote for this project can be found in this repository.

Sparkify fictional logo provided by Udacity
Sparkify fictional logo provided by Udacity

Sparkify has collected a large database on the activities of thousands of unique users on its platform. As a growing service, Sparkify will continue to generate even more data in the future (presuming it exists, of course). In this case, Spark proves to be a suitable tool as it offers the much-needed scalability.

Problem Statement

The large dataset provided by Sparkify (~12GB) containing detailed information about user’s activities. In particular, each row represented a user activity, including basic information about the user (user ID, first name, last name, gender, their subscription tier, whether or not they had logged in…) and further information about the nature of the activities (for example, the timestamp of the activity, the name of the page they visited, or the name of the artist, song and the length of the song they requested). Each unique user can have multiple activities recorded in the dataset.

Given the dataset, the main tasks were to clean the data (if necessary), perform data exploration and analysis, feature extraction, and finally, implement a Machine Learning model to classify churned users. Ideally, when given a list of user’s activities, the solution model should be able to predict whether they were churners or not (without looking at the cancelation events in the dataset, of course).

The models were evaluated based on their F1 score rather than accuracy because in this case, the number of churned users was small comparing to the total number of users on the platform and the data was unbalanced.

Setup

For this project, I used Python 3.7 and Spark 3.0.1 on IBM Watson Studio, which offers 2 Executors: 1 vCPU and 4 GB RAM and a Driver: 1 vCPU and 4 GB RAM with the Lite plan. Due to hardware capacity, I only worked with a subset (~200MB) of the original dataset (~12GB).

Data Exploration

The sub-dataset contains 543705 observations, each represents an activity that belongs to one of the 448 unique users, and 18 features.

The Data schema - Image by author
The Data schema – Image by author

The page column is particularly interesting as it contains the type of user activity recorded.

Unique values of page - Image by author
Unique values of page – Image by author

Note that there is no missing data in this column and NextSong is the most popular activity, which is understandable as Sparkify is a music streaming service.

However, the number of Upgrade and Downgrade is huge comparing to the number of Submit Upgrade, Submit Downgrade, Cancel, Cancellation Confirmation, and even the number of unique users. Hence, we may guess that Upgrade and Downgrade are the pages where users can choose to upgrade or downgrade their subscriptions, not the actual events of doing so.

Clean dataset

The next step is to clean the dataset. Looking at the number of missing data in each column, there are two groups:

Number of missing values in each column - Image by author
Number of missing values in each column – Image by author
  1. artist, length, and song are not null (contain some values) if and only if page type is NextSong.
  2. firstName, lastName, gender, location, registration, and userAgent are missing when users are not logged in or not registered. In fact, userId is also missing in such cases but the missing values are disguised as empty strings. Nevertheless, these users can not access the primary functionalities of Sparkify and thus, we can safely drop all these observations.

Define churn

Next, I used the Cancel and Cancellation Confirmation to define churn. In particular, any user who has at least one of these events is considered as a churned user. I also checked that these events actually coincide and there are 99 churned users, which is around 22% of total users in the dataset.

Visualize data

After cleaning the dataset and defining churn, I used some graphs to answer 4 questions that I believed to be crucial to churn prediction and user’s behavior.

What is the relationship between user engagement and user retention?

Total number of events/interactions per user - Image by author
Total number of events/interactions per user – Image by author

We can see that churn users are slightly less engaging with the service.

Is timestamp useful in this analysis?

Image by author
Image by author
Image by author
Image by author

Absolutely yes, last interaction and user lifetime are clearly different between user groups. We can also note that longtime users are more loyal to the service.

Do free users cancel their services?

Image by author
Image by author

Unfortunately, also yes. In fact, free users have approximately the same rate of cancellation as paid users.

What can user activity tell us about attrition?

Image by author
Image by author

Generally speaking, canceled users have less Thumbs Up, less NextSong, less Add to Playlist, more Thumbs Down, more Roll Advert, and more Downgrade. We can see that overall, these users didn’t enjoy the service as much.

Note: Remember that Downgrade is not the actual account downgrading event but probably the page where users can choose to downgrade their account.

Modeling

Feature selection

I grouped the dataset by userId and selected 13 features which I believed were most relevant:

  • Total number of interactions
  • Registration timestamp
  • Last interaction timestamp
  • User lifetime
  • Average interactions during lifetime
  • Level (1 means paid and 0 means free)
  • Percentage of Thumbs Up
  • Percentage of Thumbs Down
  • Percentage of Home
  • Percentage of NextSong
  • Percentage of Add to Playlist
  • Percentage of Roll Advert
  • Percentage of Downgrade

Then, the resulting dataset is split into train and test datasets (90% and 10% respectively).

Build pipelines

After obtaining the train and test datasets, I built three machine learning pipelines for two different classification methods: Logistic regression and random forest. Each pipeline performed three main tasks:

  • Assemble selected features into vectors, required by PySpark’s input schema.
  • Standardized the vectors (scaled mean to 0 and standard deviation to 1) to put all features on a similar scale. There are also other scaling methods; however, standardization is the most robust to outliers.
  • Performed grid-search and used F1 score to choose the best hyperparameters.

The hyperparameters used in grid-searches are as follow:

For LogisticRegression:

  • regParam: [0.01, 0.1]

For RandomForestClassifier:

  • maxDepth (the maximum depth of a tree): [12, 30]
  • numTrees (the total number of tree): [20, 100]

Results

The best logistic regression model was able to reach an F1 score of 0.868 on the train dataset and predicted correctly 33 out of 36 users in the test dataset (F1 score 0.912 and accuracy 0.917) with regParam=0.01.

With a random forest of maxDepth 12 and numTrees 100, I was able to achieve a better result on the train dataset (F1 score = 0.887). The model produced the same prediction on the test dataset (33 correct out of 36, F1 score 0.912 and accuracy 0.917).

Prediction results of both methods on the test dataset - Image by author
Prediction results of both methods on the test dataset – Image by author

However, this performance on the test dataset did not imply that the random forest model was not better. The test dataset containing only 36 users was simply not big enough to be used for effectively comparing different models.

I would argue that random forest is better than logistic regression in this situation as it can group a range of data (for example, churn users’ lifetime are usually between 1 and 3 billions ‘time unit’) while logistic regression can only classify data based on comparison to a value (either greater than or less than that value). In order to differentiate a range of data with logistic regression, we need to add hidden layers, thus, use deep learning. Therefore, I believe that the random forest model is a better choice.

Moreover, a great thing about random forest is that we can actually see how each feature contributed to the result by looking at the feature importance:

Image by author
Image by author

Just as guessed in the data visualization part above, the user subscription level did not contribute much while the last interaction and user lifetime were crucial to the model’s performance. Hence, to improve customer retention, one of the simplest and most effective ways is to target users who have not been active recently (lower values of last interaction) for special offers and discounts. These discounts can also be offered when users click on the cancel page.

Possible further improvements

  • Using last interaction may not be the best idea since we may want to know who doesn’t enjoy Sparkify before they even want to cancel. Instead, we can target people with fewer engagements, more Thumbs Down, more Roll Adverts, and who visit the Downgrade page more often.
  • In this project, I did not use the artist and song users listened to, or the location and gender of users. These features might be useful for a better model.

Conclusion

In summary, I have explored the dataset through visualization and analysis. I also cleaned, transformed the data and extracted relevant features about users on Sparkify. From these features, I built different machine learning pipelines using logistic regression and random forest and compared the results. The best random forest model reached an F1 score of 0.887 on the train dataset and classified correctly 33 out of 36 users in the test dataset.

Through this project, I learned Spark and was able to apply my knowledge to a practical problem. I also concluded my journey with the Data Scientist Nanodegree on Udacity with a nice capstone project.


Related Articles