Explaining the decisions of XGBoost models using counterfactual examples

Model interpretability — fault detection, identification and diagnosis

Pierre Blanchart
Towards Data Science

--

Counterfactual reasoning is a general paradigm for interpretability. It is about determining what minimal changes we would need to apply to an input data so that it gets classified in another class by a classification model.

A typical application scenario is fault detection and diagnosis. Let’s imagine we can precisely monitor a manufacturing process using sensors which are placed all along the production chain (typically in each workstation). Using this data, we can track products through each stage of the fabrication process. Thus, the data associated with a product which ends up being faulty can be recorded for further analysis in an attempt to trace back where the fault was introduced. In particular, we would be interested in finding the workstation(s) where the fault was introduced, and, if possible, in diagnosing what may may have gone wrong there (if the data allows it). You may have guessed at this point that the main goal is to avoid stopping the production chain for long periods of time while trying to discover what went wrong. Having an artificial intelligence able to provide such answers in the blink of an eye may prove very useful, and most of all, good for business, which is an excuse good enough for all the theoretical hardships you’ll go through reading the rest of this article.

if you don’t want this to happen to you, read further :). You’ll no longer have to taste all the biscuits that come out of your production line to ensure that they meet grandma’s standards. You’ll have a model that predicts it, and each time a biscuit is not up to the standard, you’ll be able to say exactly in which place of the production line something went wrong, and why, and, even, what are the minimal actions to take things back to normal. And a 100% AI-powered solution. [Image by Author]

The good news is that we are able to design very efficient fault detection models (FDM). The later are able to analyse in real time a huge amount of data from heterogeneous nature (numerical, categorical …) and measured at different steps of the production in order to predict if a given manufactured element is faulty or not. The less good news is that it is much more difficult to explain the decisions of such models to provide fault diagnosis. Commonly used FDMs are generally very complex / black box models, and everything but transparent.

In this post, we show that for a category of models known as tree ensemble models, to which belong popular high-performance models such as XGBoost, LightGBM, random forests …, we can use an approach called “counterfactual explanations” to explain the decisions of such models. For the sake of simplicity, we only consider here binary classification models which classify the data into two classes: normal / faulty.

For a given query point classified as faulty by the model, we compute a virtual point called counterfactual example (referred to as CF example in the following). The latter is the closest point in terms of Euclidean distance in the input space that is classified as normal by the model. This point is virtual since it does not necessarily exist in the training set. Most of the time it does not, and we build it based on the FDM model parameters. The geometrical intuition behind CF examples is illustrated in the diagram of the figure below.

Decision regions of a two-class classifier. The regions where the model predicts class 1 are labeled “C1”, and the regions where the model predicts class 2 are labeled “C2”. The closest counterfactual example of a point P#i is denoted by CF#i. It is the closest point to P#i which is classified in class 2 by the model. In the figure above, we placed all the points P#i in class 1, so we look for their respective counterfactual examples in class 2. [Image by Author]

For a faulty data, we can say using its associated CF example what would need to be changed at minima so that it goes back to the normal class. You saw it coming miles away, this is what we use to perform fault diagnosis. This approach is very powerful since it allows to spot imperceptible changes which differentiate a faulty data from a normal one. These changes can be multivariate, in the sense that several input characteristics may have changed compared to the normal state. (CF vs feature importance ?) Most important, we have input characteristics which are associated with a physical equipment (workstation) and a manufacturing process. Thus, if the comparison of the faulty input data with the CF example spots out that to get back to normal I need to lower the temperature in station 25 by 0.1 °C and increase pressure in station 31 by 0.05 bars, I can quickly schedule this intervention on my production chain to avoid new faults occurring and wasting more time/materials. And, that’s what’s particularly interesting with counterfactual explanations, they give you a precise idea of the minimal action to take to correct the problem. Well, I hope I convinced you that having a CF example in hand associated with your faulty data is the key to a quick solution of your problems, and, potentially, a great money saving (so cliché to think that everything is about money and time saving …).

In this part, we unroll an efficient algorithmic approach to compute the closest CF example associated with a faulty data in the case of tree ensemble models. I’ll try to keep the whole rather intuitive, but sometimes I’ll have to dig out the maths. Sensitive readers can skip these parts, it does not hinder the comprehension of the whole, and you can still keep it for later, when you’ll be done with the problem solving on your production chain with plenty of free time at your disposal.

First of all, we need to enter the particularities of tree ensemble models. CF examples are computed based on the particularities of the models (meaning the model parameters), so it seems a rather essential step to understand how they work.

Let’s start with the decision tree which is the base component of such models. More precisely the binary decision tree: in each node of such a tree, we analyse the value of an input characteristic by comparing it to a single threshold value. If the value of the characteristic is higher than the threshold we go to the right branch, if not, we go to the left one. We repeat this logic until we end up in a leaf of the tree which is associated with a score (a vote for one class). Let’s now analyse what it means geometrically speaking. I tried to represent it on the diagram of the figure below. If you think the figure is useless, I’ll recycle it as digital art in an other post, but be assured I put all my talent in it.

Decision regions in the input feature space corresponding to the leaves of a decision tree. These regions are boxes / multidimensional intervals which are potentially open on some sides. The input space is two-dimensional. In a node Ni, we indicate which dimension dj of the input data is analyzed, and the threshold value ci to which it is compared. A node Ni is thus associated with a pair (dj, ci). It can be checked that even if a feature is tested more than twice on a path between the root and a leaf of the tree, only two tests are effective to characterize the decision region associated with the leaf, the other tests being redundant. This property holds whatever the number of input dimensions. The path between N1 and F1 contains an example of redundant testing on feature d2: in node N2, we test “d2 < 2.5“, and in node N4, we test “d2 < 1.4”, which is equivalent to testing only “d2 < 1.4”. [Image by Author]

It appears that the leaves of a decision tree are, geometrically speaking, multidimensional boxes (which can be open on some sides). To be mathematically more correct, we would say that the leaves are multidimensional intervals. So, what about a tree ensemble ? You may have drawn the conclusion by yourself very easily, it is no more than a collection of multidimensional intervals / boxes, with their associated class votes / scores. To predict an input element, we just need to figure out to which boxes it belongs (boxes may intersect with each other), and sum up the associated scores. Mathematically, if we represent a tree ensemble model F as a pair (B, S) where B is a set of boxes/leaves and S the set of associated scores, the prediction function associated with F is thus simply:

Tree ensemble model prediction function. N is the number of leaves of the model. Bn refers to the n-th leaf and Sn to the score associated with this leaf. Sn is a K-dimensional vector where K is the number of classes associated with the classification problem. It is often a sparse vector which votes only for a single class (i.e. with only one non-null coefficient). To know the decision of the model associated with an input point X, we compute “arg max(F(X))”. We have δBn(X)=1 if X belongs to the leaf Bn, and δBn(X)=0 otherwise. [Image by Author]

You may think that such a model as XGBoost may yield more complex decision functions, but, in fact, no… For binary classification, it produces one score (between 0 and 1), which, once compared with a threshold value (0.5 by default) tells you in which class your data falls. I’ll make an other post on how to adapt the interpretability approach described in this post to tree ensemble models for multi-class classification and regression, so, follow my channel (for your own good, nothing commercial about this, science is my only fuel).

At this point, we’ll try to determine our CF example based on a geometrical decomposition of the decision regions of the model as a collection of boxes. That’s where things get tricky since we cannot use directly the decomposition I mentioned above. The following example should be enough to convince you that we can’t, and that the world is nastier than it looks.

The intersection of several leaves / boxes make appear new regions whose score is determined by summing the scores of the individual boxes that intersect to form these regions. For instance, the score of the green region is determined by computing S1 + S2 + S3, and the score of the red region by computing S1+S3. [Image by Author]

On the figure above, the intersection of the three boxes B1, B2 and B3 made appear two new decision regions (highlighted in red and green). We cannot tell for sure to which class these regions belong, unless we sum the scores of the boxes that intersect to form these regions. Well, guess what ? That’s what we do. We perform a “super decomposition” of the original decomposition that determines all the intersection regions and the other ones as well. For algorithmic reasons, we compute a “super decomposition” that is also a collection of boxes. So, it will rather look like that:

Super decomposition of an initial collection of three boxes that intersect with each other. The super decomposition also takes the form of a collection of boxes. The later is not necessarily the simplest box-like decomposition that can be found (in terms of the number of boxes). [Image by Author]

All the difficulty of the problem is to design an algorithmic approach to compute a box-like super decomposition that escapes the underlying combinatoric of the problem. Indeed, formulated in a brute force way, the problem would amount to determine, for any group of “k” boxes, if these “k” boxes form a maximal intersection zone, i.e. if their intersection zone is not empty, and if there are no other boxes outside the “k” considered ones that intersect with this zone. This problem is combinatoric and not tractable even for a moderately large number of boxes. We propose an algorithm which builds the intersection boxes dimension by dimension. We use a handy property of boxes with sides perpendicular to coordinate axes: if two boxes do no intersect according to one particular dimension, they do no intersect at all. The figure below illustrates this idea.

If two boxes do no intersect according to one particular dimension, they do no intersect at all. On the left diagram, the two boxes intersect according to dimension d1, but not according to dimension d2. On the right diagram, the two boxes intersect according to dimension d2, but not according to dimension d1. [Image by Author]

This property gives rise to a tree-like hierarchical exploration structure. Each level corresponds to one dimension of the input space. In each node, we compute the maximal intersection of boxes that are stored in this node according to the dimension corresponding to the current level of the tree-like structure. If two boxes do not intersect according to this dimension, they will no longer appear together in the same node in a further level of the tree, since, according to the property above, they have no chance to form an intersection region together.

Putting back to action my amazing talent as a digital art creator, I tried below to “run” the algorithm graphically, having reached the limitation of what words can provide, and not wanting to choke you up with an hardcore algorithmic proof. The algorithm is mostly about the construction of the tree-like structure mentioned previously. This structure contains in its leaves a box-like super decomposition of the intersection regions formed by the leaves of the tree ensemble model.

Tree-like exploration structure built by the box-decomposition algorithm. We consider here a model with two leaves, (F1, F2), in a two-dimensional feature space. The algorithm proceeds dimension by dimension. The first level corresponds to the dimension d1 and to the decomposition into maximal intersection intervals of the model leaves according to the dimension d1. After this level, we end up with three independent subset of leaves, on which we apply the same 1D decomposition procedure, but, this time, according to dimension d2. The hatched regions represent the resulting maximal intersection boxes produced by the algorithm. [Image by Author]

In each node of the tree-like structure, we always solve the same of intersection problem: given a list of boxes and a dimension of the input space associated with the node, we look for the maximal intersection of the boxes in this dimension, i.e. we solve the combinatoric problem mentioned previously, but in one dimension. We thus no longer deal with boxes but with 1D-intervals which are the projections of the boxes in the considered dimension. Fortunately, in 1D, the problem is no longer combinatoric, and even becomes linear in the number of considered intervals. The idea is to place all the intervals on a 1D axis as on the figure below.

Finding regions (intervals) of maximal intersection in 1D from a collection of 1D-intervals. A region is said to be of maximal intersection if it corresponds to the zone of intersection of k intervals, where k is the maximum number of intervals that intersect in that zone. In order to compute these regions, we place all the intervals on a 1D-axis, and we create a new maximal intersection region each time an interval starts or ends. On the example above, from the collection of intervals {I1, I2, I3}, we extracted five maximal intersection regions Z1, Z2, Z3, Z4, and Z5. [Image by Author]

We then notice that a new maximal intersection 1D-region (which is a 1D interval) begins each time an interval starts or ends, except for the last interval ending, which terminates the last maximal intersection region. Hence, if we consider an initial collection of N intervals, we will have at most 2.N-1 maximal intersection regions. Algorithmically, we have to sort the beginning and ending of all intervals put together, which amounts to sort 2.N points. In practice, this operation can be skipped in the super decomposition algorithm by pre-sorting separately in each dimension the boxes composing the tree ensemble model, and, by observing that using a masking operation to extract a subset from an ordered set yields an ordered subset. We can thus consider a subset of boxes in a node without having to sort them again according to the dimension associated with the node. We thus end up with treatments in each node of the tree-like structure which are surprisingly cheap compared to the initial complexity of the problem.

You may have anticipated at this point that the super decomposition algorithm may end up producing so many decision regions, that it won’t even fit in the memory of your computer. And that’s what happens cruelly. So, we will be now in search of optimizations which still allow to find the exact CF example, but limiting the number of regions which are actually built. The main idea is that we don’t need to build the whole decision region of the model, but just something around the query point that allows to find an example from an other class. So, we would like to bound the size of the search region in which the exact CF example lies. As a first try, we can use the training data to do so: given a query point, we look for the closest data in the training set that is classified as normal by the model (notice that I said “classified” and not “labeled”). This provides a first and quite honest upper bound on the size of the search region. Using this, we’ll build only the search regions which are situated within this upper bound. It comes out handy since the algorithm used to build the decision regions proceeds dimension by dimension. So, if, for a given dimension, the partially built decision region is already farther than the upper bound, we can stop the exploration in the corresponding node of the search tree. This approach bears the generic name of “branch-and-bound”, and is particularly useful in all scenario where you need to perform exploration along the branches of a tree-like structure (which is obviously our case).

I hope you’ re still following me on this dangerous branch-and-bound exploration. There are many other practical optimizations, the most effective one being to explore the search structure in a depth-first manner. This provides us very quickly with an upper-bound which is better than the one computed using the training set alone. In practice, we maintain several depth-first explorations in parallel using multi-threading to keep it effective. Each time we find a tighter upper-bound than the last tightest one, we update the value of the tightest upper bound and communicate the new one to all threads for immediate use. This ends up in a very fast pruning of the search tree, and allows to explore exhaustively the decision regions of the model in the blink of an eye (sometimes a bit more, depends how fast you blink).

Last point, dimensionality plays on your side at some point (it’s rare to be able to write this, so I would almost call this moment historical, and, if I wasn’t typing, you would see my handwriting trembling). To put it in plain words, the more dimensions you add, higher is the chance that the distance of a partially constructed box to the query point exceeds the upper bound. This simple effect causes the number of regions you create by exploring the search tree to stabilize, and, even sometimes, to (slightly) decrease after a certain number of dimensions. Well, let’s no grow over-enthusiast on this anyway, better add good optimizations and pray for a quick outcome.

So, now I unrolled the funding principles, let’s put our hands in the code. In my immense kindness and for your beautiful eyes, I coded all the optimizations above (and even more) in a nice piece of C++ program, that I wrapped in R (and maybe I’ll write a python wrapper as well in a next future, so that it can be used with your favorite high-level programming language). I’ll show you in an other blog post how to do nice and easy Rcpp wrapping, to the point that you may think about abandoning python.

The R package can be found on my github. It requires the “Boost” and “TBB” C++ libraries. You have to check that the path to these libraries is rightly set in the “Makevars” file located in the “./src” folder, or that they can be found in standard system paths. To install the package into the folder INSTALL_PATH, from a terminal, use the command (linux users only):

cd PACKAGE_SOURCE_FOLDER
R CMD INSTALL -l /INSTALL_PATH ./ --preclean

Once you have (successfully) installed the R package, let’s start with the use cases. I chose to kept the description relatively short to avoid being terribly redundant with the github. If you’re interested, be aware that the code is still under development and waiting for outside contributors to make it a great package for CF interpretability of large size tree ensemble models that could have its place one day in the code of XGBoost (yeah, the author is a dreamer …).

As a first example, let’s consider a dataset for consumer credit approval/denial. Quite different from the industrial fault detection scenario I advertised at the beginning of the blog post, but CF interpretability approaches can be widened to any classification scenario which is composed of a normal class (credit approval), and an “abnormal” one (credit denial). Also, small datasets are nice to showcase something, and the academic world has been using them for decades now (is it a solid argument ?). The dataset allows to learn a mapping between 20 input features such as the number of ongoing credits, the income, the age, the purpose of the credit … and the decision to grant or not the credit.

I spare you all the feature formatting and the training details of the XGBoost model, you can easily reverse-engineer them from the demo scripts. Let’s jump directly to the CF example computation. We first need to select a point of the test data that corresponds to a credit denial. We check that it is both classified and labeled as a “credit denial”. To ensure the later, we simply look at the ground truth. The trained model is indeed not going higher than 75 percents in terms of two-class accuracy, so there are many false positive (meaning points classified as a “credit denial”, and which aren’t in reality). Given a selected point that we call “query point”, we compute its associated closest CF example. Below, I explain the meaning of each argument in the function used to determine the CF example. I felt obliged to do it here, since it lacks explanations in the code.

Signature of the main method in charge of computing the CF example.

The table below shows some results for three individuals. Each time we mention the query point (client status), and the associated CF example (recommendation + restricted recommendation). When there are no changes between the query point and the CF example, the cell is left blank. For each individual, we compute two CF examples. The first one (line “recommendation”) corresponds to the “full actionability” scenario, i.e., the individual can control/change all the characteristics concerning himself. The second one corresponds to the “partial actionability” scenario, i.e., the individual can only control/change a few variables. For instance, it is more realistic to consider that the individual can’t control his age, though research on rejuvenating lotions has made great progresses.

For three different individuals, we diagnose the decision of the bank to refuse them the credit they asked for, and we show the recommendation provided by the CF approach, i.e. what they would need to change a minima so that the credit be granted to them. For each individual, we issue two recommendations: the first one considers that the individual has the possibility to control/change all the characteristics (“recommendation” line). The second one considers that the individual can only control a restricted number of characteristics (“restricted recommendation” line), which is often more realistic (the individual can’t make any change to his age or marital status for instance). Anyway, I sincerely hope this will help you get the 100 bucks credit you needed to buy the bugatti gran turismo of your dreams. [Image by Author]

The second use case concerns miss-classified digits of the MNIST dataset. We extract two-classes among the ten ones, which are ambiguous by nature (such as 1 and 7, or 5 and 6). Then, we train a two-class classification XGBoost model that learns to discriminate between these two classes. As a query point, we choose one point belonging to one class that is miss-classified in the other class by the model. We also ensure that the input digit is visually ambiguous (meaning the human eye can’t really tell to which class it belongs to). This last point is particularly important since our goal is not to detect potential weaknesses of the classification model, but the presence of abnormal values in the input data. So, we need to be able to think that the input data is faulty in some way, and that the classifier just does its job. Only under these assumptions, we can draw a parallel between miss-classified and faulty data, and apply our CF approach on miss-classified examples as we would do on faulty data.

In this use case, we introduce one more aspect: the notion of plausibility. For a CF example to be more plausible (i.e. more visually convincing), we can add a constraint on the decision score corresponding to the prediction of the CF example by the model. By forcing the model to predict the right class with more confidence, we obtain CF example which look closer and closer to an element of the right class (by “right class”, I mean the class that is associated with the query point in the ground truth). Enforcing plausibility comes at the detriment of distortion, meaning that higher plausibility causes the CF example to be farther from the query point in terms of Euclidean distance. he figure below illustrates this idea.

Considering two binary classification problems (5 vs 6 and 4 vs 9). First line: example of a ”5” miss-classified as a ”6” by the model. The initial query image is represented on the left. We make vary the decision threshold  between 0.5 and 0.2. Getting a CF example that is classified into the right class with more confidence by the model is done at the expense of the distortion introduced to the initial query data, i.e., the lower the decision threshold ”eps”, the higher the distance to the initial query ”dCF”. Visually, we see that the CF approach provides plausible changes of the initial query data to make it look more like a ”5”. Second line: example of a “4” miss-classified as a “9” by the model. [Image by Author]

Other plausibility criteria could be enforced as well: for instance, we could check that the CF example found lies in a maximal intersection region that contains at least one element from the training dataset. This would avoid picking out-of-distribution CF examples that are unrealistic/unreachable in real life.

Last point before I let you go eat your non-faulty wafer, it is not mandatory to consider an Euclidean distance between a query point and its corresponding CF example. Other distances could be considered as well without the need to change the decomposition algorithm, provided that these distances can be formulated as a coordinate-wise addition, i.e. computed as: “d(X, Y) = g( ∑ di(X[i], Y[i]) )” where “di” is a specific operation applied to the i-th coordinate and “g” an increasing monotonic function. For instance: di(X[i], Y[i]) = (Y[i]- X[i])² for the squared Euclidean distance.

Well, what to keep in mind from this very long blog post (sorry, I might have been overly verbose on that first attempt to communicate with the world). First, that tree ensemble models allow for the computation of exact CF examples, while being great models for fault detection (especially gradient boosted trees). Second, that CF examples, in addition to localizing the fault / anomaly, gives a precise idea of the minimal the action to take to correct it. Moreover, the CF method provided here has the great advantage of producing sparse explanations, i.e. suggest changes on a reduced number of input characteristics, which makes the explanation more intelligible for a human user.

If you’re interested in the topic, and want to know more about the possibilities of the method, you can read the full article here of which I happen to be the author (but that’s a purely random coincidence). If you want to use the package in your application, please cite: An exact counterfactual-example-based approach to tree-ensemble models interpretability

In a next post, I’ll show you an extension of CF explanation to regression problems, and teach you (if you let me) how CF reasoning can be deployed for profit maximization, or, at least, how to make the sale price of your house climb by 10 000 dollars just by changing the color of the carpet floor in the kitchen. Well, I said too much (or not enough), but if this first post is a success, I’ll deliver you the full recipe :).

--

--