The world’s leading publication for data science, AI, and ML professionals.

XGBoost: How Deep Learning Can Replace Gradient Boosting and Decision Trees – Part 2: Training

A world without if

Photo by Simon Wilkes on Unsplash
Photo by Simon Wilkes on Unsplash

In a previous article:

XGBoost: How Deep Learning Can Replace Gradient Boosting and Decision Trees – Part 1

you have learned about rewriting decision trees using a Differentiable Programming approach, as suggested by the NODE paper. The idea of this paper is to replace XGBoost by a Neural Network.

More specifically, after explaining why the process of building Decision Trees is not differentiable, it introduced the necessary mathematical tools to regularize the two main elements associated with a decision node:

  • Feature Selection
  • Branch detection

The NODE paper shows that both can be handled using the entmax function.

To summarize, we have shown how to create a binary tree without using comparison operators.

The previous article ended with open questions regarding training a regularized decision tree. It’s time to answer these questions.

If you’re interested in a deep dive in Gradient Boosting Methods, have a look at my book:

Practical Gradient Boosting: A deep dive into Gradient Boosting in Python

A smooth decision node

First, based on what we presented in the previous article, let’s create a new Python class: SmoothBinaryNode .

This class encodes the behavior of a smooth binary node. There are two key parts in its code :

  • The selection of the features, handled by the function _choices
  • The evaluation of these features, with respect to a given threshold, and the identification of the path to follow: left or right . All this is managed by the methods left and right .

    As explained in the previous article, the key to regularize a binary node (and hence to allow its training) is to use the entmax function.

The almighty dot product

Please note that both feature selection and left and right branch selection is done thanks to a dot product. The dot product is a simple operation, as long as one only considers its implementation, but in reality, it is very powerful.

It can be used to make projections, compute cosinus , find the intersection between rays and triangles, … But this is another story.

A simple binary node

Let’s see with a few lines of code how we can use this new class on a very basic binary tree with only one node:

In this snippet, we configure our node so that its left leave contains a 1 and the right one a 1 . This is defined by the parameter leaves . We also set the threshold to 50, as you can see in the parameter biais . Finally, the feature selection is done through the weights defined by weights.

Here is another slightly more complex example, with a two-level tree:

The principle is similar to the previous example, except that 2 nodes are added to the root.

Parameters to learn

In the previous example, we have manually define three parameters:

  • The weights are used to select features. Thanks to the entmax function, the features with the highest weight with respect to the other will be chosen.
  • Once the feature is selected, we need to find the best threshold to split data into two subsets.
  • Finally, we have to define the value attached to the leaves of the tree, here the parameter leaves .

Those are the parameters that will be learned during the training process.

Defining the objective

As always when doing Machine Learning, the goal is to find the combination of parameters that minimize some cost functions.

The most current one is the Mean Squared Error , which is quite simple to compute and can be written in Python as follows:

Note that there is a little subtlety in the code above, as the msefunction creates a closure that captures the tree object.

We can apply this function to the simple, one-level tree defined below:

As expected, the error is null.

Learning with Gradient Descent

If there is one mathematical tool that is associated with Machine Learning, it’s Gradient Descent. Deep Learning, simple Neural Networks, and even Gradient Boosting (but in a functional space) use Gradient Descent to minimize some kind of cost function to train a model.

The main difficulty in this process is to compute the gradient for a complex function, which is the mathematical composition of linear function (layer inputs) and non-linear function (activation functions).

Even though this difficulty has prevented the success of Neural Network methods for a few decades, the underlying mathematical principles have been known for many years: differentiation rules.

Nowadays, computing the gradient of a complex function, being the composition of many linear and non-linear base functions can be done in a simple and efficient way with a single line of code, using Automatic Differentiation.

As you can see in the code for the class SmoothBinaryNode , we have isolated these 3 parameters, weights, bias, leaves in the variable params .

Using the Gradient Descent method is the standard way to optimize the parameters to minimize the error.

Both Automatic Differentiation and Gradient Descent are concepts that I explore in my book 70 mathematical concepts:

Unveiling 70 Mathematical Concepts with Python: A Practical Guide to Exploring Mathematics Through…

As its name implies, Gradient Descent requires computing the Gradient of the error function, with respect to the parameters we want to optimize.

Thanks to the library jax , it’s damned simple to compute the gradient of any function, using grad function.

Jax use automatic differentiation to automatically and efficiently compute derivative. If you’re interested in this compelling method, have a look at my introductory article on the subject :

Differentiable Programming from Scratch

One of the ways to implement it is shown in this piece of code:

This code uses the library jax first to compute the gradient of the error for a set of parameters that is optimal, i.e. a set of parameters for which the error is zero. Hence we ensure that as expected the gradient is null in this case.

Then we slightly perturb the leaves parameter and ensure that the gradient is non-null in the direction of this parameter. This is hopefully the case. Same when modifying the bias parameter.

Finally, starting from a perturbated, non-optimal set of parameters, we use the Gradient Descent method to iteratively update them.

After an arbitrary 1000 iterations, the parameters converge to another set that minimizes the error.

Vanishing Gradient

I must confess that I cheated a little in the example above. I deliberately chose parameters in a non-flat region of the error function.

If you remember the shape of the entmax function, when using a high alpha the transition from 0 to 1 is very steep. This means that for any values slightly distant from 0, the curve of the function is completely flat.

As the gradient is by definition the slope of a curve, the gradient is null in this region.

As the Gradient Descent method updates the parameters by adding a small increment being the product of the learning rate and the gradient, the optimization fails.

One option to avoid this annoying limitation is to perform batch normalization to ensure that the feature remains in a region where the entmax function is not flat

https://www.buymeacoffee.com/guillaumes0
https://www.buymeacoffee.com/guillaumes0

Conclusion

We have seen in this series of two articles how to regularize decision trees. We have shown how to replace feature selection and branch selection with the entmax function and a dot product.

This regularization allows to use smooth decision trees in the mathematical framework of Differentiable programming.

Being able to use this formalism is very powerful, as it allows us to mix any kind of Neural Network (a complex differentiable function) with decision trees.


Related Articles