A Tutorial Using Spark for Big Data: An Example to Predict Customer Churn

Ying Geng
Towards Data Science
9 min readJun 4, 2020

--

Apache Spark has become arguably the most popular tool for analyzing large data sets. As my capstone project for Udacity’s Data Science Nanodegree, I’ll demonstrate the use of Spark for scalable data manipulation and machine learning. Context-wise, we use the user log data from a fictitious music streaming company, Sparkify, to predict which customers are at risk to churn.

The full data set is 12GB. we’ll first analyze a mini subset (128MB) and build classification models using Spark Dataframe, Spark SQL, and Spark ML APIs in local mode through the python interface API, PySpark. Then we’ll deploy a Spark cluster on AWS to run the models on the full 12GB of data. Hereafter, we assume that Spark and PySpark are installed (a tutorial for installing PySpark).

Set up a Spark session

Before we are able to read csv, json, or xml data into Spark dataframes, a Spark session needs to be set up. A Spark session is a unified entry point for Spark applications from Spark 2.0. Note that prior to Spark 2.0, various Spark contexts are needed to interact with Spark’s different functionalities (a good Medium article on this).

# Set up a SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("capstone").getOrCreate()

The Dataset

# Load data and show basic data shape
path = "mini_sparkify_event_data.json"
df = spark.read.json(path)

Now that the mini Sparkify user log data set is in the Spark dataframe format, we can do some initial exploration to get familiar with the data. Spark dataframe and Spark SQL modules have methods such as the following: select(), filter(), where(), groupBy(), sort(), dropDuplicates(), count(), avg(), max(), min(). They also have Window functions that are useful for basic analysis (see documentation for syntax). To summarize the data:

  1. The dataset has 286500 rows and 18 columns.
  2. It spans the time period from 2018–9–30 to 2018–12–02.
  3. It records each event users did during this time period as a row.
  4. Column definitions are as follow:
 -- artist (string): artist's name for a song
-- auth (string): Logged Out | Cancelled | Guest | Logged In
-- firstName (string): user's firstname
-- gender (string): Female | Male
-- itemInSession (long) : number of items in a session
-- lastName (string): user's lastname
-- length (double): a song's length in seconds
-- level (string): paid | free
-- location (string): city and state of the user
-- method (string): HTTP method
-- page (string): which page a user is on at an event
-- registration (long): timestamp of user registration
-- sessionId (long): the Id of the session a user is in at an event
-- song (string): song name
-- status(long): 307 | 404 | 200
-- ts (long): timestamp ateach event
-- userAgent (string) :
-- userId (string): user ID

5. Some of these columns probably are not very useful for prediction, such as firstName, lastName, method, and userAgent. Categorical features need to be encoded, such as gender and level. Some numerical features will be useful for engineering aggregated behavior features, such as itemInSession, length, page visits, etc.

6. Class is imbalanced; we need to consider stratified sampling when we split training test data. We also should consider the f1 score over the accuracy for our model evaluation metrics.

7. For models, we will try Logistic Regression, Decision Tree, Random Forest, and Gradient Boosted Trees.

With these initial thoughts, let’s proceed with handling missing values.

Dealing with Missing Value

The seaborn heatmap is a good way to show where missing values are located in the dataset and whether data is missing in some systematic way.

# Let's take a look at where the missing values are located.
plt.figure(figsize=(18,6))
sns.heatmap(df.toPandas().isnull(),cbar=False)

Note (1): From the heatmap, we can see that firstName, lastName, gender, location, userAgent, and registration are missing at same rows. We can infer that these missing values come from users that are not registered. Usually, users who are not registered would not have a user ID. We’ll explore this further.

df.select(‘userid’).filter(‘registration is null’).show(3)

It turns out that the userId column actually has missing values, but they are coded as just empty, instead of coded as a “NaN”. The number of such null values match the number of missing rows in registration. Since these records do not even have userId information, we’ll go ahead and delete them.

Note (2): Also, artist, length, and song are missing data at the same rows. These records do not have song-related information. We’ll explore what pages users are on for these rows.

print(df_pd[df_pd.artist.isnull()][‘page’].value_counts())
print(df_pd[df_pd.artist.isnull()==False][‘page’].value_counts())
Thumbs Up 12551
Home 10082
Add to Playlist 6526
Add Friend 4277
Roll Advert 3933
Logout 3226
Thumbs Down 2546
Downgrade 2055
Settings 1514
Help 1454
Upgrade 499
About 495
Save Settings 310
Error 252
Submit Upgrade 159
Submit Downgrade 63
Cancel 52
Cancellation Confirmation 52
Name: page, dtype: int64
NextSong 228108
Name: page, dtype: int64

Feature Engineering and EDA

Based on intuition and domain knowledge, we decide not to include the columns for firstName, lastName, method, and userAgent in our first-pass modeling for now, since these variables probably do not affect our prediction. We also decide not to include artist, location, song, and status for now. This leaves us with the following columns:

 -- gender (string): Female | Male
-- itemInSession (long) : number of items in a session
-- length (double): a song's length in seconds
-- level (string): paid | free
-- page (string): which page a user is on at an event
-- registration (long): timestamp of user registration
-- sessionId (long): the Id of the session a user is in at an event
-- ts (long): timestamp ateach event
-- userId (string): user ID

1) Define Churned User: We can see that there is approximately a 1:3 class imbalance.

flag_cancellation = udf(lambda x : 1 if x=="Cancellation Confirmation" else 0, IntegerType())
df = df.withColumn("churn",flag_cancellation("page"))
# Create the cross-sectional data that we’ll use in analysis and modelling
w1 = Window.partitionBy(‘userId’)df_user = df.select(‘userId’,’churn’,’gender’,’level’) \
.withColumn(‘churned_user’,Fsum(‘churn’).over(w1)) \
.dropDuplicates([‘userId’]).drop(‘churn’)
df_user.groupby(‘churned_user’).count().show()+------------+-----+
|churned_user|count|
+------------+-----+
| 0| 173|
| 1| 52|
+------------+-----+

2) Categorical features: For categorical features, we need to first label encoding (simply converting each value to a number). Depending on the machine learning models, we may need to further encode these numbers to dummy variables (e.g., one-hot encoding).

In Spark, StringIndexer does the label encoding part:

indexer = StringIndexer(inputCol="gender",outputCol="genderIndex")
df_user = indexer.fit(df_user).transform(df_user)
indexer = StringIndexer(inputCol="level",outputCol="levelIndex")
df_user = indexer.fit(df_user).transform(df_user)
df_user.show(3)
+------+------+-----+------------+-----------+----------+
|userId|gender|level|churned_user|genderIndex|levelIndex|
+------+------+-----+------------+-----------+----------+
|100010| F| free| 0| 1.0| 0.0|
|200002| M| free| 0| 0.0| 0.0|
| 125| M| free| 1| 0.0| 0.0|
+------+------+-----+------------+-----------+----------+
only showing top 3 rows

Let’s take a look at how gender and level are related to churn. By looking at the simple statistics, it seems that a larger percentage of male users tend to churn than female users and that a larger percentage of paid users tend to churn than free users.

df_user.groupby(‘genderIndex’).avg(‘churned_user’).show()+-----------+-------------------+
|genderIndex| avg(churned_user)|
+-----------+-------------------+
| 0.0| 0.2644628099173554|
| 1.0|0.19230769230769232|
+-----------+-------------------+
df_user.groupby('churned_user').avg('levelIndex').show()+------------+-------------------+
|churned_user| avg(levelIndex)|
+------------+-------------------+
| 0|0.23121387283236994|
| 1|0.15384615384615385|
+------------+-------------------+

Since we will utilize Logistic Regression and SVM classifier, we will need to convert label encoding to dummy variables. OneHotEncoderEstimator() does this part:

encoder = OneHotEncoderEstimator(inputCols=[“genderIndex”, “levelIndex”], outputCols=[“genderVector”, “levelVector”])
model = encoder.fit(df_user)
df_user = model.transform(df_user)
df_user.select('genderVector','levelVector').show(3)
+------+-------------+-------------+
|userId| genderVector| levelVector|
+------+-------------+-------------+
|100010| (1,[],[])|(1,[0],[1.0])|
|200002|(1,[0],[1.0])|(1,[0],[1.0])|
| 125|(1,[0],[1.0])|(1,[0],[1.0])|
+------+-------------+-------------+
only showing top 3 rows

The output columns of OneHotEncoderEstimator() is not the same as sklearn’s output. Instead of binary values, it gives this sparse vector format as shown in the above code snippet.

3) General activity aggregates: Based on the columns for sessionId, song, artist, length, and registration, we generate aggregated features including:

  • numSessions (number of sessions a user had during this period)
  • numSongs (number of different songs a user listened to)
  • numArtists (number of different artists a user listened to)
  • playTime (total time of playing songs measured in seconds)
  • activeDays (number of days since a user registered)

4) Page visits aggregates: Based on the page column, we generate aggregated page visit behavior features that count how many times a user visited each type of pages during the period.

w2 = Window.partitionBy('userId','page')
columns = [str(row.page) for row in df.select('page')\
.dropDuplicates().sort('page').collect()]
df_pageVisits = df.select('userId','page')\
.withColumn('pageVisits',count('userId').over(w2))\
.groupby('userId')\
.pivot('page',columns)\
.mean('pageVisits')
df_pageVisits = df_pageVisits.na.fill(0).drop(['Cancel','Cancellation Confirmation'],axis=1)

5) Check for multicollinearity: Tree-based models would not be affected by multicollinearity, but, since we are also testing linear models (logistic regression and svm), we’ll go ahead and remove highly correlated features.

Correlation heatmap for all features
Correlation heatmap after removing highly correlated features

6) Vector Assembling and Feature Scaling: In Spark, machine learning models require features to be a vector type. The VectorAssembler() method converts all the feature columns into one vector, as shown in the following code snippet.

# Vector Assembler
cols = df_inuse.drop('userID','churned_user').columns
assembler=VectorAssembler(inputCols=cols,outputCol='feature_vector')
df_inuse=assembler.transform(df_inuse).select('userId','churned_user','feature_vector')
df_inuse.take(1)[Row(userId='100010', churned_user=0, feature_vector=SparseVector(13, {0: 1.0, 2: 52.0, 6: 2.0, 7: 7.0, 8: 11.4259, 9: 1.0, 12: 1.0}))]

The scaled data looks like this:

df_inuse_scaled.take(1)[Row(userId='100010', label=0, feature_vector=SparseVector(13, {0: 1.0, 2: 52.0, 6: 2.0, 7: 7.0, 8: 11.4259, 9: 1.0, 12: 1.0}), features=SparseVector(13, {0: 0.3205, 2: 2.413, 6: 0.7817, 7: 0.4779, 8: 0.3488, 9: 2.0013, 12: 2.4356}))]

7) Split data into training and testing sets:

ratio = 0.8
train = df_inuse_scaled.sampleBy(‘churned_user’,fractions={0:ratio,1:ratio}, seed = 42)
test = df_inuse_scaled.subtract(train)

Model Selection

We will compare five baseline models: Logistic Regression, Linear SVM Classifier, Decision Tree, Random Forests, and Gradient Boosted Tree Classifier.

# initiate the models
lr = LogisticRegression()
svc = LinearSVC()
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier()
gbt = GBTClassifier()

The ParaGridBuilder() class can be used to construct a grid of hyper-parameters to search over. However, since the purpose here is to show Spark’s ML methods, we will not do an in-depth tuning of the model here.

# this line will keep the default hyper-parameters of a model
paramGrid = ParamGridBuilder().build()
# to search over more parameters, we can use the ,,addGrid() method, for example:paramGrid = ParamGridBuilder()\
.addGrid(lr.regParam, [0.1, 0.01]) \
.addGrid(lr.fitIntercept, [False, True])\
.addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\
.build()

We’ll define an evaluation function to run through all five classification models and output their cross validation average metrics (f1).

def evaluate(model_name,train,test):
evaluator = MulticlassClassificationEvaluator(metricName=’f1')
paramGrid = ParamGridBuilder().build()
crossval = CrossValidator(estimator=model_name,\
evaluator=evaluator, \
estimatorParamMaps=paramGrid,\
numFolds=3)
cvModel = crossval.fit(train)
cvModel_metrics = cvModel.avgMetrics
transformed_data = cvModel.transform(test)
test_metrics = evaluator.evaluate(transformed_data)
return (cvModel_metrics, test_metrics)

Finally, the performance of the five baseline models are as shown in the following code snippet. As we can see, the f1 scores for all of the models are unsatisfactory. We certainly need finer tuning to search for optimized hyper-parameters for these models!

However, if we were to choose from these baseline models, the cross validation models’ f1 score should be the criterion. In this case, the LinearSVC model will be the model of choice. (Note that the test score is worse than the score on the training data, indicating over-fitting).

model_names = [lr,svc,dtc,rfc,gbt]for model in model_names:
a = evaluate(model,train,test)
print(model,a)
LogisticRegression ([0.6705811320138374], 0.6320191158900836)
LinearSVC ([0.6765153189823112], 0.6320191158900836)
DecisionTreeClassifier ([0.6382104034150818], 0.684376432033105)
RandomForestClassifier ([0.666026954511646], 0.6682863679086347)
GBTClassifier([0.6525712756381464], 0.6576482830385015)

Deploy on Cloud (AWS)

To run the model on the full 12GB data on AWS, we’ll use basically the same codes, except the plotting portion done by Pandas is removed. There is one point that is worth noting though: following the Spark course’s instruction on configuring the cluster in the Nanodegree’s extracurricular material, I was able to run the codes; however, the session goes inactive after a while. This is likely due to insufficient spark driver memory. Therefore, we need to go with the advanced option in configuring a cluster and increase the driver memory.

Step 1: In EMR console, create new cluster
Step 2: Go to advanced options
Step 3: In the box, enter the desired configuration

Conclusion

This project provides an excellent opportunity to learn manipulating large datasets with Spark as well as AWS, which are among the highest-demand skills in the field of data science.

Topic-wise, predicting customer churn is a challenging and common problem that data scientists and analysts face in any customer-facing business. The analysis and modeling completed here is to highlight the process of a machine learning project using Spark. There is certainly large room to improve the model performance:

1) The aggregated behavior features are simply summations and averages. Weighted averages could be used to emphasize more recent behaviors. Diversity measurements can also be included. Finer hyper-parameter tuning are needed.

2) Since the dataset is longitudinal, we could potentially use survival models or time series models. Here are some articles on the these strategies for modeling customer churn.

--

--