Seeing the random forest from the decision trees: An explanation of Random Forest

Michelle Jane Tat
Towards Data Science
10 min readApr 16, 2017

--

So many bagged decision trees, so little processing resources.

Over this past weekend, I got a little bored and decided to brush up on my R a bit. I have been programming in Python almost exclusively as a fellow at Insight, but I actually not had done much predictive analytics in R, save for pretty vanilla linear regressions. I wanted a somewhat clean data source where I can play around with modeling a bit in R. Thus, a good source of clean data is good ole’ kaggle. I decided to work on a Video Game Sales data set.

Decisions, decisions…

Decision trees, and their cousins like bagged decision trees, random forest, gradient boosted decision trees etc., are commonly referred to as ensemble methods.

To understand more complicated ensemble methods, I think it’s good to understand the most common of these methods, decision trees and random forest. Let take the simplest example: regression using decision trees. For a given data set with n-dimensions, a you can grow a ‘decision tree’ with n-branches, and n-leaves. The goal of a decision tree is to determine branches that reduce the residual sums of squares the most, and provide the most predictive leaves as possible. Perhaps a figure will help…

Taken from the book Introduction to Statistical Learning

The figure above represents a baseball related data set, where we want to determine the log salary of a player. On the left figure, if a player has less than 4.5 years experience, they are predicted to make 5.11 thousands of dollars. If a player has greater than 4.5 years experience, but fewer than 117.5 hits, they are predicted to make 6 thousands of dollars (again, log based). In the data on the right, the predicted values represent the subspaces R1, R2, and R3 respectively.

For those who like the see the math: decision trees essentially grow branches that reduce the sum of the errors for Rj sub-spaces in the data.

The above example uses continuous data, but we can extend this to classification. In a classification setting, we are essentially growing branches that reduce classification error, although it’s not as straightforward as that. In the classification setting, we take an entropy-like measure, and try to reduce the amount of entropy at each branch to provide the best branch split. The Gini Index is a commonly used metric.

p-hat mk represents the proportion of observations in the mth region from the kth class. In essence, the Gini index is a measure of variance. The higher the variance, the more mis-classification there is. Therefore lower values of the Gini Index yield better classification.

Bagging those predictions…

Decision trees are commonly referred to as being “greedy”. This is simply a function of how the algorithm tries to determine the best way to reduce error. Unfortunately, this leads to model over-fitting and model over generalization.

One method used to combat this is called bootstrap aggregation or ‘bagging’ for short. If you understand the idea of bootstrapping in statistics (in terms of estimating variance and error of an unknown population), the bagging is similar when it comes to decision trees.

In bagging, we decide how many repeated bootstraps we want to take from our data set, fit them all to the same decision tree, then aggregate them back together. This gives us a more robust result, and is less prone to over fitting.

Further, typically one third of the sample is left out of each bagged tree. We can then fit the bagged tree to the that sample, and obtain out-of-bag error rates. This essentially is a decision trees version of cross-validation, although you could perform cross-validation on top of out of bag error rates!

Enter Random Forest

Now that we have a general understanding of decision trees and bagging, the concept of random forest is relatively straightforward. A vanilla random forest is a bagged decision tree whereby an additional algorithm takes a random sample of m predictors at each split. This works to decorrelate trees used in random forest, and is useful in automatically combating multi-collinearity.

In classification, all trees are aggregated back together. From this aggregation, the model essentially takes a poll / vote to assign data to a category.

For a given observation, we can predict the class by observing what class each bagged tree outputs for that observation. Then we look across all trees to see how many times that observation was predicted. A class is then assigned to that observation if it is predicted from the majority of bagged trees.

A random forest takes a random subset of features from the data, and creates n random trees from each subset. Trees are aggregated together at end.

A dangerously brief example applying Random Forest in R, using the Video Game Sales kaggle data set

An overview of the dataset can be found here.

All my terribly messy code can be found on my github.

The goal for this example was to see if sales numbers and the console a game was on could predict it’s genre (e.g., sports, action, RPG, strategy, etc.).

In this example, I make use of caret and ggplot2. I use the package dummies to generate dummy variables for categorical predictors.

I wanted to get some practice using caret, which is essentially R’s version of scikit-learn. But first, as with any data set, it’s worth exploring it a little bit. My general approach is to look for quirkiness in the data first, explore potential correlations, then dig a bit deeper to see if there are any other trends worth noting in the data. Ideally, you will want to examine the data in every which way before modeling it. For brevity, I skipped some of the data exploration and jumped towards some modeling.

First, I inspected the data for missing values. There were a ton of NaNs so I went ahead and did K-Nearest Neighbor Imputation using the DMwR package.

Next, I wanted to generally inspect the sales data to find if there were any outliers. There were. And the distribution was highly skewed.

Most sales were far less than $20 million total.

I went ahead and normalized them using a log transform.

Normalized-ish, but still sorta skewed.

From here, I generated dummy variables for the different consoles each game was on, and then examined the correlations. Global sales, not surprisingly, were correlated with all other sales. Critic Scores and counts were not. Not pictured here are correlations by console. There was not anything of note there, given the sparsity of console dummy data. One may simply remove the Global Sales variable in lieu of keeping all the other sales variables, if multi-collinearity was a huge concern.

In caret , I did a 80%-20% train-test split, as common practice for conducting modeling. I relabeled all the genres as numbers, and they are as follows:

  1. Sports
  2. Platformer
  3. Racing
  4. RPG
  5. Puzzle
  6. Miscellaneous
  7. Shooter
  8. Simulation
  9. Action
  10. Fighting
  11. Adventure
  12. Strategy

I did some grid searching on the number of features available at each tree split. Recall that Random Forest doesn’t take all available features when it creates a split for each node in the tree. This is a manipulable hyperparameter in the model.

mtry <- sqrt(ncol(vg))
tunegrid <- expand.grid(.mtry = mtry)

In the code snippet above, I took the square root of number of columns as the initial number features available. Doing a grid search expands upon that such that caret will iterate through the initial start variable, then do another sqrt(ncol(vg)) additional features in the next fit iteration, then assess the model once more.

metric <- 'Accuracy'
control <- trainControl(method = 'repeatedcv', number = 10, repeats = 2, search = 'random', savePredictions = TRUE)

Next, I set my metric as accuracy, since this is a classification procedure. I do cross validation to evaluate if my training data is wonky in any way. 5–10 number of folds (denoted as the number parameter) is typical. I do a random search because it’s a bit quicker and less computationally intensive.

Using caret, I trained two models. One with 15 bagged trees. Another with 500 bagged trees. The 500 tree model took some time to run (maybe about 30 minutes?). One could easily incorporate the number of bagged trees in a grid search. For brevity (and time), I just compared two models.

Note I allowed the model to use Box Cox to determine how to normalize the data appropriately (which it log transformed the data).

model_train1 <- train(Genre ~ ., data = vg_train, method = 'rf', trControl = control, tunegrid = tunegrid, metric = metric, ntree = 15, preProcess = c('BoxCox'))model_train2 <- train(Genre ~ ., data = vg_train, method = 'rf', trControl = control, tunegrid = tunegrid, metric = metric, ntree = 500, preProcess = c('BoxCox'))

The results from my cross validation show that the 500 tree model did a tiny bit better…but only a tiny bit. 21 features per split seems appropriate given the cross validation results.

Model 1 with 15 bagged trees.
Model 2 with 500 bagged trees

My accuracy is utterly terrible however. My overall accuracy in Model 2 is only 34.4%.

Random Forests allow us to look at feature importances, which is the how much the Gini Index for a feature decreases at each split. The more the Gini Index decreases for a feature, the more important it is. The figure below rates the features from 0–100, with 100 being the most important.

It seems user count, and critic count are particularly important. However, given how poor the model fit is, I’m not sure how entirely useful interpreting any of these variables is. I’ve included a snippet of the variable importance code in case you want to replicate this.

# Save the variable importance values from our model object generated from caret.
x<-varImp(model_train2, scale = TRUE)
# Get the row names of the variable importance data
rownames(x$importance)
# Convert the variable importance data into a dataframe
importance <- data.frame(rownames(x$importance), x$importance$Overall)
# Relabel the data
names(importance)<-c('Platform', 'Importance')
# Order the data from greatest importance to least important
importance <- transform(importance, Platform = reorder(Platform, Importance))
# Plot the data with ggplot.
ggplot(data=importance, aes(x=Platform, y=Importance)) +
geom_bar(stat = 'identity',colour = "blue", fill = "white") + coord_flip())

We can look at a confusion matrix to see how much accurate classification and mis-classification there was. The diagonal indicates correct % of classification. Off diagonals indicate the % of times the model misclassified a genre.

It’s pretty awful. Shooters got classified correct 68% of the time…but was mis-classified as Strategy games a big percent of the time too.

Takeaways

Why did our model do so poorly? There are several reasons. The model tends to be under fitting the data. This could mean random forest was not complex enough to capture trends in the data, and we might have to use a more complex approach using another model. However, the more likely candidate is that the features are simply not predictive of video game genres.

And if our features are sort of crappy, we can do one of two things: We can engineer some additional features for our given data set. For example, we might be able to create a variable that denotes the average critic score of each genre of game in the data as a predictor (but that might be uninteresting).

What we likely have to do is scrape some additional information from a video game repository that may have additional historical sales data from each type of genre of video game. A second easy thing to do is to simply take the aggregate sum of sales for each genre, then apply it across the entire data set. So many options!

Or maybe the answer is even simpler. It could be that the data are imbalanced in terms of classes. If this were the case (and it is, if you examine the data further), you may want to prune back or combine genres to rectify this.

Conclusions

Random forest is a commonly used model in machine learning, and is often referred to as an off-the-shelf model that is used frequently. In many cases, it out performs many of its parametric equivalents, and is less computationally intensive to boot. Of course, with any model, make sure you know why you should choose a model, such as a Random Forest (hint, maybe you don’t know the distribution of your data, maybe your data is very high dimensional, maybe you have lots of collinearity, maybe you want a model that is easy to interpret). Don’t go about choosing a model willy-nilly like I did here. :-)

Resources

I don’t really delve deep into the mechanics of random forest. If you want to take a deep dive, I highly recommend two books:

The latter is considered the machine learning bible to some!

--

--

A recovering academic . Now a queer trans API data scientist and social justice warrior. Musings about queer justice, gender exploration, and data science.