PySpark & AWS | Predicting Customer Churn

Patrick De Guzman
Towards Data Science
9 min readOct 23, 2019

--

What do you call a bunch of distributed dinosaur-breeding islands?

Jurassic Spark... *Ba Dum Tss*

One of the most important concerns for companies with subscription-based business models is customer churn. Customers downgrade or discontinue service for various reasons, and the service provider often cannot know when or why customers leave until they leave!

What if the company could anticipate when customers churn before they leave?

This would be powerful! If we can reliably predict whether a customer is likely to churn, we have the chance to retain these customers by intervening with promotions, communicating new features, etc. This is a proactive approach to retaining customers, as opposed to a reactive approach of getting back lost customers.

Udacity provided two similarly-formatted user activity datasets (128MB and 12GB) from a fictional music streaming company, Sparkify, and I used this data to better understand Sparkify’s customers, then predict whether a customer will churn with over 80% accuracy.

There’s no way I could analyze and model the 12GB dataset on my machine locally, so I explored and prototyped my workflow on the 128MB dataset using PySpark’s local mode in a Jupyter Notebook, which lets you simulate working on a distributed cluster on one machine.

Then, I turned the workflow into a Python script, spun up a 4-node AWS EC2 cluster, and executed the script through AWS EMR on the cluster.

Here’s what I’m going to cover in this article:

  1. Feature Engineering & Exploring the Data locally
  2. Data Preprocessing locally
  3. Machine Learning locally
  4. Running steps 1–3 on an AWS Distributed Cluster

Who is this article for?

This article is intended to walk you through the steps I took and some of my useful code along the way that took me from a small Jupyter notebook analysis to a 12GB dataset analysis on a 4-node AWS EMR cluster. So, I’m assuming you’re familiar with Python & Machine Learning.

Also, in this post, I show a lot of code snippets of me using PySpark, which is very different from processing data through Pandas and Scikit-Learn libraries. If you’re not familiar with PySpark, just skim it! No big deal.

Finally, if you want a more detailed analysis and walkthrough of how the analysis and machine learning works through my code, check out my full Jupyter notebook (hosted on my website) or GitHub Repo.

I hope you enjoy this!

Now Let’s Get Started…

I. Feature Engineering & Exploring the Data

First, I converted the 128MB dataset from JSON format to a Spark DataFrame. Here’s the schema so you know how the data is structured:

Original imported features.

Date Feature Engineering | After some brief exploring, I used the “ts” (timestamp) attribute to create some date features using PySpark UDF’s (user-defined functions).

I created a count plot for the new date features to better understand how the users behave on a high level throughout the period of the dataset.

Note this smaller dataset (128MB) only contains data for about two months.

Events Feature Engineering | I used the “page” feature to flag whenever a user visits a particular page or performs a particular action. These flag features will later help me aggregate user activity on a user-level rather than an event-level.

This is identical to one-hot encoding based on this feature. In hindsight, this would have been faster with PySpark’s native OneHotEncoder. Oh well! It works.

Running Count Feature Engineering | I also leveraged the transactional nature of the data (i.e. one row for each event performed by some user at a point in time) to create some window-calculated features (e.g. running counts of listens within a month, etc.) Here’s an example of how features like this were created.

Creating a window for each user per month, counting rows whenever a user listens, and joining back to DF.

This creates a column that looks like the fourth one in this picture:

Last Column: Note the null values

Since this is a running count, null values should be front-filled. However, given the distributed nature of Spark’s Resilient Distributed Datasets (RDDs), there is no native function to front-fill null values. To explain, if a dataset is randomly partitioned among several machines in the cloud, how does the machine know when the preceding value is the true immediate preceding value in the dataset?

My best workaround was to create a lag of the running listens column, then replace null values with the lag iteratively. Consecutive null values need this exercise performed several times because we’re effectively replacing the current null value with the previous value (which is null for consecutive null values). Here’s how I implemented it:

Sharing this snippet in case anyone finds it useful. This is the only way I found to front-fill null values that was computationally cheaper than using a complicated cross-join.

The above snippet front fills null values with its preceding value 6 times, then fills the rest with 0 to save computation time. This may not be the best way to front-fill, but this may be useful if you’re filling a large dataset with relatively few null values. This works because this operation requires Spark to join the partitions to calculate lag (i.e. the immediately preceding value).

Defining Churn | Finally, I defined customer churn as whenever a user visits the “Cancellation Confirmation” or “Downgrade” page.

This labels the “Churn” column of the DataFrame as 1 if a user visits the aforementioned pages.

II. Data Preprocessing

Now let’s take a step back and think about what we have and what we want to achieve.

What we have: a dataset of “events” specifying what activity a user performs at a given point in time.

What we want: predict customer churn.

Intuitively, I don’t think it’s a great idea to predict churn on the dataset that we have in the format that we have it in. It would be very expensive to train and predict based on user events, which can be very, very expensive since you can have thousands of events every hour (or more!)

So, my approach was to simply aggregate the data for each user. Specifically:

I used the transactional, event-level data to create a metric-based, user-level matrix that essentially summarizes each user’s activity in a matrix that has only as many rows as unique users.

This may not translate well from text, so let me just show you what I did. Here’s an example of code aggregating to see total activities performed by each user:

This type of aggregation helps me create a neat DataFrame and Visualization like the ones below…

Truncated DF | One row per user with user-specific summary statistics.
Using the summary statistics to compare our churned users vs. our non-churned users.

Interestingly, it appears our churned users were much more active than our non-churned users.

Now that I have the data in a nice, tight matrix, I used this as the basis for the machine learning models.

III. Machine Learning

My approach to modeling was simple and straightforward: try a few algorithms, pick the one that seems most promising, then tune the hyperparameters for this model.

Here’s a schema of the user matrix so you know what’s being modeled:

One user and several metrics/aggregations for this user per row.

PySpark ML requires data to be in a very particular DataFrame format. It needs its features in a column of vectors, where each element of the vector represents the value for each of its features. It also requires that its labels are in its own column.

The below code transforms the original matrix into this ML-friendly format and standardizes the values so they’re on the same scale.

Now that we have our data ready, the below code shows how I initially evaluated the models. I arbitrarily picked Logistic Regression, Random Forest, and Gradient Boosted Trees as candidates for my final churn model.

After initial model evaluation, I found that the GBTClassifier (Gradient Boosted Trees) performed the best with an accuracy of 79.1% and an F-1 score of 0.799.

So, I tuned this model using PySpark’s native CrossValidator and ParamGridBuilder to grid search, which chooses the best hyperparameters using K-Fold validation. Here, I used a small grid of hyperparameters and 3-fold validation because of expensive computation time.

After tuning then re-evaluating the model, accuracy and F-1 increased to 82.35% and 0.831 respectively! The best model had a max depth of 3 and max bins of 16. I would have tried a more extensive grid search if this didn’t take so long to run on my laptop (it took about 1 hour).

PySpark automatically calculates feature importances for these tree-based models.

Note: For the curious, here’s an interesting write-up explaining how feature importance is calculated, and why it’s actually not that accurate. This isn’t really within the scope of this project, so I’m just glossing over this.

Feature Importances calculated for the tuned GBTClassifier.

It appears the top three most important features to predict churn are:

  • Average song plays per session
  • Total thumbs ups
  • Most consecutive days not playing songs

It also looks like metrics related to error and total song plays were completely irrelevant to predicting churn.

IV. AWS Cluster Deployment

Up until now, I’ve been describing the analysis I performed on the small, 128MB dataset using Spark’s local mode on my laptop.

To run the 12GB dataset through this same analysis, I needed a lot more processing power. That’s where AWS Elastic Map Reduce (EMR) and Elastic Cloud Compute (EC2) come in.

Note: EC2 is what lets you work on machines in the cloud. EMR is what you use to easily manage EC2 clusters that already have Spark installed.

After setting up EMR and access to AWS through my terminal, I uploaded my script to my cluster’s Hadoop Distributed File System (HDFS), loaded data from Udacity’s S3 bucket containing the 12GB dataset, and ran the script on the master node from my terminal.

Note: here’s what I used to set up EMR and here’s what I used to set up my access to AWS on my terminal.

Don’t even try to read this. Just FYI: this is what it looks like when you run a Spark app through EMR from your CLI/terminal.

The script takes a long time to finish (it took me about 5 hours on 4 machines). Just so all this hard work doesn’t just go to waste, I had the model and pre-processed user matrix saved at the end of the script.

In case you’re curious, running this analysis and workflow on the full 12GB dataset resulted in a very high Accuracy and F-1 Score!

This was a HUGE improvement from the model trained on 128MB dataset.

V. Wrap-Up

A lot more features can be engineered from user activity, such as thumbs-ups per day/week/month, thumbs-ups to thumbs-downs ratio, etc. Feature engineering can improve results better than simply optimizing one algorithm.

Thus, further work can be done extracting more features from our transactional user data to improve our predictions!

Once a model is created, perhaps it can be deployed in production and run every x-amount of days or hours. Once we have a prediction on a user that is likely to churn, we have an opportunity to intervene!

To evaluate how well this hypothetically deployed model does, we can run some proof-of-concept analysis and not intervene on its predictions for a given testing period. If the users it predicts will churn end up churning at a higher rate than the average user tends to churn, this can indicate that our model is working correctly!

VI. Concluding Thoughts

Learning PySpark and AWS seems like a nightmare (I know from this experience). However, if you’re already familiar with Python and Machine Learning, you don’t need to know that much more to get started with PySpark and AWS.

You know how to process data. You know how machine learning works. The other half of the battle is knowing how to perform these same tasks in PySpark and on an AWS cluster, which you can find out via Google searches and tutorials! I suggest walking slowly through my (or others’) Jupyter notebook so you have a handle on PySpark, then looking up how to run a Spark app on AWS EMR.

It seems like a lot of work, but there is no way in heck I could’ve worked with the 12GB dataset with my laptop. Mind-boggling amounts of data get created every day, and if you want to be able to wrangle that much data, Spark is a great tool to learn!

--

--