Explainable Graph Neural Networks
A step forward in explainable AI, and why it is hard to adapt existing explanation methods to GNNs
TL;DR
- Explainability is a big topic in deep learning as it enables more reliable and trustable predictions.
- Existing explanation methods can’t be easily adapted to Graph Neural Networks due to the irregularity of graph structure.
- Quick peek into 5 groups of GNN explanation methods.
Explainability increases reliability
Recently, explainability in Artificial Intelligence has attracted much attention, where the main motivations lie in more reliable and trustable predictions generated by the “black box” models such as Deep Neural Networks. A good example given in the LIME paper [2] is that you would trust your doctor due to his/her ability to explain the diagnosis according to your symptoms. Analogously, the predictions generated by a deep model are more reliable and trustable if the predictions can be explained or justified in a human interpretable way.
Related posts
There have been several posts in TDS in the past month related to explainable AI. Some of them provide hands-on examples to help getting started with explainable AI. A couple of posts are listed here if you’re interested in learning more about explainable AI.
- How to explain neural networks using SHAPE [3] by Gianluca Malato provided an example of using SHAPE to explain a Neural Network trained to predict the probability of diabetes given 33 features. The explanations are then visualized to understand how each feature contributes to the predicted outcomes.
- Explainable AI (XAI) — A guide to 7 Packages in Python to Explain Your Models [4] by Prateek Bhatnagar provided an overview with hands-on examples of a couple of great toolkits for explaining deep models to help you get started.
- Explainable Deep Neural Networks [5] by Javier Marin presented a new approach to visualize hidden layers in a Deep Neural Network to gain insights into how the data are transformed throughout the network using Topological Data Analysis.
Why is it hard to adapt existing methods to explain GNN?
Traditional explanation methods work quite well on Convolution Neural Networks (CNN). The example below shows the LIME explanation of the three top predicted class labels of the input image (a). We can see clearly the parts that lead to corresponding predictions matched the class label. For example, the guitar neck contributes most to the prediction “Electric Guitar” (b).
However, when it comes to Graph Neural Networks (GNN), things become a bit trickier. As opposed to the highly regular mesh grids on which CNNs operate, the irregularity of graph structure poses many challenges. For example, we can easily interpret the above explanations of the CNN model, but for a graph, the analogous node level explanations are not easy to visualize and interpret.
In the following section, we’ll go through the main idea of each group of methods as presented in a recent review on explainability in GNNs [1].
Overview of methods in explaining GNN
Gradient/Feature Based Methods: use gradient or hidden features as the approximations of input importance to explain the predictions via back-propagation.
Perturbation Based Methods: output variation with respect to the perturbation of input reflects the importance of that input region. Or put it another way, what are the nodes/edges/features that need to be kept so that the final predictions do not deviate too much from that of the original GNN model.
Surrogate Methods: train a surrogate model that is more interpretable using the neighboring areas of an input node.
Decomposition Methods: decompose the predictions into several terms, each regarded as the importance score of the corresponding input features.
Generation Methods: learn to generate graphs that achieve optimal prediction scores according to the GNN model to be explained.
What makes a good GNN explanation method?
When it comes to method evaluation, there are many things to take into account. The authors in [1] suggested a couple of metrics that ensure the following properties.
- Fidelity: the explainable model should have consistent predictions as the original GNN that is being explained.
- Sparsity: only a small fraction of nodes/edges/features are used as explanations.
- Stability: small changes to the input should not affect the explanation too much.
- Accuracy: the explanation should accurately recover the ground truth explanations (this only work for synthetic datasets where the ground truth are known)
Conclusion
Explainability is a crucial part of Artificial Intelligence as it enables reliable and trustable predictions. However, adapting existing explanation methods to GNN is not trivial. We’ve taken a quick look at a couple of general approaches in existing GNN explanation methods and also some desired properties that define a good GNN explanation method.
Reference
[1] H. Yuan, H. Yu, S. Gui, S. Ji, Explainability in Graph Neural Networks: A Taxonomic Survey (2020), arXiv
[2] M. T. Ribeiro, S. Singh, C. Guestrin, “Why Should I Trust You?”: Explaining the Predictions of Any Classifier (2016), ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD)
[3] How to explain neural networks using SHAPE
[4] Explainable AI (XAI) — A guide to 7 Packages in Python to Explain Your Models