In recent years, Machine Learning has exploded in popularity, and Neural Deep Learning models have blown shallow models like XGBoost [4] out of the water for complex tasks like image and text processing. However, deep models are often less effective than these shallow models regarding tabular data, and no universal deep learning approach consistently outperforms gradient-boosting trees.
To address this gap, Russian internet services company Yandex researchers have proposed a new architecture: Neural Oblivious Decision Ensembles (NODE) [1]. This network leverages lightweight and interpretable neural decision trees and integrates them within a neural network framework. This allows the model to capture complex interactions and dependencies in tabular data while maintaining interpretability.
In this article, I aim to explain how NODE works and the various attributes that make it a robust yet interpretable prediction model. As usual, I encourage everyone to read the original paper. If you want to use NODE, please check out the GitHub for the model.
This article is part of a series on Neural Decision Trees, highly explainable architectures that provide predictive power equivalent to traditional deep networks.
NODE Decision Tree Structure
Neural Decision Trees
This article assumes you have some familiarity with Neural Decision Trees. If you don’t, I highly encourage reading my previous article on them for an in-depth explanation. However, in summary: Neural Decision trees are decision trees that are both soft and oblique.
An oblique tree is one where multiple variables are used to make decisions in each node (usually arranged in a linear combination). For example, to predict a car accident, an orthogonal tree may produce a branching decision using the rule "car_speed – speed_limit <10." This differs from "orthogonal" trees like CART (the basic decision tree), which only uses one variable at any given node and will need more nodes to approximate the same decision boundary.
A soft tree is one where all branching decisions are probabilistic, and the calculations at each node define the probability of going into a particular branch. This is unlike regular, "hard" decision trees like CART, where each branching decision is deterministic.
Since the tree does not restrict the number of variables used in each node, and the branching decisions are continuous, the entire tree is differentiable. Since the whole tree is differentiable, it can be integrated into any Neural Network framework such as Pytorch or Tensorflow and trained using traditional neural optimizers (ex. Stochastic Gradient Descent and Adam).
NODE Trees
The decision trees NODE uses are slightly different from your traditional Neural Tree. Let’s break down all the differences.
Oblivious Nature
The first significant change is the fact that the trees are Oblivious. This means the tree uses the same splitting weights and thresholds for all internal nodes at the same depth. As a result, Oblivious Decision Trees (ODTs) can be represented as a decision table with 2^d entries (d being the depth). One benefit is that ODTs are more interpretable than traditional decision trees as there are fewer decisions to parse, making the decision path easier to visualize and understand. However, ODTs are significantly weaker learners when compared to traditional decision trees (again due to the constrained nature of the splitting functions).
So if our goal is performance, why would we use ODTs? As the developers of CATBoost [2] showed, ODTs function incredibly well when ensembled together and are less prone to overfitting the data. Additionally, the inference of ODTs is extremely efficient as the splits can all be computed in parallel to find the appropriate entry in the table quickly.
Entmax for Feature Selection and Branching
NODE’s second improvement over the traditional neural decision tree is the use of alpha-entmax [3] over sigmoid in its architecture. Alpha-entmax is a generalized version of softmax capable of producing sparse distributions where most of the result equals zero. This sparsity is controlled by a parameter (alpha hence the name) where the higher the alpha, the more sparse the distribution.
This transformation is used in two key places. The first use is in sparse feature selection. NODE includes a trainable feature selection weight matrix F (of size d x n where n is the number of features and d is the depth of the tree) passed through the entmax transformation. Since most entries of the entmax transformation are equal to zero, this naturally results in a small number of features being used in each decision node.
In addition to feature selection, entmax is also used for the branching probabilities. This is done by passing the branching function’s result, subtracting a learned threshold, and scaling it appropriately. This value is then concatenated with 0 and passed into the entmax function to create a 2-class probability distribution, which is exactly what we need for branching.
Using this, we can define a "choice" tensor C by computing the outer product of all branching distributions c. This can then be multiplied by the values in the leaf to create the result of the network.
Ensembling
As the name suggests, these Neural Oblivious Decision Trees are Ensembled together. A NODE layer is defined as a concatenation of m individual trees, each with its own branching decisions and leaf values. As mentioned before, this ensembling synergizes with the oblivious nature of the individual trees and helps increase accuracy with a reduced chance of overfitting.
Multilayer NODE
NODE is a flexible architecture that can be trained alone (resulting in a single ensembling of decision trees) or with a complex multi-layer structure where each set of ensembles takes input from the previous layer.
The multi-layer architecture of NODE closely follows the popular DenseNet architecture. Each NODE layer contains several trees whose outputs are concatenated and are inputs for subsequent layers. The final output is then obtained by averaging the output of all trees from all layers. Since each layer relies on chains of all previous predictions, the network can capture complex dependencies.
Experimental Performance
To test their architecture, Popov et al. (2019) compared NODE to CatBoost [2], XGBoost[4], a fully connected neural network, mGBDT [5], and DeepForest [6]. Their methodology involved testing the models on six different datasets. Specifically, they did a comparison using each model’s default parameters and another comparison where each model had tuned hyperparameters.
The experimental results for NODE are extremely encouraging. for one, the NODE architecture outperforms all the other models with the default parameters. Even with tuned parameters, NODE outperforms most other models on 4 of the 6 chosen datasets.
Conclusion
By incorporating the advantages of decision trees into the neural network architecture, NODE opens up new possibilities for Deep Learning applications in domains where structured tabular data is prevalent, such as finance, healthcare, and customer analytics.
This isn’t to say that NODE is perfect, however. For one, the ensembling of the architecture means that many of the local interpretability gains from using neural decision trees are thrown away, and only global feature importance can be gleaned from the model. However, this architecture does provide the building blocks for improving neural interpretability, and a follow-up model (NODE-GAM [7]) has been proposed to bridge the interpretability gap.
Additionally, while NODE outperforms many shallow models, my experience using it has indicated that it takes longer to train, even when using GPUs (a conclusion supported by the experimental results provided by the paper authors).
Overall this is an extremely promising approach and one I plan on actively using as a component of deep learning models I develop in the future.
Resources and References
- NODE Paper: https://arxiv.org/abs/1909.06312
- NODE Code: https://github.com/Qwicen/node
- NODE can also be found in the Pytorch Tabular package: https://github.com/manujosephv/pytorch_tabular
- If you are interested in Interpretable Machine Learning or Time Series Forecasting, consider following me: https://medium.com/@upadhyan.
- See my other articles on neural decision trees: https://medium.com/@upadhyan/list/3b4a9cb97b84
References
[1] Popov, S., Morozov, S., & Babenko, A. (2019). Neural oblivious decision ensembles for deep learning on tabular data. Eight International Conference on Learning Representations.
[2] Prokhorenkova, L., Gusev, G., Vorobev, A., Dorogush, A. V., & Gulin, A. (2018). CatBoost: unbiased boosting with categorical features. Advances in neural information processing systems, 31.
[3] Peters, B., Niculae, V., & Martins, A. (2019). Sparse Sequence-to-Sequence Models. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (pp. 1504–1519). Association for Computational Linguistics.
[4] Chen, T., & Guestrin, C. (2016, August). Xgboost: A scalable tree boosting system. In Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and data mining (pp. 785–794).
[5] Feng, J., Yu, Y., & Zhou, Z. H. (2018). Multi-layered gradient boosting decision trees. Advances in neural information processing systems, 31.
[6] Zhou, Z. H., & Feng, J. (2019). Deep forest. National science review, 6(1), 74–86.
[7] Chang, C.H., Caruana, R., & Goldenberg, A. (2022). NODE-GAM: Neural Generalized Additive Model for Interpretable Deep Learning. In International Conference on Learning Representations.