CODEX

Building and Visualizing Decision Tree in Python

Learn to build and visualize a Decision tree model with scikit-learn in Python

Nikhil Adithyan
CodeX
Published in
5 min readOct 26, 2020

--

Photo by Illiya Vjestica on Unsplash

Decision Tree

Decision trees are the building blocks of some of the most powerful supervised learning methods that are used today.

‘A decision tree is basically a binary tree flowchart where each node splits a group of observations according to some feature variable. The goal of a decision tree is to split the data into groups such that every element in one group belongs to the same category.’

One of the great properties of decision trees is that they are very easily interpreted. You do not need to be familiar at all with machine learning techniques to understand what a decision tree is doing. Decision tree graphs are feasibly interpreted.

Python for Decision Tree

Python is a general-purpose programming language and offers data scientists powerful machine learning packages and tools. In this article, we will be building our Decision tree model using python’s most famous machine learning package, ‘scikit-learn’. We will be creating our model using the ‘DecisionTreeClassifier’ algorithm provided by scikit-learn then, visualize the model using the ‘plot_tree’ function. Let’s do it!

Step-1: Importing the packages

Our primary packages involved in building our model are pandas, scikit-learn, and NumPy. Follow the code to import the required packages in python.

After importing all the required packages for building our model, it’s time to import the data and do some EDA on it.

Step-2: Importing data and EDA

In this step, we will be utilizing the ‘Pandas’ package available in python to import and do some EDA on it. The dataset we will be using to build our decision tree model is a drug dataset that is prescribed to patients based on certain criteria. Let’s import the data in python!

Python Implementation:

Output:

   Age Sex      BP Cholesterol  Na_to_K   Drug
0 23 F HIGH HIGH 25.355 drugY
1 47 M LOW HIGH 13.093 drugC
2 47 M LOW HIGH 10.114 drugC
3 28 F NORMAL HIGH 7.798 drugX
4 61 F LOW HIGH 18.043 drugY

Now we have a clear idea of our dataset. After importing the data, let’s get some basic information on the data using the ‘info’ function. The information provided by this function includes the number of entries, index number, column names, non-null values count, attribute type, etc.

Python Implementation:

Output:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Age 200 non-null int64
1 Sex 200 non-null object
2 BP 200 non-null object
3 Cholesterol 200 non-null object
4 Na_to_K 200 non-null float64
5 Drug 200 non-null object
dtypes: float64(1), int64(1), object(4)
memory usage: 9.5+ KB

Step-3: Data Processing

We can see that attributes like Sex, BP, and Cholesterol are categorical and object type in nature. The problem is, the decision tree algorithm in scikit-learn does not support X variables to be ‘object’ type in nature. So, it is necessary to convert these ‘object’ values into ‘binary’ values. Let’s do it in python!

Python Implementation:

Output:

     Age  Sex  BP  Cholesterol  Na_to_K   Drug
0 23 1 2 1 25.355 drugY
1 47 1 0 1 13.093 drugC
2 47 1 0 1 10.114 drugC
3 28 1 1 1 7.798 drugX
4 61 1 0 1 18.043 drugY
.. ... ... .. ... ... ...
195 56 1 0 1 11.567 drugC
196 16 1 0 1 12.006 drugC
197 52 1 1 1 9.894 drugX
198 23 1 1 1 14.020 drugX
199 40 1 0 1 11.349 drugX

[200 rows x 6 columns]

We can observe that all the ‘object’ values are processed into ‘binary’ values to represent categorical data. For example, in the Cholesterol attribute, values showing ‘LOW’ are processed to 0 and ‘HIGH’ to be 1. Now we are ready to create the dependent variable and independent variable out of our data.

Step-4: Splitting the data

After processing our data to be of the right structure, we are now set to define the ‘X’ variable or the independent variable and the ‘Y’ variable or the dependent variable. Let’s do it in python!

Python Implementation:

Output:

X variable samples : [[ 1.     2.    23.     1.    25.355]
[ 1. 0. 47. 1. 13.093]
[ 1. 0. 47. 1. 10.114]
[ 1. 1. 28. 1. 7.798]
[ 1. 0. 61. 1. 18.043]]

Y variable samples : ['drugY' 'drugC' 'drugC' 'drugX' 'drugY']

We can now split our data into a training set and testing set with our defined X and Y variables by using the ‘train_test_split’ algorithm in scikit-learn. Follow the code to split the data in python.

Python Implementation:

Output:

X_train shape : (160, 5)
X_test shape : (40, 5)
y_train shape : (160,)
y_test shape : (40,)

Now we have all the components to build our decision tree model. So, let’s proceed to build our model in python.

Step-5: Building the model & Predictions

Building a decision tree can be feasibly done with the help of the ‘DecisionTreeClassifier’ algorithm provided by the scikit-learn package. After that, we can make predictions of our data using our trained model. Finally, the precision of our predicted results can be calculated using the ‘accuracy_score’ evaluation metric. Let’s do this process in python!

Python Implementation:

Output:

Accuracy of the model is 88%

In the first step of our code, we are defining a variable called the ‘model’ variable in which we are storing the DecisionTreeClassifier model. Next, we are fitting and training the model using our training set. After that, we defined a variable called the ‘pred_model’ variable in which we stored all the predicted values by our model on the data. Finally, we calculated the precision of our predicted values to the actual values which resulted in 88% accuracy.

Step-6: Visualizing the model

Now that we have our decision tree model and let’s visualize it by utilizing the ‘plot_tree’ function provided by the scikit-learn package in python. Follow the code to produce a beautiful tree diagram out of your decision tree model in python.

Python Implementation:

Output:

Image by Author

Conclusions!

There are a lot of techniques and other algorithms used to tune decision trees and to avoid overfitting, like pruning. Although, decision trees are usually unstable which means a small change in the data can lead to huge changes in the optimal tree structure yet their simplicity makes them a strong candidate for a wide range of applications. Before neural networks became popular, decision trees were the state-of-the-art algorithm in Machine Learning. With that, we come to an end and if you forget to follow any of the coding parts, don’t worry I’ve provided the full code for this article.

Happy Machine Learning!

Full code:

--

--

Nikhil Adithyan
CodeX

Founder @BacktestZone (https://www.backtestzone.com/), a no-code backtesting platform | Top Writer | Connect with me on LinkedIn: https://bit.ly/3yNuwCJ