Using sklearn, graphviz and dtreeviz Python packages for fancy visualization of decision trees

Data Visualization plays a key role in data analysis and machine learning fields as it allows you to reveal the hidden patterns behind the data. Model visualization allows you to interpret the model. The visualization process is now easy with plenty of available Python packages today.
Tree-based models such as Decision Trees, Random Forests and XGBoost are more popular for supervised learning (classification and repression) tasks. This is because those models are well fitted on non-linear data which are frequently used in real-world applications.
The baseline model for any tree-based model is the Decision Tree. Random Forests consist of multiple decision trees. Today, we’ll discuss 4 different ways to visualize individual decision trees in a Random Forest. Please note that the methods discussed here are also commonly applied to any tree-based model, not just to Random Forests. We’ll use sklearn, graphviz and dtreeviz Python packages which make it easy to create visualizations with just a few code lines.
The following are the 4 ways of visualization of trees that we discuss today.
- Plot decision trees using sklearn.tree.plot_tree() function
- Plot decision trees using sklearn.tree.export_graphviz() function
- Plot decision trees using dtreeviz Python package
- Print decision tree details using sklearn.tree.export_text() function
The first three methods build the decision tree in the form of a graph. The last method builds the decision tree in the form of a text report.
Prerequisites
I recommend you to read the following contents written by me as they are prerequisites for today’s content.
- Random forests – An ensemble of decision trees
- Train a regression model using a decision tree
- 9 Guidelines to master Scikit-learn without giving up in the middle
Building a random forest model on "wine data"
Before discussing the above 4 methods, first, we build a random forest model on "wine data". This model can be used as an input for the above 4 methods. The "wine dataset" is available in the Scikit-learn built-in datasets.
The X is ** the feature matrix and y is the label column. The "wine data" has 3 class labels named as ‘class_0‘, ‘class_1‘, ‘class_2‘. Both ** X and y are used as the input for the random forest model. Since this is a classification task, we build a RandomForestClassifier() on "wine data". For regression tasks, you can use RandomForestRegressor().
The model is now fitted on "wine data" and can be accessed through the rf variable.
Accessing individual decision trees in a random forest
The number of trees in a random forest is defined by the n_estimators parameter in the RandomForestClassifier() or RandomForestRegressor() class. In the above model we built, there are 100 trees. Each tree can be accessed from:
rf.estimators_[index]
rf is the random forest model. Here, index values are ranging from 0 to 99 (both inclusive). The 0 represents the first decision tree. The 99 represents the last one.
Plot decision trees using sklearn.tree.plot_tree() function
This is the simple and easiest way to visualize a decision tree. You do not need to install any special Python package. If you’ve already installed Anaconda, you’re all set! This function does not adjust the size of the figure automatically. Therefore, the contests of larger decision trees will not be clear. To avoid this, you have to use the figsize argument of plt.figure to control the figure size.
Let’s plot the first decision tree (accessed by index 0) in our random forest model using this method.

You can save the figure as a PNG file by running:
fig.savefig('figure_name.png')
To learn more about the parameters of the sklearn.tree.plot_tree() function, please read its documentation.
Plot decision trees using sklearn.tree.export_graphviz() function
In contrast to the previous method, this method has an advantage and a disadvantage. The advantage is that this function adjusts the size of the figure automatically. Therefore, you do not need to worry about it when you plot larger trees. The disadvantage is that you should install the Graphviz Python package by running the following command in your Anaconda command prompt.
pip install graphviz
If that didn’t work for you, try the following one:
conda install -c anaconda graphviz
Let’s plot the last decision tree (accessed by index 99) in our random forest model using Graphviz.

You can save the figure by running:
graph.render('figure_name')
Now, the graph will be saved as a PNG file because we specified format="png" in the graphviz.Source() function. If you specify format="pdf", the graph will be saved as a PDF file.
Plot decision trees using dtreeviz Python package
The dtreeviz Python package can be used to plot decision trees. It creates some nice visualizations. Let’s see them in action for both classification and regression datasets.
Before that, you need to install dtreeviz in your Anaconda environment by running the following line of code:
pip install dtreeviz
If you get an error due to a PATH problem when running dtreeviz in your Anaconda environment, it is better to use dtreeviz in the Google Colab notebook environment. Run the following line of code to install dtreeviz there.
!pip install dtreeviz
Now, you’re ready to run dtreeviz in your own Colab environment.
Dtreeviz – Plot a decision tree on a classification dataset
Now, we use dtreeviz to plot the last decision tree of the random forest model built on the "wine dataset" which is used for classification. Here is the code.
After running the above code, a pop-up window will appear asking to save the SVG file on your computer. You can convert the SVG file into a PNG file by using this free online converter. Set a higher value (e.g. 250) in the Pixel Density box to get a high-resolution image.

Dtreeviz – Plot a decision tree on a regression dataset
Now, we use dtreeviz to plot the first decision tree of the random forest model built on the "boston house-prices dataset" which is used for regression. Here is the code.
After running the above code, a pop-up window will appear asking to save the SVG file on your computer. You can convert the SVG file into a PNG file by using this free online converter. Set a higher value (e.g. 250) in the Pixel Density box to get a high-resolution image.

Print decision tree details using sklearn.tree.export_text() function
In contrast to the previous 3 methods, this method builds the decision tree in the form of a text report.
Let’s print the decision tree details of the first decision tree (accessed by index 0) in our random forest model using this method.

The details are exactly the same as the figure we obtained using the sklearn.tree.plot_tree() function.
Summary
The sklearn, graphviz and dtreeviz Python packages provide high-level functions to plot the decision trees. The sklearn functions are much easier to use and give detailed visualizations. The graphviz and dtreeviz Python packages should be installed separately before using them. You may get an error due to a PATH problem when running dtreeviz in your Anaconda environment. Therefore, it is better to use dtreeviz in the Google Colab notebook environment. dtreeviz __ creates some nice visualizations compared to other methods. My favourite one is the _dtreevi_z. What’s your favourite one? Write in the comment section. Please also let me know how I did by writing your valuable feedback so that I can develop my future contents based on them.
Thanks for reading!
This tutorial was designed and created by Rukshan Pramoditha, the Author of Data Science 365 Blog.
Read my other articles at https://rukshanpramoditha.medium.com
2021–05–07