Decision Trees for Classification — Complete Example
A detailed example how to construct a Decision Tree for classification
This article explains how we can use decision trees for classification problems. After explaining important terms, we will develop a decision tree for a simple example dataset.
Introduction
Traditionally decision trees are drawn manually, but they can be learned using Machine Learning. They can be used for both regression and classification problems. In this article we will focus on classification problems. Let’s consider the following example data:
Using this simplified example we will predict whether a person is going to be an astronaut, depending on their age, whether they like dogs, and whether they like gravity. Before discussing how to construct a decision tree, let’s have a look at the resulting decision tree for our example data.
We can follow the paths to come to a decision. For example, we can see that a person who doesn’t like gravity is not going to be an astronaut, independent of the other features. On the other side, we can also see, that a person who likes gravity and likes dogs is going to be an astronaut independent of the age.
Before going into detail how this tree is constructed, let’s define some important terms.
Terms
Root Node
The top-level node. The first decision that is taken. In our example the root node is ‘likes gravity’.
Branches
Branches represent sub-trees. Our example has two branches. One branch is, e.g. the sub-tree from ‘likes dogs’ and the second one from ‘age < 40.5’ on.
Node
A node represents a split into further (child) nodes. In our example the nodes are ‘likes gravity’, ‘likes dogs’ and ‘age < 40.5’.
Leaf
Leafs are at the end of the branches, i.e. they don’t split any more. They represent possible outcomes for each action. In our example the leafs are represented by ‘yes’ and ‘no’.
Parent Node
A node which precedes a (child) node is called a parent node. In our example ‘likes gravity’ is a parent node of ‘likes dogs’ and ‘likes dogs’ is a parent node of ‘age < 40.5’.
Child Node
A node under another node is a child node. In our example ‘likes dogs’ is a child node of ‘likes gravity’ and ‘age < 40.5’ is a child node of ‘likes dogs’.
Splitting
The process of dividing a node into two (child) nodes.
Pruning
Removing the (child) nodes of a parent node is called pruning. A tree is grown through splitting and shrunk through pruning. In our example, if we would remove the node ‘age < 40.5’ we would prune the tree.
We can also observe, that a decision tree allows us to mix data types. We can use numerical data (‘age’) and categorical data (‘likes dogs’, ‘likes gravity’) in the same tree.
Create a Decision Tree
The most important step in creating a decision tree, is the splitting of the data. We need to find a way to split the data set (D) into two data sets (D_1) and (D_2). There are different criteria that can be used in order to find the next split, for an overview see e.g. here. We will concentrate on one of them: the Gini Impurity, which is a criterion for categorical target variables and also the criterion used by the Python library scikit-learn.
Gini Impurity
The Gini Impurity for a data set D is calculated as follows:
with n = n_1 + n_2 the size of the data set (D) and
with D_1 and D_2 subsets of D, 𝑝_𝑗 the probability of samples belonging to class 𝑗 at a given node, and 𝑐 the number of classes. The lower the Gini Impurity, the higher is the homogeneity of the node. The Gini Impurity of a pure node is zero. To split a decision tree using Gini Impurity, the following steps need to be performed.
- For each possible split, calculate the Gini Impurity of each child node
- Calculate the Gini Impurity of each split as the weighted average Gini Impurity of child nodes
- Select the split with the lowest value of Gini Impurity
Repeat steps 1–3 until no further split is possible.
To understand this better, let’s have a look at an example.
First Example: Decision Tree with two binary features
Before creating the decision tree for our entire dataset, we will first consider a subset, that only considers two features: ‘likes gravity’ and ‘likes dogs’.
The first thing we have to decide is, which feature is going to be the root node. We do that by predicting the target with only one of the features and then use the feature, that has the lowest Gini Impurity as the root node. That is, in our case we build two shallow trees, with just the root node and two leafs. In the first case we use ‘likes gravity’ as a root node and in the second case ‘likes dogs’. We then calculate the Gini Impurity for both. The trees look like this:
The Gini Impurity for these trees are calculated as follows:
Case 1:
Dataset 1:
Dataset 2:
The Gini Impurity is the weighted mean of both:
Case 2:
Dataset 1:
Dataset 2:
The Gini Impurity is the weighted mean of both:
That is, the first case has lower Gini Impurity and is the chosen split. In this simple example, only one feature remains, and we can build the final decision tree.
Second Example: Add a numerical Variable
Until now, we considered only a subset of our data set - the categorical variables. Now we will add the numerical variable ‘age’. The criterion for splitting is the same. We already know the Gini Impurities for ‘likes gravity’ and ‘likes dogs’. The calculation for the Gini Impurity of a numerical variable is similar, however the decision takes more calculations. The following steps need to be done
- Sort the data frame by the numerical variable (‘age’)
- Calculate the mean of neighbouring values
- Calculate the Gini Impurity for all splits for each of these means
This is again our data, sorted by age, and the mean of neighbouring values is given on the left-hand side.
We then have the following possible splits.
We can see that the Gini Impurity of all possible ‘age’ splits is higher than the one for ‘likes gravity’ and ‘likes dogs’. The lowest Gini Impurity is, when using ‘likes gravity’, i.e. this is our root node and the first split.
The subset Dataset 2 is already pure, that is, this node is a leaf and no further splitting is necessary. The branch on the left-hand side, Dataset 1 is not pure and can be split further. We do this in the same way as before: We calculate the Gini Impurity for each feature: ‘likes dogs’ and ‘age’.
We see that the lowest Gini Impurity is given by the split “likes dogs”. We now can build our final tree.
Using Python
In Python, we can use the scikit-learn method DecisionTreeClassifier for building a Decision Tree for classification. Note, that scikit-learn also provides DecisionTreeRegressor, a method for using Decision Trees for Regression. Assume that our data is stored in a data frame ‘df’, we then can train it using the ‘fit’ method:
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
X = df['age', 'likes dogs', 'likes graviy']
y = df['going_to_be_an_astronaut']
clf.fit(X,y)
We can visualize the resulting tree using the ‘plot_tree’ method. It is the same as we built, only the splitting criteria is named with ‘<=’ instead of ‘<’, and the ‘true’ and ‘false’ paths go to the other direction. That is, there are some differences in the appearance.
plot_tree(clf, feature_names=[‘age’,‘likes_dogs’,‘likes_gravity’], fontsize=8);
Advantages and Disadvantages of Decision Trees
When working with decision trees, it is important to know their advantages and disadvantages. Below you can find a list of pros and cons. This list, however, is by no means complete.
Pros
- Decision trees are intuitive, easy to understand and interpret.
- Decision trees are not effected by outliers and missing values.
- The data doesn’t need to be scaled.
- Numerical and categorical data can be combined.
- Decision trees are non-parametric algorithms.
Cons
- Overfitting is a common problem. Pruning may help to overcome this.
- Although decision trees can be used for regression problems, they cannot really predict continuous variables as the predictions must be separated in categories.
- Training a decision tree is relatively expensive.
Conclusion
In this article, we discussed a simple but detailed example of how to construct a decision tree for a classification problem and how it can be used to make predictions. A crucial step in creating a decision tree is to find the best split of the data into two subsets. A common way to do this is the Gini Impurity. This is also used in the scikit-learn library from Python, which is often used in practice to build a Decision Tree. It’s important to keep in mind the limitations of decision trees, of which the most prominent one is the tendency to overfit.
References
- Chris Nicholson, Decision Trees (2020), pathmind — A.I. Wiki, A Beginner’s Guide to Important Topics in AI, Machine Learning, and Deep https://wiki.pathmind.com/decision-tree.
- Abhishek Sharma, 4 Simple Ways to Split a Decision Tree in Machine LearningOverview over splitting methods (2020), analyticsvidhya
All images unless otherwise noted are by the author.
Find more Data Science and Machine Learning posts here: