DRF: A Random Forest for (almost) everything

Jeffrey Näf
Towards Data Science
16 min readFeb 1, 2022

--

Photo by Geran de Klerk on Unsplash

When Breiman introduced the Random Forest (RF) algorithm in 2001, did he know the tremendous effect it would have? Nowadays RF is a heavily used tool in many parts of data science. It’s easy to see why — RF is easy-to-use and known for its high performance in an extremely wide range of tasks. This alone is impressive, but what makes it even more interesting is that no tuning is generally required to obtain these results.

In this article we discuss a pretty massive extension of the original RF, the Distributional Random Forest (DRF) we recently developed in this paper:

Distributional Random Forests: Heterogeneity Adjustment and Multivariate Distributional Regression (jmlr.org)

joint work with Loris Michel, who also contributed to this article. Additional code and a python implementation can be found here: https://github.com/lorismichel/drf

We will start with a small introduction to the topic of Random Forests and then go through the steps to develop DRF. A small list of examples awaits at the end, which we happily extend in the future.

Before we go any further, we first need to get some notation out of the way: We will in this article loosely refer to X,Y as random vectors from which we assume to observe an iid sample (yi,xi), i=1,…n. This is quite common in research papers. If we deal with random variables instead (i.e. univariate), we just write Y or X. The main goal in many applications of data science is predicting Y from some features X, say we want to predict a label “dog” (Y) from say a vector of pixel intensities (X). Less glamorously formulated, a prediction is actually almost always a function of the conditional distribution P(Y|X=x). So we ask “if the vector X is fixed to a value x then what is a certain distributional aspect of Y?”. The most prominent example is the conditional mean or expectation, i.e. the expected value under the distribution P(Y|X=x). This is what people usually want when they ask for a prediction. For example, if we predict with linear regression, we plug in a value x into the linear formula and get an expected value of Y as the prediction. The image classification example above is already different — here we might ask for the most likely value, i.e. for which label l is P(Y=l|X=x) maximized. Long story short: Most predictions are aspects of a conditional distribution, often conditional expectation.

Intriguingly, people later figured out RF can be used not just for conditional expectation or mean predictions, but also to predict other aspects of a (conditional) distribution. For instance, one can predict quantiles and thus obtain prediction intervals measuring uncertainty. DRF takes this approach to a new level: You can feed it a multivariate response Y together with the typical vector of features X and it will then estimate the whole conditional distribution of the multivariate Y given any (realization) x. This estimate is given in the form of simple weights, making it easy to simulate from the distribution or calculate quantities of interest, i.e. “predictions”.

A thorough explanation of RF is given here, for instance, so we just touch on this briefly: In a RF several independent decision trees are fitted. Each tree just gets a part of the variable and then splits the outcome space according to the features in X.

A Random Forest in action. A new point x is dropped down each tree (in red) and lands in a leaf or terminal node. Source: Image generated with TikZ code from stackexchange.

So when we ‘’drop down’’ a new point x, it will end up in a leaf for each tree. A leaf is a set with observations i and taking the average over all yi in that leaf gives the prediction for one tree. These predictions are then averaged to give the final result. Thus, for a given x if you want to predict the conditional mean of Y given that x, you:

1. “Drop down” the x each tree (this is indicated in red in the above figure). Since the splitting rules were made on X, your new point x will safely land somewhere in a leaf node.

2. For each tree you average the responses yi in that leaf to get an estimate of the conditional mean of each tree.

3. You average each conditional mean over the trees to get the final prediction.

Importantly averaging the prediction of all trees leads to a marked reduction in variance. The errors get ‘’averaged out’’ so to speak. This is why it is important to make the trees as independent as possible.

We now work towards an extension of this idea to not just make a mean prediction but to instead predict the whole (conditional) distribution. That is, we want to predict P(Y|X=x) for a test point x. As mentioned already, many things one usually wants to predict are direct functions of this conditional distribution: The (conditional) mean, conditional quantiles, conditional variances, covariances and so on and so on. It turns out the form of DRF also makes it suitable for causal analysis, beyond what is possible even for state-of-the-art methods. We give a small example at the end of this document and potentially a lot more in later posts.

For now, we derive step by step how to get to the Distributional Random Forest. The first step is to get the weights.

The Theory

Step 1: Gaining some weight

The key to this is to get the weights that RF delivers. This is usually overlooked in a first introduction into RF, but it is vital here:

Instead of directly calculating the mean in a leaf node, one calculates the weights that are implicitly used when doing the mean calculation. The weight wi(x) is a function of (1) the test point x and (2) an observation i. That is, if we drop x down a tree, we observe in which leaf it ends up. All observations that are in that leaf get a 1, all others 0. So if we end up in a leaf with observations (1,3,10), then the weight of observations 1,3,10 for that tree is 1, while all other observations get 0. We then further divide that weight by the number of elements in the leaf node. In the example before, we had 3 observations in the leaf node, so the weights for observations 1,3, 10 are 1/3 each, while all other observations still get a weight of 0 in this tree. Averaging these weights over all trees gives the final weight wi(x).

Why are these weights useful? First, it can be shown that calculating the mean as we do in a traditional Random Forest, is the same as summing up wi(x)*yi. This looks a bit ugly, but it's basically just about exchanging two sums. So if this estimate is sensible it should make sense that if you take other functions of yi, say yi², you get the mean of Y² given X. Or you, for some number t, you could use I(yi<t) that is 1 if yi < t and 0 else and see the sum over wi(x)I(yi<t) as an estimate of the conditional probability that Y < t. Thus you can directly obtain the conditional cumulative distribution function (cdf), which fully characterizes the conditional distribution in theory. Second, these weights actually give you a nearest-neighbor estimate over observations that are similar in a certain sense to the query x. As we will see below, this similarity is measured in a way that ideally makes it possible to see the associated yi’s as an iid sample from the conditional distribution.

This can be taken even further. In particular, if you sample from Y with these weights, you will get an approximate sample of the conditional distribution. Before we go on, let’s make a quick example using the famous ranger package in R. The package does not directly provide these weights, so we need to implement them manually, which is a good exercise here:

library(ranger)
library(ggplot2)
n<-1000
## Simulate some data:
X<-matrix(rnorm(n*2),ncol=2)
Y=2*X[,1] + 0.5*X[,2]^2 + rnorm(n)
Y<-as.matrix(Y)
## Get a test point
x<-matrix(c(1,-0.5),ncol=2)
# Fit Random forest
RF<-ranger(Y~., data=data.frame(Y=Y, X=X),min.node.size = 10, num.trees=500)
## Get the leaf/terminal node indices of each observation in the training samples
terminalX <- predict(RF, data=data.frame(X=X), type="terminalNodes")$predictions
## Get the leaf/terminal node of the test point
terminalxnew <- predict(RF, data=data.frame(X=x), type="terminalNodes")$predictions
## For the leafs in which x ends up, what are the number of observations (we need to normalize by that)
divid<-sapply(1:ncol(terminalxnew[1,, drop=F]), function(j) {sum(terminalX[,j]==terminalxnew[1,j])} )

# Average the weights over the trees:
weights<-rowSums(t(sapply(1:nrow(terminalX), function(j) as.numeric(terminalX[j,]==terminalxnew[1,])/divid )), na.rm=T)/length(terminalxnew[1,])
## We can now sample according to those weights from Y
Ysample<-Y[sample(1:dim(Y)[1], prob=weights, replace = T),]
## True conditional distribution
Ytrue<-2*x[,1] + 0.5*x[,2]^2 + rnorm(n)
p1 <- hist(Ytrue, freq=F)
p2 <- hist(Ysample, freq=F)
plot( p2, col=rgb(0,0,1,1/4), freq=F)
plot( p1, col=rgb(1,0,0,1/4), add=T, freq=F)
## Plot the weights by looking at the point x and its nearest neighbors according to the weights
ggplot(data = data.frame(X=X), aes(x = X.1, y = X.2)) +
geom_point(aes(alpha = exp(weights ))) + geom_point(data=data.frame(X=x), color = "red")
## Get the estimated conditional mean:
mean(Ytrue)
mean(Ysample)
sum(weights*Y)
The results of the above code: On the left, we see a plot of the Xi (the gray points) and in red the test point x. The points of the observations Xi are darker, the more weight is assigned to them by the RF. On the right, a simulation from the true conditional distribution (red) is compared against the points drawn from Y according to the weights. Image by author.

Step 2: Use a different splitting criterion

There is another way to look at the Random Forest algorithm: It’s a homogeneity machine. In each split in a tree, the split in X is chosen such that the two samples of Y in the resulting nodes are as ‘’different’’ as possible. The picture below shows a small example for univariate X and Y.

Illustration of the splitting done in a RF. Image by author.

In such an example, a given tree will likely split X at S as in the picture. Then all Yi with Xi < S will be thrown into node 1 and all Xi >= S into node 2, thus identifying clusters of data. If we do this long enough, each leaf node will have a very homogeneous sample of Y, or in other words, all yi in a leaf node will be similar in some way, because all the ‘’dissimilar’’ observations have been cut. Thus for each tree, you group your data into buckets of similar things.

Why does this make sense? Because in essence, RF is a nearest-neighbor algorithm. If you give it an x it drops it down the trees until it lands in a bucket or leaf with “similar” observations. From these observations, the conditional mean is then calculated. That is, in each tree, only the observations in the leaf are considered, nothing else. So it is like a k-NN where distance is not measured by the Euclidean distance but decided by the forest. The forest in turn decides to label these xi’s as ‘’similar’’ which have ‘’similar’’ yi’s. So the similarity of the xi gets decided based on their associated yi, which makes a lot of sense if your goal is to infer things about Y. In fact, even k-NN methods assume something along the lines of “the conditional distribution P(Y|X=xi) for xi close (in Euclidean distance) to x is about the same”. The Figure below shows an illustration of this: You can see for each value xi in the sample the associated true conditional distribution, from which yi is sampled. A perfect version of DRF would recognize that the conditional distributions for (x1,x4,x6) and (x3, x5, x7) are similar (no matter what their Euclidean distance actually is) and treat the corresponding group of yi, (y1, y4, y6) and (y3, y5, y7) each as an iid sample of the conditional distribution.

For each value of xi in the sample, the conditional distribution of Y|X=xi is shown, with Yi being a draw from this distribution. As can be seen, some distributions are more similar to others. Image by author.

Ideally, this would mean that in practice the homogeneous yi samples we end up within a leaf are actually approximately an iid sample from the conditional distribution Y|X=x. This is what justifies taking the (weighted) mean.

Unfortunately, in the original RF this approach doesn’t work as intended outside of conditional mean prediction. Again, what we want is a splitting criterion that renders the distribution of Y in the two splits as different as possible. What we get instead in the original RF, is simply a split that makes the difference in means between the two samples as large as possible. In the above picture, such a method might group all but x2 into one group, because x1, x3, x4, x6, x7 all have very similar means. But of course, as is already visible in the picture above, a distribution is not defined through its mean. A normal distribution can have the same mean, but very different variance or other moments. Generally speaking, you can have lots of distributions having the same mean, but otherwise being very different.

Two distributions with the same mean. If these were the income inequalities of two different countries, would they be seen as the same? Image by author.

The key is that each split should depend on a measure of difference in distribution between the two resulting nodes. So not just differences in mean or variance are detected, but any difference in distribution. DRF solves this problem by adjusting the splitting criterion usually employed in a RF to harness the theoretical and practical power of kernels and the so-called MMD criterion. The MMD can be very efficiently approximated and is, in principle, able to detect any difference in distribution. Theoretically speaking, we thereby send each point yi into an infinite-dimensional space, the Reproducing Kernel Hilbert Space, and actually compare the means in that space. Through the magic of kernel methods, this comparison between means is actually a comparison of distributions! Turns out in this special space means are distributions. What this means in practice is the following: A leaf node will contain xi’s that are similar, in the sense that the distribution of yi in that bucket is similar. Thus if the conditional distribution of Y given xi and xj are similar, they will be grouped in the same bucket. This can in principle be true, even if xi and xj are far apart in the Euclidean space (i.e. if they are not nearest neighbors in a k-NN sense). So if we use these weights to calculate conditional things we are interested in, we use a nearest-neighbor method that deems xi and xj to be similar, when the distributions of their associated yi, yj are similar. And in particular, under some smoothness assumptions, the sample in the leaf node x ends up in is approximately an iid sample from the distribution P(Y|X=x).

Step 3: Use a multivariate response

This step is actually easy since the MMD also allows to compare multivariate distributions. Importantly, distinguishing more than just the mean becomes even more important for multivariate responses, since the differences in distributions can be even more complicated. For instance, two multivariate distributions can have the same mean and variance for all marginals, but different covariances between the elements.

Examples

Let's make some small examples. Here the goal is to just provide very simple simulated examples to get a feeling for the method. First, we redo what we did manually above:

library(drf)
# Fit DRF
DRF<-drf(X,Y)
weights<-predict(DRF,x)$weights[1,]
### We can now sample according to those weights from Y
Ysample<-Y[sample(1:dim(Y)[1], prob=weights, replace = T),]
## True conditional distribution
Ytrue<-2*x[,1] + 0.5*x[,2]^2 + rnorm(n)
p1 <- hist(Ytrue, freq=F)
p2 <- hist(Ysample, freq=F)
plot( p2, col=rgb(0,0,1,1/4), freq=F)
plot( p1, col=rgb(1,0,0,1/4), add=T, freq=F)
ggplot(data = data.frame(X=X), aes(x = X.1, y = X.2)) +
geom_point(aes(alpha = exp(weights ))) + geom_point(data=data.frame(X=x), color = "red")
## Get the estimated conditional mean:
mean(Ytrue)
mean(Ysample)
sum(weights*Y)

Now yielding these better-looking results:

The results of the above code: On the left, we see a plot of the Xi (the gray points) and in red the test point x. The points of the observations Xi are darker, the more weight is assigned to them by DRF. On the right, a simulation from the true conditional distribution (red) is compared against the points drawn from Y according to the weights. Image by author.

We can also predict conditional quantiles. Doing this for instance gives a prediction interval for the values of Y|x, so that the values drawn from this distribution should be approximately 95% of the time in the interval:

# Predict quantiles for a test point.
quants<-predict(DRF,x, functional = “quantile”, quantiles=c(0.025, 0.975))$quantile
q1<-quants[1]
q2<-quants[2]
mean(Ytrue >=q1 & Ytrue <=q2)

The last line checks what fraction of the new samples, simulated from the conditional distribution, lie within the interval [q1, q2]. The result, in this case, is about 94%, which is close to the 95% we would hope for.

Two-Dimensional Response

Here we construct a crazy difficult example with a two-dimensional response and just calculate a variety of predictions. We first simulate the data:

n<-5000
d<-2
p<-3
## Simulate X and Y
X<-matrix( cbind(runif(n)*2, rnorm(n, mean = 1), rt(n, df=8)),nrow=n)
Y<-matrix(NA, ncol=2, nrow=n)
Y[,1]<-X[,1]^2 + X[,2]*X[,1] + X[,3] + rnorm(n)
Y[,2] <- 2*Y[,1]*X[,2] + 0.1*X[,3] + rnorm(n)

These are pretty crazy relationships for a nonparametric method to estimate, with an X whose elements have widely different distributions. One thing to mention is 2*Y[,1]*X[,2], which means that the correlation between the first and second element of Y is positive when the second element of X is positive, but the correlation is negative when the second element of X is negative. Y overall looks like this:

Y as simulated in the above code. Image by author.

We now take two test points and do some wild predictions, just because we can:

library(drf)# Fit DRF
DRF<- drf(X=X, Y=Y, num.features=200)
# Choose a few test point:
x= matrix(c(0.093, -0.5, 1.37, 0.093, 0.5, 1.37) , ncol=p, nrow=2, byrow=T)
# mean prediction
(predict(DRF, newdata=x, functional="mean")$mean)
# correlation prediction
matrix(predict(DRF, newdata=x[1,], functional="cor")$cor, nrow=d,ncol=d)
matrix(predict(DRF, newdata=x[2,], functional="cor")$cor, nrow=d,ncol=d)
# Estimated probability that Y is smaller than 0:
weightstotal<-predict(DRF, newdata=x)$weights
p1<-sum(weightstotal[1,]* (Y[,1] + Y[,2]<= 0) )
p2<-sum(weightstotal[2,]* (Y[,1] + Y[,2] <= 0) )
# Bootstrapping the estimated probability of the sum being <=0 for both points:
B<-100
pb1<-matrix(NA, ncol=1,nrow=B)
pb2<-matrix(NA, ncol=1,nrow=B)
for (b in 1:B){
Ybx1<-Y[sample(1:n, size=n, replace=T, prob=weightstotal[1,]), ]
Ybx2<-Y[sample(1:n, size=n, replace=T, prob=weightstotal[2,]), ]
pb1[b] <- mean(Ybx1[,1] + Ybx1[,2] <= 0)
pb2[b] <- mean(Ybx2[,1] + Ybx2[,2] <= 0)
}
ggplot(data.frame(x=1:2, y=c(p1, p2)), aes(x = x, y = y)) +
geom_point(size = 4) +
geom_errorbar(aes(ymax = c(quantile(pb1,1- 0.025), quantile(pb2, 1-0.025)), ymin = c(quantile(pb1, 0.025),quantile(pb2, 0.025) )))

We don’t describe the results in detail, but it is interesting that DRF manages to correctly detect a negative correlation when the second element of x is negative and a positive correlation when the second element of x is positive. Furthermore, we snuck in a new aspect: We can even do a (conditional) bootstrap for our estimates. In this case, we applied it to the estimated probability that the sum of the elements of Y is smaller or equal to 0. This results in the following confidence intervals for this quantity:

Bootstrapped CIs for P(Y1 + Y2 ≤ 0| X=x) for the two test points.

From Prediction to Causal Effect

To make things more interesting we study a medical example where we would like to get a causal effect (it's entirely made up with completely unrealistic numbers, though it is inspired by real problems — the difference of reaction that men and women can have on medications).

In this example we simulate an outcome, say a blood-thinning effect (B) that should be regulated by some medication. We also know the age and sex of the patients and we simulate the following relationship: For a male patient, independent of age, the medication linearly increases the blood-thinning effect. For female patients, the medication also increases B, but to a higher extent than for males, if they are above 50 in age. If they are below 50 however, the effect completely reverses and the medication leads to a lower blood-thinning effect. The exact data generating process is here:

n<-1000# We randomly sample from different ages…
Age<-sample(20:70, n, replace=T)
# and sexes
Sex <- rbinom(n, size=1,prob=0.6) # 1=woman, 0=man
# W denotes the dosis of the medication, the causal effect we are directly interested in
W<-sample( c(5,10,50,100),n, replace=T)
# B is the thing we want to understand
B<- 60+ (0.5*W)*(Sex==0) + (-0.5*W)*(Sex==1)*(Age<50) + (0.8*W)*(Sex==1)*(Age>=50) + rnorm(n)

First an illustration of this relationship (which we can plot like this since we know the truth):

Illustration of the simulated effect of the drug by age group and gender. Image by author.

A way to approach this with DRF is to (1) take Y=(B,W) (so again our Y is two-dimensional) and X=(Age,Sex), (2) get the weights for a given x and then (3) estimate a linear regression weighted with those weights. What this gives is an estimator of the effect given that X is fixed to x:

get_CATE_DRF = function(fit, newdata){
out = predict(fit, newdata)
ret = c()
for(i in 1:nrow(newdata)){
ret = c(ret, lm(out$y[,1]~out$y[,2], weights=out$weights[i,])$coefficients[2])
}
return(ret)
}

The result can then be seen here for different test points:

library(drf)# Construct the data matrices
X<-matrix(cbind(Age, Sex), ncol=2)
Y<-matrix(cbind(B,W), ncol=2)
# Fit
DRF<-drf(Y=Y, X=X)
# Get a test point (changing this test point gives different queries)
x<-matrix(c(40, 1), nrow=1)
# Get the weights
weights <- predict(DRF, newdata=x)$weights[1,]
# Study the weights
head(X[order(weights, decreasing=T),])
# Result for given test point
(CATE_DRF = get_CATE_DRF(DRF, newdata=x))

As can be seen, when the code is run for 1000 datapoints these results are quite accurate: In this example, we get an effect of -0.26 for a woman in her 40s (x=(40,1)), an effect of 0.48 for a woman in her 60s (x=(60,1)) and an effect of 0.28 for a 30-year-old man (x=(30,0)).

In some ways, this is still a very simple example and even a linear regression with age and sex as predictors might work well. The important thing is that DRF has no prior assumptions here (such as linearity) and learns the relationships all by itself, even when the effect of X is nonlinear. It gets a lot harder to estimate this effect for smaller sample sizes, but the general direction is often not too bad.

Conclusion

This article explained the Distributional Random Forest method (hopefully in an understandable way). The method is a Random Forest, where each tree splits the response Y according to X in such a way that observations with similar distributions end up in a leaf node. If then a new point x is dropped down a tree, it ends up in a leaf node with other xi’s that have a similar conditional distribution of Y. This results in weights that, averaged over all trees, gives an estimate of the conditional distribution in a simple form. This gives a purely nonparametric estimate of P(Y| X=x) from which a lot of interesting quantities can be estimated.

At the end of this article, we just want to caution that estimating a multivariate conditional distribution nonparametrically is a daunting task. It makes sense especially when there are lots of observations and when a complicated relationship is suspected. However, sometimes just assuming a linear model with a Gaussian distribution does just as well. What makes DRF so versatile is that even in cases where a parametric model is more appropriate, the weights may still be useful for a semi-parametric approach,

A large number of additional examples can be found in the original paper, or in future potential Medium articles. We hope that DRF will help in a lot of data-dependent tasks!

--

--

I am a researcher with a PhD in statistics and always happy to study and share research, data-science skills, deep math and life-changing views.