The world’s leading publication for data science, AI, and ML professionals.

Implementing a Decision Tree From Scratch

Recursively refine your understanding by implementing a decision tree with just Python and NumPy

From Scratch

Decision Tree From Scratch [Image by Author]
Decision Tree From Scratch [Image by Author]

Decision trees are simple and easy to explain. They can easily be displayed graphically and therefore allow for a much simpler interpretation. They are also a quite popular and successful weapon of choice when it comes to machine learning competitions (e.g. Kaggle).

Being simple on the surface, however, does not mean the algorithm and the underlying mechanisms are boring or even trivial.

In the following sections, we are going to implement a decision tree for classification in a step-by-step fashion using just Python and NumPy. We will also learn about the concepts of entropy and information gain, which provide us with the means to evaluate possible splits, hence allowing us to grow a decision tree in a reasonable way.

But before diving straight into the implementation details, let’s establish some basic intuition about decision trees in general.


Decision Trees 101: Dendrology

Tree-based methods are simple and useful for interpretation since the underlying mechanisms are considered quite similar to human decision-making.

The methods involve stratifying or segmenting the predictor space into a number of simpler regions. When making a prediction, we simply use the mean or mode of the region the new observation belongs to as a response value.

An example of a segmented predictor space [Image by Author]
An example of a segmented predictor space [Image by Author]

Since the splitting rules to segment the predictor space can be best described by a tree-based structure, the supervised learning algorithm is called a Decision Tree.

Decision trees can be used for both regression and classification tasks.

A simplified example of a decision tree [Image by Author]
A simplified example of a decision tree [Image by Author]

Now, that we know what a decision tree is and why it can be useful, we need to know how to grow one?

Planting a seed: How to grow a decision tree

Loosely speaking, the process of building a decision tree mainly involves two steps:

  1. Dividing the predictor space into several distinct, non-overlapping regions
  2. Predicting the most-common class label for the region any new observation belongs to

As simple as it sounds, one fundamental question arises – How do we split the predictor space?

In order to split the predictor space into distinct regions, we use binary recursive splitting, which grows our decision tree until we reach a stopping criterion. Since we need a reasonable way to decide which splits are useful and which are not, we also need a metric for evaluation purposes.

In information theory, entropy describes the average level of information or uncertainty and can be defined as the following:

We can leverage the concept of entropy to calculate the information gain, resulting from a possible split.

Let’s consider we have a dataset containing different patient data. Now, we want to classify each patient into either having a high or low risk of suffering from a heart attack. Imagine a possible decision tree like the following:

An example decision tree to compute information gain [Image by Author]
An example decision tree to compute information gain [Image by Author]

In order to calculate the split’s information gain (IG), we simply compute the sum of weighted entropies of the children and subtract it from the parent’s entropy.

Let’s work through our example to clarify things further:

An information gain of 1 would be the best possible result. In our example, however, the split yields an information gain of roughly 0.395 and contains therefore a lot more uncertainty or in other terms – a higher value of entropy.

Equipped with the concepts of entropy and information gain, we simply need to evaluate all possible splits at the current growth stage of the tree (greedy approach), select the best one, and continue growing recursively until we reach a stopping criterion.


Introducing the Algorithm

Now, that we have covered all the basics we can start implementing the learning algorithm.

But before diving straight into the implementation details, we will take a quick look at the main computational steps of the algorithm to provide a high-level overview as well as some basic structure.

The main algorithm can be basically divided into three steps:

  1. Initialization of parameters (e.g. maximum depth, minimum samples per split) and creation of a helper class
  2. Building the decision tree, involving binary recursive splitting, evaluating each possible split at the current stage, and continuing to grow the tree until a stopping criterion is satisfied
  3. Making a prediction, which can be described as traversing the tree recursively and returning the most-common class label as a response value

Since building the tree contains multiple steps, we will rely heavily on the use of helper functions in order to keep our code as clean as possible.

The algorithm will be implemented in two classes, the main class containing the algorithm itself and a helper class defining a node. Below, we can take a look at the skeleton classes, which can be interpreted as some kind of blueprint, guiding us through the implementation in the next section.


Implementation From Scratch

Basic Setup and the Node

Let’s kick off our implementation with some basic housekeeping. First of all, we define some basic parameters for our main class, namely the stopping criteria max_depth,min_samples_split, and the root node.

Next, we define a small helper class, which stores our splits in a node. The node contains information about the feature, the threshold value, and the connected left and right child, which will be useful when we recursively traverse the tree in order to make a prediction.

Building the tree

Now, things get a little bit more involved. Thus, in the following, we will heavily rely on the use of several helper methods to stay organized.

We start our building process by calling the fit() method, which simply invokes our core method_build_tree().

Within the core method, we simply gather some information about the dataset (number of samples, features, and unique class labels), which will be needed in order to decide if the stopping criteria are met.

Our helper method _is_finished() is used to evaluate the stopping criteria. If, for example, there are fewer samples than the minimum required samples per split remaining, our method returns True and the building process will be stopped at that current branch.

If our building process has finished, we compute the most common class label and save that value in a leaf node.

Note: The stopping criterion serves as an exit strategy to stop the recursive growth. Without a stopping mechanism in place, we would have created an endless loop.

We continue to grow our tree by computing the best split at that current stage. In order to get the best split, we loop through all the feature indices and unique threshold values to calculate the information gain. In the earlier sections, we already learned how to compute the information gain, which basically tells us how much uncertainty could be removed by the proposed split.

Once we obtain the information gain for the specific feature-threshold combination, we compare the result to our previous iterations. If we find a better split we store the associated parameters in a dictionary.

After looping through all combinations we return the best feature and threshold as a tuple.

Now, we can finish our core method by growing the children recursively. We therefore split the data by utilizing the best feature and threshold into a left and a right branch.

Next, we call our core method from itself (that’s the recursive part here) in order to start the building process for the children.

Once we satisfy the stopping criteria the method will recursively return all nodes, allowing us to build a full-grown decision tree.

Making a prediction – or traversing the tree

We finished most of the hard work by now – we simply need to create one more helper method and we are done.

Making a prediction can be implemented by recursively traversing the tree. Meaning, for every sample in our dataset, we compare the node feature and threshold values to the current sample’s values and decide if we have to take a left or a right turn.

Once we reach a leaf node we simply return the most common class label as our prediction.

And this is it! We finished our implementation of a decision tree.


Testing the Classifier

Having finished the implementation we still need to test our classifier.

For testing purposes, we will use the classic binary classification breast cancer Wisconsin dataset[1]. The dataset contains 30 dimensions and 569 samples in total.

After importing the dataset, we can split it into train and test samples respectively.

We instantiate our classifier, fit it on the training data and make our predictions. Utilizing our helper function, we obtain an accuracy of ~95.6 %, allowing us to confirm that our algorithm works.


Conclusion

In this article, we implemented a decision tree for classification from scratch with just the use of Python and NumPy. We also learned about the underlying mechanisms and concepts like entropy and information gain.

Understanding the basics of decision trees will prove useful when tackling the more advanced extensions like bagging, random forest, and boosting. A deeper understanding of the algorithm will also be helpful when trying to optimize the hyperparameters of a learning algorithm based on a decision tree.

You can find the full code here on my GitHub.

ML __ Algorithms From Scratch


Enjoyed the article? Become a Medium member and continue learning with no limits. I’ll receive a portion of your membership fee if you use the following link, with no extra cost to you

Join Medium with my referral link – Marvin Lanhenke


References / Further Material:


Related Articles