This is the fifth post in my scikit-learn tutorial series. If you didn’t catch them, I strongly recommend my first four posts; it’ll be way easier to follow along.
This module introduces decision trees. As we’ll see, decision trees are a type of supervised learning algorithm that works by recursively splitting the data based on threshold/feature couples, creating a nested structure like a tree. The leaves of the tree represent the prediction of the model.
Personal Disclaimer: When I first started the scikit-learn tutorial, I had never encountered decision trees. At best, I’d read that term here and there but had absolutely no idea what it meant. So I was kind of curious about what decision trees were and afraid it would be more complex and harder to understand than other models I already knew, like linear regression models and support vector machines. But it turns out, decision trees are actually way simpler and easier to understand! And they are powerful!
We’ll discover how decision trees work first using a very simple example of a regression problem with a 1d dataset and the MSE loss function, and then a 2D dataset for classification with the Gini and Entropy impurity functions. The idea is to understand the concept of how decision trees grow, and what are the differences between a regression and a classification. It is then easy to extrapolate the way they work to higher dimension problems. Finally we’ll see some hyperparameters decision trees expose.
All images by author.
Decision Tree for 1D Regression (with MSE)
In order to understand and grasp the overall logic behind decision trees, we’ll use a simple example of 1D regression, using DecisionTreeRegressor.
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from Sklearn.tree import DecisionTreeRegressor, plot_tree
np.random.seed(42)
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])
plt.figure(figsize=(12, 6))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.title("Data to it")
plt.xlabel("X")
plt.ylabel("target")
plt.legend()

Our goal is to create a regression model that fits the 1D dataset using a decision tree approach. In this 1D feature example, the overall approach follows these steps:
- Try all possible threshold values in the X-space that split the dataset into 2 sub-datasets: the "left" dataset for samples where X < threshold and the "right" dataset where threshold < X.
- Associate the prediction value of the left dataset as the mean of that subset, and the prediction value of the right dataset as the mean of that dataset.
- Based on a criterion that translates how well the left and right means "fit" the corresponding target values, select the best threshold.
- This creates 2 new datasets. Apply the same procedure to each dataset separately until the mean of the subdataset "fits" enough the target samples.
Let’s see how that works out with a single iteration: to do this, we set the max depth of the decision tree to 1, meaning a single split is done from the root of the tree – the node that contains the whole dataset.
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
def plot_decision_regression(max_depth=1):
regressor = DecisionTreeRegressor(max_depth=max_depth)
regressor.fit(X, y)
# Generate test data for plotting the regression line
y_pred = regressor.predict(X_test)
# Plot the Decision Tree regression result and structure
fig, axes = plt.subplots(1,2,figsize=(12, 6))
# Plot the regression result on the left
axes[0].scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
axes[0].plot(X_test, y_pred, color="cornflowerblue", label="prediction")
axes[0].set_title("Decision Tree Regression")
axes[0].set_xlabel("X")
axes[0].set_ylabel("target")
axes[0].legend()
# Plot the Decision Tree structure on the right
plot_tree(regressor, filled=True, feature_names=["X"], ax=axes[1])
axes[1].set_title(f"Decision Tree Structure")
fig.suptitle(f'Decision tree regressor with max depth={max_depth}')
# Show the combined plot
fig.tight_layout()
return fig, axes, regressor
fig, axes, dtr = plot_decision_regression()
axes[0].axvline(dtr.tree_.threshold[0], color="gray", linestyle="--")

This plot shows how the data is split into 2 subsets, the left and right subset based on the threshold value of 3.186. The mean value associated with the left subset is the mean of the corresponding samples, which is 0.611, and -0.754 for the right subset. These "approximations" of the real data are not perfect, and this imperfection is translated by the squared error between the sample values and the mean.
The MSE Loss Function
Under the hood, sklearn found that the best value to split was the threshold of 3.186, according to the (weighted) mean squared error of the left and right subsets. Mathematically, the model finds the solution to:

In other words, it finds the x-threshold value that minimizes the sum of the left and right mean-squared error:

Once this optimal threshold is found, the left and right datasets are stored in new nodes of the decision tree.
In the right plot, we can see the structure of the decision tree that represents how the split was done and what the new nodes are. This is done using a plotting helper function from sklearn.
At this point, we can say that the model is not very good at learning the data: it underfits a lot. To allow the model to "fit more," we can add an additional depth or "step" to the decision tree. The idea is to apply the exact same procedure that was applied to the root node (the original whole dataset) to both left and right subsets. This means that new sub-left and sub-right nodes are going to be created for both left and right nodes.
fig, axes, dtr = plot_decision_regression(2)
axes[0].axvline(dtr.tree_.threshold[0], color="gray", linestyle="--", alpha=0.5)
axes[0].axvline(dtr.tree_.threshold[1], color="gray", linestyle="--")
axes[0].axvline(dtr.tree_.threshold[4], color="gray", linestyle="--")

We can see that the previous left mean value of 0.611 is now split into 2 new subsets with mean values of 0.687 on the left of the threshold 2.612 and 0.231 on the right. Similarly, the previous right mean value of -0.754 is now split on the left with mean values of -0.419 on the left of the threshold 3.901 and -0.932 on the right.
Note that where the mean squared error at depth 1 was about 0.1, they are now about 0.03, which means the model fits the data better – which can be seen visually with the regression line.
We can now repeat the same procedure by adding a new depth level to the model:
_ = plot_decision_regression(3)

As expected, increasing the max depth of the model increased the "fitting" performance. You can now see how max_depth acts like a hyperparameter to control the complexity of the model. Setting the max depth to a (relatively) big value of 6 will make the model overfit: since the number of leaves grows exponentially with the depth, overfitting can appear very quickly:
_ = plot_decision_regression(6)

The following gif shows how the tree grows and the regression evolves as the max depth increases, from 1 to 20:

As you can see, increasing the max depth increases the number of leaves, hence the number of "constant sections" equal to the mean of the samples in this section. At some point, some branches stop growing because there is only one or two samples in the leaves, and so some branches will be bigger. At a depth of 12, the tree is said to be "fully" grown, when increasing the max depth does not change the final structure.
Finally, notice that when creating new leaves/splits, existing thresholds are kept and not modified: for example, the first split at x=3.186 is kept all along.
Before moving to a more complex example, let’s review how the threshold value is selected at each step. As mentioned above, the model selects the threshold that leads to the 2 best-fitting left and right nodes. Internally, the tested thresholds are all the sample values.
The following reproduces the way splits are computed and tested, and we can visualize this process for the first split:

This way, we can see that a loss score is associated with each possible split, and the split with the smallest loss is selected. Remember that this procedure is applied for each node recursively, so the next step is then to apply it to the left node, and then also to the right node.
And that is it; decision trees are as simple as that. The important concept to understand is that at each step (or node), the best threshold that splits the data into 2 new nodes is found, and the process is applied recursively. The "best" threshold is computed by minimizing the error between the actual value of the training samples and the corresponding prediction given by the mean of the samples in the node.
In the case of classification, the approach is pretty similar. Let’s consider a 1D feature and a binary target class. So for each sample x corresponds a 0/1 target class. Again, the best split is found by trying all possible values in the x feature, but we cannot use the MSE to quantify how good a split is since now y is not a continuous variable but a binary class. Instead, we have to use another metric that quantifies how well the nodes are "separated" into one class or the other. In other words, the objective is to have nodes that only contain one kind of target class. Such a metric is the Gini Impurity or Entropy – don’t worry too much about them, just remember that they act as MSE but for class targets instead of continuous targets. We’ll see a complete example below.
In the end, the approach for decision trees, whether it is for regression or classification, is really the same: only the metric used to select the best split is used (for example, MSE for regression and Gini for classification).
Decision Tree for 2D Classification
In the previous section, we saw the example of a 1D regression, where we used a 1D feature data X to fit a continuous target value y. In this second example, we are going to see a 2D example for a classification problem: so we have 2 input features X1 and X2, and a discrete target variable y, with only 2 levels for now, so let’s say 0 and 1 (or True/False if you will).
The algorithm’s approach is immensely similar to the previous example:
- at each node, find the best split to separate the current dataset into 2 subsets (the "left" and "right").
- Apply recursively to each subnode.
The difference here will be:
- We have a classification problem and not a regression problem, so we’ll use a classification metric to compute how well a threshold separates the dataset into a "True" subset and a "False" subset.
- We have 2 features X1 and X2: the idea is simply to try all possible thresholds for X1 AND all possible thresholds for X2, and keep the best one of all.
Let’s first quickly review the Gini Impurity metric, which quantifies how well a dataset is separated into 2 subsets.
Gini Impurity Loss Function
The Gini impurity of a categorical dataset measures how often a randomly chosen element would be incorrectly labeled if it were labeled randomly and independently according to the distribution of labels in the set. In other words, it computes the accuracy of the following algorithm:
- Select a sample at random.
- Compute the probability of randomly sampling a sample from all classes: that is n_i/n.
- Compute the probability to mislabel it: meaning "combine" the probability to select from class i, and the probability to affect to class j!=i.
Let’s consider i=1..N categories, and pi=ni/n the corresponding frequencies in a dataset of nn elements. The probability of randomly selecting a sample of category ii is pipi, and the probability of randomly mislabeling it is 1−pi. For a single category, the probability of randomly selecting and randomly mislabeling is then pi(1−pi). The Gini impurity is then the sum of such probability for all categories:

Another way to see the Gini impurity is as follows: for a given dataset DD with N classes, the Gini impurity can be between:
- Perfectly determined: If all samples belong to a single class, the probabilities are 0 for all labels except one, and 1 for that label; hence the Gini impurity is

- Perfectly uniformly distributed: If all labels are uniformly distributed, all probabilities are the same, and the Gini impurity is


So for binary class, the Gini impurity of a perfectly uniform distribution is 1−0.5=0.51−0.5=0.5, and for a perfectly sorted one, it is 0.

You can compare this to the loss function used for regression:

Entropy loss function
Another loss function worth mentioning is the entropy, also called log-loss. It does pretty much the same thing as the Gini impurity: it quantifies how well a dataset is "sorted" or not.
It is expressed as:

And the loss for choosing the best split is:

A 2D classification example
Let’s review the process of how a decision is built in the case of a 2D classification problem.
The following plot shows how all possible thresholds are tested for both features X1 and X2. Once they all have been tested, the best one is picked as having the lowest impurity. Then the dataset is split according to that threshold for that feature: so remember that a split only splits the dataset "along" one feature with one threshold, not several at the same time. On the other hand, the next split could be on any feature, depending on which one provides the best split.

Now let’s see an actual sklearn decision tree classifier grow, and the corresponding predictions for a more complex dataset:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.datasets import make_circles, make_classification, make_moons
X, y = make_moons(n_samples=100, noise=0.5, random_state=0)
def plot_decision_boundaries(depth):
tree_model = DecisionTreeClassifier(random_state=42, max_depth=depth)
tree_model.fit(X, y)
fig = plt.figure(figsize=(12, 6))
ax = plt.subplot(1, 2, 1)
ax.set_title('Decision tree classification')
ax.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolors='k', s=100, label="data")
plot_step = 0.01
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step))
Z = tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
ax.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
ax.set_xlabel('Feature X1')
ax.set_ylabel('Feature X2')
ax = plt.subplot(1, 2, 2)
tree.plot_tree(tree_model, filled=True, feature_names=['Feature 0', 'Feature 1'], class_names=['Class 0', 'Class 1'])
ax.set_title('Decision Tree Structure')
fig.suptitle(f'Deicision tree classifier for 2D feature with max_depth={depth}')
fig.tight_layout()
return fig

Underfitting, Overfitting, and Regularization of Decision Trees
Just like any other supervised models, decision trees are subject to overfitting and underfitting. And, like other models, decision trees expose hyperparameters that enable us to control the model complexity, so we can find that sweet spot between underfitting and overfitting.
We already saw one hyperparameter that controls the model’s complexity, which is the maximum depth allowed. As we saw, the more allowed depths, the more the tree can be specific and create very small regions. At some point, we clearly see that adding depth level makes the tree "learn" the noise. On the other hand, keeping the maximum depth very small only allows a few regions, and so the model doesn’t have the complexity to learn the data’s overall shape.
Another hyperparameter we briefly overviewed is the criteria- or metric- used to select the best split. For regression, we only saw the Mean Squared Error (MSE), but other criteria exist like the Median Absolute Error (MAE), which is the average distance (L1 norm) to the median. For classification, we saw two possible criteria: the Gini impurity and the Log-Loss/Entropy.
When creating new splits, we can impose a check and stop the splitting of a node if that node has just a few numbers of samples in it. Using the min_samples_split hyperparameter, we can specify down to how small we allow a node to be split. By default, this hyperparameter is equal to 2, which means that a node will be split only if it has at least 2 samples. So in this case, the default value means more or less that we can grow the tree down to leaves with a single sample. Using a bigger value means that the tree will stop earlier, effectively limiting its complexity. Note that sklearn also allows specifying the number of samples as a percentage of the total number of samples. So, for example, min_samples_split=0.02 means that a node with less than 2% of the total number of samples will not be split.
Another similar hyperparameter is min_samples_leaf, which controls the minimum number of samples a leaf can have. If a split would create nodes with less than that number, the split is aborted, and the branch stops growing. Like min_samples_split, we can also specify that number using a floating value between 0.0 and 1.0 to specify as a percentage of the total number of samples.
Other hyperparameters allow limiting, for example, the number of features to test for finding the best split or limiting the total number of leaves.
Let’s see a quick example with the same classification problem, but by playing around with some hyperparameters.
fig, axes = plt.subplots(3, 4, figsize=(18, 12))
X, y = make_moons(n_samples=100, noise=0.5, random_state=0)
def plot_decision_boundaries(ax1, ax2, max_depth=None, min_samples_split=2, min_samples_leaf=1):
tree_model = DecisionTreeClassifier(
random_state=42,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf
)
tree_model.fit(X, y)
ax1.set_title('Decision tree classification')
ax1.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolors='k', s=100, label="data")
plot_step = 0.01
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step))
Z = tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
ax1.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
tree.plot_tree(tree_model, filled=True, feature_names=['Feature 0', 'Feature 1'], class_names=['Class 0', 'Class 1'], ax=ax2)
ax2.set_title('Decision Tree Structure')
hyperparams = [
(None, 2, 1, "Baseline"),
(3, 2, 1, "max_depth=3"),
(5, 2, 1, "max_depth=5"),
(None, 5, 1, "min_samples_split=5"),
(None, 10, 1, "min_samples_split=10"),
(None, 2, 10, "min_samples_leaf=10"),
]
for i, tup in enumerate(hyperparams):
ax1 = axes.flatten()[i*2]
ax2 = axes.flatten()[i*2+1]
plot_decision_boundaries(ax1, ax2, *tup[:-1])
ax1.set_title(tup[-1])
fig.tight_layout()

As you can see, we can get many different trees by playing around with hyperparameters. Selecting the best ones is a hyperparameter tuning problem, which we saw how to conduct in a previous post.
Wrapup
So let’s recap what we learned about decision trees:
- They work recursively, splitting the root dataset into branches, until reaching condition that stops the branch from growing, effectively creating leaves.
- The splits are tested by trying all possible threshold values, for all features.
- The "best" split is identified using a loss function like the MSE for regression or Gini for classification.
- The very nature of decision tree allows us to control their complexity with many hyperparameters, like the maximum depth or the minimum number of samples for a node to be split.
_Bonus visualisation tool: As an addition to the nice visualisation provided by sklearn with the plottree funciton, you can also use the dtreevis package to visualise a decision tree at https://github.com/parrt/dtreeviz.
You might like some of my other posts, make sure to check them out:
Fourier-transforms for time-series