The world’s leading publication for data science, AI, and ML professionals.

Machine Learning with R: A Complete Guide to Decision Trees

An easy and straightforward guide to machine learning and classification with decision trees.

Photo by Fabrice Villard on Unsplash
Photo by Fabrice Villard on Unsplash

Decision trees are among the most fundamental algorithms in supervised machine learning, used to handle both regression and classification tasks. In a nutshell, you can think of it as a glorified collection of if-else statements, but more on that later.

Today you’ll learn the basic theory behind the decision trees algorithm and also how to implement the algorithm in R.

The article is structured as follows:

  • Introduction to Decision Trees
  • Dataset Loading and Preparation
  • Modeling
  • Making Predictions
  • Conclusion

Introduction to Decision Trees

Decision trees are intuitive. All they do is ask questions, like is the gender male or is the value of a particular variable higher than some threshold. Based on the answers, either more questions are asked, or the classification is made. Simple!

To predict class labels, the decision tree starts from the root (root node). Calculating which attribute should represent the root node is straightforward and boils down to figuring which attribute best separates the training records. The calculation is done with the gini impurity formula. It’s simple math, but can get tedious to do manually if you have many attributes.

After determining the root node, the tree "branches out" to better classify all of the impurities found in the root node.

That’s why it’s common to hear decision tree = multiple if-else statements analogy. The analogy makes sense to a degree, but the conditional statements are calculated automatically. In simple words, the machine learns the best conditions for your data.

Let’s take a look at the following decision tree representation to drive these points further home:

Image 1 - Example decision tree (source)
Image 1 – Example decision tree (source)

As you can see, variables Outlook?, Humidity?, and Windy? are used to predict the dependent variable – Play.

You now know the basic theory behind the algorithm, and you’ll learn how to implement it in R next.


Dataset Loading and Preparation

There’s no machine learning without data, and there’s no working with data without libraries. You’ll need these ones to follow along:

As you can see, we’ll use the Iris dataset to build our decision tree classifier. This is how the first couple of lines look like (output from the head() function call):

Image 2 - Iris dataset head (image by author)
Image 2 – Iris dataset head (image by author)

The dataset is pretty much familiar to anyone with a week of experience in Data Science and machine learning, so it doesn’t require further introduction. Also, the dataset is as clean as they come, which will save us a lot of time in this section.

The only thing we have to do before continuing to predictive modeling is to split this dataset randomly into training and testing subsets. You can use the following code snippet to do a split in 75:25 ratio:

And that’s it! Let’s start with modeling next.


Modeling

We’re using the rpart library to build the model. The syntax for building models is identical as with linear and logistic regression. You’ll need to put the target variable on the left and features on the right, separated with the ~ sign. If you want to use all features, put a dot (.) instead of feature names.

Also, don’t forget to specify method = "class" since we’re dealing with a classification dataset here.

Here’s how to train the model:

The output of calling model is shown in the following image:

Image 3 - Decision tree classifier model (image by author)
Image 3 – Decision tree classifier model (image by author)

From this image alone, you can see the "rules" decision tree model used to make classifications. If you’d like a more visual representation, you can use the rpart.plot package to visualize the tree:

Image 4 - Visual representation of the decision tree (image by author)
Image 4 – Visual representation of the decision tree (image by author)

You can see how many classifications were correct (in the train set) by examining the bottom nodes. The setosa was correctly classified every time, the versicolor was misclassified for virginica 5% of the time, and virginica was misclassified for versicolor 3% of the time. It’s a simple graph, but you can read everything from it.

Decision trees are also useful for examining feature importance, ergo, how much predictive power lies in each feature. You can use the varImp() function to find out. The following snippet calculates the importances and sorts them descendingly:

The results are shown in the image below:

Image 5 - Feature importances (image by author)
Image 5 – Feature importances (image by author)

You’ve built and explored the model so far, but there’s no use in it yet. The next section shows you how to make predictions on previously unseen data and evaluate the model.


Making Predictions

Predicting new instances is now a trivial task. All you have to do is use the predict() function and pass in the testing subset. Also, make sure to specify type = "class" for everything to work correctly. Here’s an example:

The results are shown in the following image:

Image 6 - Decision tree predictions (image by author)
Image 6 – Decision tree predictions (image by author)

But how good are these predictions? Let’s evaluate. The confusion matrix is one of the most commonly used metrics to evaluate classification models. In R, it also outputs values for other metrics, such as sensitivity, specificity, and the others.

Here’s how you can print the confusion matrix:

And here are the results:

Image 7 - Confusion matrix on the test set (image by author)
Image 7 – Confusion matrix on the test set (image by author)

As you can see, there are some misclassifications in versicolor and virginica classes, similar to what we’ve seen in the training set. Overall, the model is just short of 90% accuracy, which is more than acceptable for a simple decision tree classifier.

Conclusion

Decision trees are an excellent introductory algorithm to the whole family of tree-based algorithms. It’s commonly used as a baseline model, which more sophisticated tree-based algorithms (such as random forests and gradient boosting) need to outperform.

Today you’ve learned basic logic and intuition behind decision trees, and how to implement and evaluate the algorithm in R. You can expect the whole suite of tree-based algorithms covered soon, so stay tuned if you want to learn more.


Loved the article? Become a Medium member to continue learning without limits. I’ll receive a portion of your membership fee if you use the following link, with no extra cost to you.

Join Medium with my referral link – Dario Radečić


Stay Connected


Originally published at https://appsilon.com on February 10, 2021.


Related Articles