The world’s leading publication for data science, AI, and ML professionals.

Decision Trees and Random Forests – explained with Python implementation.

In this article, I will walk you through the basics of how Decision Tree and Random Forest algorithms work.

I will also show how they are implemented in Python, with the help of an example.

Photo Credits - Filip Cernak on Unsplash
Photo Credits – Filip Cernak on Unsplash

A Decision Tree is a Supervised Machine Learning algorithm that imitates the human thinking process. It makes the predictions, just like how, a human mind would make, in real life. It can be considered as a series of if-then-else statements and goes on making decisions or predictions at every point, as it grows.

A decision tree looks like a flowchart or an inverted tree. It grows from root to leaf but in an upside down manner. We can easily interpret the decision making /prediction process of a decision tree, due to it’s structure.

A typical binary split Decision Tree
A typical binary split Decision Tree

The starting node is called as the Root Node. It splits further by making a decision based on some criterion into what are called as Internal Nodes. The internal nodes that split further are called as parent nodes and the nodes formed out of parent nodes are called a s child nodes. The nodes that do not split further are called Leaf Nodes.

A Decision Tree can be a Classification Tree or a Regression Tree, based upon the type of target variable. The class in case of classification tree is based upon the majority prediction in leaf nodes. In case of regression, the final predicted value is based upon the average values in the leaf nodes.

There are various Decision Tree algorithms depending on the whether it is a binary split or multiway split. Binary Split means a parent node splits into 2 child nodes and in a multiway split, the parent node splits into more than 2 child nodes.

CART meaning Classification and Regression Tree algorithm deals with binary split trees while ID3 algorithm deals with multiway split trees. We will discuss the CART algorithm in detail. (Hereafter the Decision Tree will mean CART algorithm tree)

A Decision Tree divides the data into various subsets and then makes a split based on a chosen attribute. This attribute is chosen based upon the homogeneity criterion called Gini Index.

The Gini Index, basically measures purity (or impurity as well, we can say) of the nodes after the split happens. Meaning, it is the measure of how pure are the child nodes from the parent node, after the split.

The Tree should make a split based upon that attribute which leads to maximum homogeneity of the child nodes. The formula for Gini Index is as follows :

Where pi is the probability of finding a point with the label i, and n is the number of classes/resultant values of the target variable.

Thus the tree calculates the Gini index at the parent node. Then the weighted Gini index at the child nodes is calculated for every attribute. An attribute is chosen such that we get maximum homogeneity in the child nodes and the split is performed. This process is repeated until all the attributes are exhausted. At the end we get what we call as a fully grown Decision tree. Thus the attributes which lead to initial splits can be considered as the most important attributes for prediction.

The problem with a fully grown Decision Tree is that, it overfits. Meaning it mugs up all the train data and fails to generalize, leading to a poor performance on test data. To solve this problem and to get maximum efficacy out of a decision tree, we need to prune the fully grown tree or truncate it’s growth.

Pruning is a technique where we can prune the less important branches of a fully grown tree while truncation is a technique where we control the growth of the tree, in order to avoid overfitting.

Truncation is more popularly used and is done by tuning the hyperparameters of the tree which will be discussed in the example.

So overall, Decision Trees are efficient algorithms which require zero or minimum data processing. They can handle linear and non-linear data, categorical or numerical data efficiently and make predictions based on the given set of attributes. Most important, they are easily interpretable.

Some shortcomings of Decision Trees are that, they overfit. A small change in data may change the entire tree structure, which makes them high variance algorithms. Gini Index calculations make them complex and they consume lot of time and memory. Also, they can be called as greedy algorithms, since they consider only the immediate split implications and not the implications of further splits.

Random Forests

A combination of various models (linear regression, logistic regression, decision tree, etc.) brought together as a single model in order to achieve the final outcome is called as an Ensemble. So, in an ensemble, different models are viewed as one, and not separately.

A Random Forest is a powerful ensemble model built with large number of Decision Trees. It overcomes the shortcomings of a single decision tree in addition to some other advantages. But that does not mean that it is always better than a decision tree. There can be instances when a decision tree may perform better than a random forest.

An ensemble functions in such a way that every model in the ensemble helps to compensate for the shortcomings of every other model. We can say, if a random forest is built with 10 decision trees, every tree may not be performing great with the data, but the stronger trees help to fill the gaps for weaker trees. This is what makes an ensemble a powerful Machine Learning model.

The individual trees in a random forest must satisfy two criterion :

  1. Diversity : Meaning every model in the ensemble should function independently and should complement other models.
  2. Acceptability : Meaning every model in the ensemble should perform at least better than a random guesser model.

Bagging is one of the common ensemble techniques, used to achieve the diversity criterion. Bagging stands for Bootstrapped Aggregation. Bootstrapping is a method used to create bootstrapped samples. These samples are created by sampling the given data, uniformly and with replacement. A bootstrapped sample contains approximately 30%–70% data from the entire dataset. Bootstrapped samples of uniform length are created and given as input to the individual decision trees in the random forest model. The result of all the individual models is then aggregated to arrive at a final decision.

A typical Random Forest Classifier
A typical Random Forest Classifier

Also, not all features are used to train every tree. A fixed number of random set of features is chosen at every node of every tree. This ensures there is no correlation among functioning of different trees and that they function independently.

In Random Forest Classifier, the majority class predicted by individual trees is considered as final prediction, while in Random Forest Regressor, the average of all the individual predicted values is considered as the final prediction.

In the Random Forest model, usually the data is not divided into training and test sets. The entire data is used to form bootstrapped samples, such that a set of data points is always set aside to form the test or validation set for individual trees. These samples are called Out Of Bag (OOB) samples. Thus for every tree in the random forest ensemble, there exists a set of data points which was not present in it’s training data and is used as validation set to evaluate it’s performance .

The OOB error is calculated as the ratio of number of incorrect predictions in the OOB sample to the total number of predictions of OOB samples.

Random Forests also help decide the important features, among all the features in the given data, by eliminating less important features/attributes. This leads to better prediction results. Feature importance is decided by calculating the decrease in impurity or increase in purity with the help of Gini Index calculation.

The main advantage of random forests over decision trees is that they are stable and are low variance models. They also overcome the problem of overfitting present in decision trees. Since they use bootstrapped data and random set of features, they ensure diversity and robust performance. They are immune to curse of dimensionality as they do not consider all the features at one time for individual trees.

The main disadvantage of random forests is their lack of interpretability. One cannot trace how the algorithm works unlike decision trees. Another disadvantage is that they are complex and computationally expensive.

Now let us see the Python implementation of both Decision tree and Random forest models with the help of a telecom churn data set.

Python Implementation

For telecom operators, retaining high profitable customers is the number one business goal. The telecommunications industry experiences an average of 15–25% annual churn rate. We will analyse customer-level data of a leading telecom firm, build predictive models to identify customers at high risk of churn and identify the main indicators of churn.

The dataset contains customer-level information for a span of four consecutive months – June, July, August and September. The months are encoded as 6, 7, 8 and 9, respectively. We need to predict the churn in the last (i.e. the ninth) month using the data (features) from the first three months.

Flow of analysis :

1.Import the required libraries

  1. Read and understand the data
  2. Data cleanup and preparation
  3. Filter High Value customer
  4. Churn Analysis
  5. Exploratory Data Analysis
  6. Model building
  7. Model building – hyperparameters tuning
  8. Best Model

Required libraries are imported. We will be loading more packages from various libraries as and when required further.

Data is loaded into pandas data frame and analysed.

Initial analysis tells us that the data has 99999 rows and 226 columns. There are a lot of missing values which need t be imputed or dropped.

Most of the columns with more than 30% null values are dropped and for others, the null values are imputed with mean/median/mode values appropriately. Finally we have data set with 99999 rows and 185 columns.

Since most of the revenue comes from high value customers, we filter the data accordingly and get a data set with 30011 rows and 185 columns.

We perform Exploratory Data Analysis (EDA)for these customers.

Churn Analysis
Churn Analysis

We can see that 91.4% customers are non-churn and only 8.6% customers are churn. So this is an imbalanced data set. Following are some of the analyses graphs from EDA.

std incoming within the telecom company network minutes of usage
std incoming within the telecom company network minutes of usage
Roaming Incoming minutes of usage for churn and non-churn
Roaming Incoming minutes of usage for churn and non-churn
Local Incoming minutes of usage for churn and non-churn
Local Incoming minutes of usage for churn and non-churn

We split the data into train and test sets and separate the churn variable as y, as the dependent variable and rest of features as independent variable X.

We need to handle the class imbalance, since every model that we build will give good performance with the majority class and worst performance with the minority class. We can handle imbalanced classes by balancing the classes. Meaning increasing minority or decreasing majority class.

There are various class imbalance handling techniques as follows :

  1. Random Under-Sampling
  2. Random Over-Sampling
  3. SMOTE – Synthetic Minority Oversampling Technique
  4. ADASYN – Adaptive Synthetic Sampling Method
  5. SMOTETomek – Over-sampling followed by under-sampling

We implemented all the techniques with Logistic Regression, Decision Tree and Random Forest models. We need to concentrate on the models that give us high sensitivity/recall. Following were the results :

Comparison of evaluation metrics for various models
Comparison of evaluation metrics for various models

As we can see in the summary table, we have recall values high for Random Forest with Random Undersampling, Logistic Regression with Random Undersampling, Random Oversampling, SMOTE, ADASYN, SMOTE+TOMEK. Within Logistic Regression ADASYN has highest recall.

We will pick up Random Forest with Undersampling method for further analysis.

We know that Random Forest gives us feature importance by eliminating less important features. We run Random Forest Classifier and choose the important features :

We perform hyperparameter tuning using GridSearchCV for Random Forest Undersampling model. Tuning of hyperparameters is similar for a Decision Tree Classifier. We just need to replace the Random Forest Classifier with Decision Tree Classifier, for basic model.

We get the following best score and best model.

The grid search provided by [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV) exhaustively generates candidates from a grid of parameter values specified with the param_grid parameter. This process is computationally expensive and time consuming especially over a large hyperparameter space.

There is another approach called RandomizedSearchCV to find the best model. Here a fixed number of hyperparameters is sampled from specified probability distributions. This makes the process time efficient and the model less complex.

We get the following best score and best model.

Comparing the scores obtained and complexity of the above models, we conclude that model obtained using RandomizedSearchCV, is the best model.

We make predictions on the test set and get the following results.

We get recall equal to 81% and accuracy 85%. The following is the representation of top 10 important predictors of churn.

This completes out analysis. Hope the article was informative and easy to understand. Also, I hope you enjoyed analyzing the colorful graphs that were included in the analysis.

Do feel free to comment and give your feedback.

You can connect with me on LinkedIn: https://www.linkedin.com/in/pathakpuja/

Please visit my GitHub profile for the python codes. The code mentioned in the article, as well as the graphs, can be found here: https://github.com/pujappathak

References:

_https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.RandomizedSearchCV.html_

_https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html_


Related Articles