Decision Trees for Classification — Complete Example

A detailed example how to construct a Decision Tree for classification

Datamapu
Towards Data Science

--

Photo by Fabrice Villard on Unsplash

This article explains how we can use decision trees for classification problems. After explaining important terms, we will develop a decision tree for a simple example dataset.

Introduction

A decision tree is a decision support tool that uses a tree-like model of decisions and their possible consequences, including chance event outcomes, resource costs, and utility. It is one way to display an algorithm that only contains conditional control statements.

Traditionally decision trees are drawn manually, but they can be learned using Machine Learning. They can be used for both regression and classification problems. In this article we will focus on classification problems. Let’s consider the following example data:

Example data (constructed by the author)

Using this simplified example we will predict whether a person is going to be an astronaut, depending on their age, whether they like dogs, and whether they like gravity. Before discussing how to construct a decision tree, let’s have a look at the resulting decision tree for our example data.

Final decision tree for example data

We can follow the paths to come to a decision. For example, we can see that a person who doesn’t like gravity is not going to be an astronaut, independent of the other features. On the other side, we can also see, that a person who likes gravity and likes dogs is going to be an astronaut independent of the age.

Before going into detail how this tree is constructed, let’s define some important terms.

Terms

Root Node

The top-level node. The first decision that is taken. In our example the root node is ‘likes gravity’.

Branches

Branches represent sub-trees. Our example has two branches. One branch is, e.g. the sub-tree from ‘likes dogs’ and the second one from ‘age < 40.5’ on.

Node

A node represents a split into further (child) nodes. In our example the nodes are ‘likes gravity’, ‘likes dogs’ and ‘age < 40.5’.

Leaf

Leafs are at the end of the branches, i.e. they don’t split any more. They represent possible outcomes for each action. In our example the leafs are represented by ‘yes’ and ‘no’.

Parent Node

A node which precedes a (child) node is called a parent node. In our example ‘likes gravity’ is a parent node of ‘likes dogs’ and ‘likes dogs’ is a parent node of ‘age < 40.5’.

Child Node

A node under another node is a child node. In our example ‘likes dogs’ is a child node of ‘likes gravity’ and ‘age < 40.5’ is a child node of ‘likes dogs’.

Splitting

The process of dividing a node into two (child) nodes.

Pruning

Removing the (child) nodes of a parent node is called pruning. A tree is grown through splitting and shrunk through pruning. In our example, if we would remove the node ‘age < 40.5’ we would prune the tree.

Decision tree illustration

We can also observe, that a decision tree allows us to mix data types. We can use numerical data (‘age’) and categorical data (‘likes dogs’, ‘likes gravity’) in the same tree.

Create a Decision Tree

The most important step in creating a decision tree, is the splitting of the data. We need to find a way to split the data set (D) into two data sets (D_1) and (D_2). There are different criteria that can be used in order to find the next split, for an overview see e.g. here. We will concentrate on one of them: the Gini Impurity, which is a criterion for categorical target variables and also the criterion used by the Python library scikit-learn.

Gini Impurity

The Gini Impurity for a data set D is calculated as follows:

with n = n_1 + n_2 the size of the data set (D) and

with D_1 and D_2 subsets of D, 𝑝_𝑗 the probability of samples belonging to class 𝑗 at a given node, and 𝑐 the number of classes. The lower the Gini Impurity, the higher is the homogeneity of the node. The Gini Impurity of a pure node is zero. To split a decision tree using Gini Impurity, the following steps need to be performed.

  1. For each possible split, calculate the Gini Impurity of each child node
  2. Calculate the Gini Impurity of each split as the weighted average Gini Impurity of child nodes
  3. Select the split with the lowest value of Gini Impurity

Repeat steps 1–3 until no further split is possible.

To understand this better, let’s have a look at an example.

First Example: Decision Tree with two binary features

Before creating the decision tree for our entire dataset, we will first consider a subset, that only considers two features: ‘likes gravity’ and ‘likes dogs’.

The first thing we have to decide is, which feature is going to be the root node. We do that by predicting the target with only one of the features and then use the feature, that has the lowest Gini Impurity as the root node. That is, in our case we build two shallow trees, with just the root node and two leafs. In the first case we use ‘likes gravity’ as a root node and in the second case ‘likes dogs’. We then calculate the Gini Impurity for both. The trees look like this:

Image by the author

The Gini Impurity for these trees are calculated as follows:

Case 1:

Dataset 1:

Dataset 2:

The Gini Impurity is the weighted mean of both:

Case 2:

Dataset 1:

Dataset 2:

The Gini Impurity is the weighted mean of both:

That is, the first case has lower Gini Impurity and is the chosen split. In this simple example, only one feature remains, and we can build the final decision tree.

Final Decision Tree considering only the features ‘likes gravity’ and ‘likes dogs’

Second Example: Add a numerical Variable

Until now, we considered only a subset of our data set - the categorical variables. Now we will add the numerical variable ‘age’. The criterion for splitting is the same. We already know the Gini Impurities for ‘likes gravity’ and ‘likes dogs’. The calculation for the Gini Impurity of a numerical variable is similar, however the decision takes more calculations. The following steps need to be done

  1. Sort the data frame by the numerical variable (‘age’)
  2. Calculate the mean of neighbouring values
  3. Calculate the Gini Impurity for all splits for each of these means

This is again our data, sorted by age, and the mean of neighbouring values is given on the left-hand side.

The data set sorted by age. The left hand side shows the mean of neighbouring values for age.

We then have the following possible splits.

Possible splits for age and their Gini Imputity.

We can see that the Gini Impurity of all possible ‘age’ splits is higher than the one for ‘likes gravity’ and ‘likes dogs’. The lowest Gini Impurity is, when using ‘likes gravity’, i.e. this is our root node and the first split.

The first split of the tree. ‘likes gravity’ is the root node.

The subset Dataset 2 is already pure, that is, this node is a leaf and no further splitting is necessary. The branch on the left-hand side, Dataset 1 is not pure and can be split further. We do this in the same way as before: We calculate the Gini Impurity for each feature: ‘likes dogs’ and ‘age’.

Possible splits for Dataset 2.

We see that the lowest Gini Impurity is given by the split “likes dogs”. We now can build our final tree.

Final Decision Tree.

Using Python

In Python, we can use the scikit-learn method DecisionTreeClassifier for building a Decision Tree for classification. Note, that scikit-learn also provides DecisionTreeRegressor, a method for using Decision Trees for Regression. Assume that our data is stored in a data frame ‘df’, we then can train it using the ‘fit’ method:

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
X = df['age', 'likes dogs', 'likes graviy']
y = df['going_to_be_an_astronaut']
clf.fit(X,y)

We can visualize the resulting tree using the ‘plot_tree’ method. It is the same as we built, only the splitting criteria is named with ‘<=’ instead of ‘<’, and the ‘true’ and ‘false’ paths go to the other direction. That is, there are some differences in the appearance.

plot_tree(clf, feature_names=[‘age’,‘likes_dogs’,‘likes_gravity’], fontsize=8);
Resulting Decision Tree using scikit-learn.

Advantages and Disadvantages of Decision Trees

When working with decision trees, it is important to know their advantages and disadvantages. Below you can find a list of pros and cons. This list, however, is by no means complete.

Pros

  • Decision trees are intuitive, easy to understand and interpret.
  • Decision trees are not effected by outliers and missing values.
  • The data doesn’t need to be scaled.
  • Numerical and categorical data can be combined.
  • Decision trees are non-parametric algorithms.

Cons

  • Overfitting is a common problem. Pruning may help to overcome this.
  • Although decision trees can be used for regression problems, they cannot really predict continuous variables as the predictions must be separated in categories.
  • Training a decision tree is relatively expensive.

Conclusion

In this article, we discussed a simple but detailed example of how to construct a decision tree for a classification problem and how it can be used to make predictions. A crucial step in creating a decision tree is to find the best split of the data into two subsets. A common way to do this is the Gini Impurity. This is also used in the scikit-learn library from Python, which is often used in practice to build a Decision Tree. It’s important to keep in mind the limitations of decision trees, of which the most prominent one is the tendency to overfit.

References

All images unless otherwise noted are by the author.

Find more Data Science and Machine Learning posts here:

--

--