..

Introduction
I use Graph Neural Networks in my day-to-day job, and I have wasted many days due to the lack of a decent network visualisation tool when trying to explain and review the outputs of a newly trained model.
So this has motivated me to write this article, where I provide a step-by-step guide on how to create a fully interactive network visualisation of a GNN model explanation, in Python, WITHOUT having to rely on expensive third-party solutions.
We go from static plots like this:

To interactive and insightful visualisations like this:

This visualisation will differ from existing methods (e.g. matplotlib
, networkx
) as it will allow us to:
- Drag and drop nodes.
- Click on both nodes and edges to reveal feature importance values.
- Colour nodes and edges according to their importance.
- Apply different shapes to nodes according to the model prediction.
- Change the display settings dynamically using a menu-panel.
Whilst being:
- convenient to implement.
- compatible with
networkx
, a commonly used network analysis package in Python. - provide moderately good scalability to the size of the graph (we’re not making something for prod, but we need it to handle hundreds of nodes at least).
We will cover the following
- Installation and setup
- Explain the example node classification task, the dataset, and the model we will use in this article.
- A quick intro to explaining a GNN using
torch-geometric
, and what the explanation object looks like. - A step-by-step guide on creating the visualisation.
This article will not be explaining the fundamentals of GNNs nor the explanation algorithms used here, as there are many articles and videos out there that already do this.
The models and algorithms I choose are arbitrary and can be substituted to those of your choice.
Also, I have chosen examples that often appear in tutorials, so that people who are familiar with them can go straight to the visualisation part and save some time.
You can also find the Jupyter Notebook that accompanies this article here:
…which you can have open side-by-side with this article whilst reading.
So, enough with the intro, let’s get started.
Installation and Set Up
Packages
To get us started, we first need to install our dependencies.
We will be using torch
and torch-geometric
to train our GNN.
We also install some additional torch related dependencies required for some helper functions we will be using later.
pip install torch torch_geometric
pip install torch-sparse torch-scatter -f https://data.pyg.org/whl/torch-{your-torch-version}+cpu.html
For visualisation, we use gravis
,
pip install gravis
my favourite visualisation package. For more info on what this package has to offer, I have written about it in detail here:
And finally some additional dependencies:
pip install numpy scikit-learn matplotlib networkx
The Task, the Dataset, and the Model
Before we start visualising, we need (1) an example task and (2) a model to explain.
Multi-class Node Classification Task
We will be using the _PubMed Dataset_¹ (CC BY-SA), a citation network of scientific publications from the PubMed² database relating to diabetes.
Each publication belongs to one of three anonymised categories, and our model is tasked with predicting the correct class for each publication.
Dataset Description
I provide the description from the PubMed dataset source:
The dataset consists of a single large network, composed of 19.7K nodes and 88.7K edges. Each node has a TF/IDF weighted word vector from a dictionary consisting of 500 unique words.
So each node is a publication, an edge is a link between two publication where there exists a citation between them, and features represent words.
This is quite a small dataset—there are 19.7K nodes, but it only specifies 60 nodes for training, 500 for validation and 1000 for testing. The rest are connections that we do not use for training or inference.
Download and Preprocess the Data
We will use the torch-geometric
dataset implementation which downloads and preprocesses the data for you in a single line of code.
from torch_geometric.datasets import Planetoid
data_train = Planetoid(root="./data/", name='PubMed').data
print(data_train)
Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])
You can see that the data object consists of several properties. Here is a breakdown of what they mean:
x
: The node features, of shape[number_of_nodes, number_of_features]
edge_index
: an edge-list of shape[2, number_of_edges]
which defines the connectionsfrom_node
andto_node
. Nodes are represented as the index value in which you can find the node’s corresponding feature values inx
.y
: The labels, shape[number_of_nodes]
where it can take values[0, 1, 2]
train_mask
,val_mask
,test_mask
: aboolean
tensor of shape[number_of_nodes]
indicating whether the node in the corresponding index inx
should be used for training/validation/testing. The sum of the elements in each mask equals 60, 500 and 1000 respectively.
The Model
To explain a GNN we obviously need a trained model. We can use the below code to define a simple two-step Graph Attention Network. I have set the hyperparameters according to this paper.
class GAT(torch.nn.Module):
def __init__(self, input_dim):
# fix hparams according to paper
self.n_heads = 8
self.dropout = 0.6
self.hidden_dim = 8
self.num_classes = 3
super(GAT, self).__init__()
self.conv1 = GATConv(input_dim, self.hidden_dim, heads=self.n_heads)
self.conv2 = GATConv(
self.n_heads * self.hidden_dim, self.num_classes, heads=1
)
def forward(self, x, edge_index, adj=None):
# Layer 1
x = self.conv1(x, edge_index)
x = F.dropout(F.elu(x), p=self.dropout, training=self.training)
# Layer 2
x = self.conv2(x, edge_index)
x = F.dropout(F.elu(x), p=self.dropout, training=self.training)
return F.softmax(x, dim=1)
We can now train it using the below code. You can also replicate the training through the notebook I provided.
First, we define what happens in a single train-validation iteration.
def train(model, data, optimizer):
# initialise
model.train()
optimizer.zero_grad()
# optimize
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def val(model, data):
# compute loss
model.eval()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask])
# compute metrics
pred = out.argmax(dim=1)
acc = (pred[data.val_mask] == data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
return acc, loss.item()
We then set up the training by defining the hyperparameters, model and optimizer:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialise model
num_node_features = data_train['x'].shape[1]
model = GAT(num_node_features)
# Push to GPU if availble
model.to(device)
data_train = data_train.to(device)
# initialise optimizer for training
LEARNING_RATE = 0.01
WEIGHT_DECAY = 1E-5
optimizer = torch.optim.Adam(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
And now, we run through the training loop.
train_loss_vals = []
val_loss_vals = []
acc_vals = []
n_epochs = 100
for epoch in range(n_epochs):
# train val iteration
train_loss = train(model, data_train, optimizer)
acc, val_loss = val(model, data_train)
# keep track of metrics
train_loss_vals.append(train_loss)
val_loss_vals.append(val_loss)
acc_vals.append(acc)
print(f'Epoch: {epoch:03d}, Loss: {val_loss:.4f}, Accuracy: {acc:.4f}')
Below are the loss curves and accuracy plots:


For the purposes of this article, we just need a model that is good enough so we don’t bother with any further hyperparameter optimisation.
How to Explain the model
- Initialise the Explainer object
torch_geometric
provides multiple explanation algorithms, and for this article I will use the GNNExplainer
. You can find all the explanation algorithms supported by torch-geometric
and their use cases here.
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='probs',
),
)
- We use a
node_mask_type = 'attributes'
to get a feature-level breakdown of importance values (i.e. importance values for each feature on the node). edge_mask_type = 'object'
provides us with edge-level importance values (i.e. importance values for each edge). If we had edge-level features, we could change this to'attributes'
to get feature-level importance values.
- Sample nodes and their local neighbourhood
We use the NeighbourLoader
class to sample a node and its local neighbourhood that we explain.
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data=data_train,
num_neighbors=[10, 10],
input_nodes=data_train.test_mask, # sample from the test set only
)
The num_neighbors
parameter specifies how many neighbours to sample for each hop. So, if num_neighbors=[10,10]
, it means we are randomly sampling two hops deep into a network, and up to 10 nodes in the first hop and up to 10 nodes in the second hop.
![Image by author. An example of neighbour sampling when _num_neighbors = [3, 3]_ on a graph where the source node (blue) has 5 nodes in the first hop, and successive nodes have between 1 and 4 neighbours. In each hop, up to three neighbours are randomly sampled.](https://towardsdatascience.com/wp-content/uploads/2024/01/1OhBtl_n8fgkMlywno2G_-A.png)
num_neighbors = [3, 3]
_ on a graph where the source node (blue) has 5 nodes in the first hop, and successive nodes have between 1 and 4 neighbours. In each hop, up to three neighbours are randomly sampled.As we want to plot something interesting, we will loop over the loader
to find an example sub-graph that has at least 40 nodes.
for batch in loader:
if batch['x'].shape[0] > 40).any():
break
print(batch)
Data(x=[97, 500], edge_index=[2, 110], y=[97], train_mask=[97], val_mask=[97], test_mask=[97], n_id=[97], e_id=[110], input_id=[1], batch_size=1)
We have ended up with a 2-layered graph consisting of 97 nodes and 110 edges.
- Run the Explanation Algorithm
We can now run the explanation algorithm on our sampled graph, where we explain the node at index 0:
explanation = explainer(batch.x, batch.edge_index, index=0)
print(explanation)
Explanation(node_mask=[97, 500], edge_mask=[110], prediction=[97, 3], target=[97], index=[1], x=[97, 500], edge_index=[2, 110])
The Explanation Object
An Explanation
object is composed of the following attributes (for brevity, I’ve excluded any we don’t need to consider for this article):
Explanation(
node_mask=[97, 500], # Node importance mask, shape [#nodes, #features]
edge_mask=[110], # Edge importance mask, shape [#edges]
prediction=[97, 3], # The predicted prob values, shape [#nodes, #classes]
target=[97], # The class predictions by the model, shape [#nodes]
...
x=[97, 500], # node features [#nodes, #features]
edge_index=[2, 110] # edge connections [2, #edges]
)
- Node mask: This contains importance values for each feature in each node in the graph. This is the result of us specifying
node_mask_type='attributes'
when initialising theExplainer
object. The sum of importance values across all features is equal to the node level importance. - Edge Mask: This contains importance values for each edge, where we set
edge_mask_type='object'
. If our dataset had edge features, then like we did for nodes, we can setedge_mask_type='attributes'
to get feature level importances. - Prediction: The output of the model. In our case, our model provides probability values.
- Target: The predicted class from the model. Note, this is not the label.
Visualising our Explanation
Given a model, a network, and an explanation, we can start to pull everything together into an interactive visualisation.
Step 1: Convert the Explanation into a NetworkX Graph object
We firstly translate our explanation object into a networkx
graph object, as the latter is compatible with our visualisation package, gravis
. The full code to do this is provided below.
from torch_geometric.utils import to_networkx
# store labels into the explanation object - we want to use it in our vis.
# store it as 'label' and not 'y' - gravis expects 'y' to store coords
explanation['label'] = batch['y']
# rename 'x' as gravis expects 'x' to store coords
explanation['node_features'] = batch['x']
# convert the explanation into a networkx graph object.
g = to_networkx(
explanation,
# these attributes are unique to the graph, so are graph attributes
graph_attrs=['node_mask', 'edge_mask', 'target', 'label'],
# the node features are unique for each node, so are node attributes.
node_attrs=['node_features']
)
Let’s break this down.
Firstly, the explanation
object has everything for our visualisation except the labels y
.
The labels will be useful to have as part of our node-level display so we manually add this.
explanation['label'] = batch['y']
Note, __ we save the labels under the key label
instead of y
. Similarly, we rename the feature values x
to node_features
.
explanation['node_features'] = batch['x']
This is because gravis
expects the keys x
and y
to be coordinates. If we don’t do this, the nodes will be fixed in place in weird locations.
Finally, we use the to_networkx
function to translate explanation
into a networkx.DiGraph
instance whilst specifying the attributes we want to retain during the conversion.
# convert the explanation into a networkx graph object.
g = to_networkx(
explanation,
# these attributes are unique to the graph, so are graph attributes
graph_attrs=['node_mask', 'edge_mask', 'target', 'label'],
# the node features are unique for each node, so are node attributes.
node_attrs=['node_features']
)
node_mask
andedge_mask
are required for determining colours and thicknesses for nodes and edges.node_features
andy
are useful information that we wish to display when the user clicks on a node.- We need the
target
to set node shapes based on the model prediction. Remember,target
stores the predicted class, not the label.
So far, plotting this in gravis
will give us a plot like below:
import gravis as gv
gv.d3(g, graph_height=700)

Step 2: Add Node and Edge Colours
Next, lets colour the nodes and edges according to their importance. We can use the below helper functions to extract this information from node_mask
and edge_mask
.
def set_node_level_imoprtance(g):
'''Get node level importance'''
# sum over each feature importance per node to get the overall node importance
node_level_importance = np.array(g.graph['node_mask']).sum(axis=1)
# assign the importance value to each node as an attribute
node_level_importance_dict = { i : node_level_importance[i] for i in g.nodes }
nx.set_node_attributes(g, node_level_importance_dict, name="importance")
def set_edge_level_importance(g):
'''Get edge level importance'''
edge_level_importance = g.graph['edge_mask']
# assign the importance value to each edge as an attribute
edge_level_importance_dict = { edge : edge_level_importance[i] for i, edge in enumerate(g.edges) }
nx.set_edge_attributes(g, edge_level_importance_dict, name="importance")
For nodes, we have a tensor of shape [#nodes, #features]
, so we sum over the individual feature values for each node to get node level importances.
node_level_importance = np.array(g.graph['node_mask']).sum(axis=1)
Meanwhile, edge_mask
is already on the edge level (shape = [#edges]
) so we don’t need to sum anything.
edge_level_importance = g.graph['edge_mask']
We then want to map the importance values onto a colormap. In this instance we use the winter
colourmap in matplotlib
– green indicates a positive contribution to the final predicted class while blue indicates a negative contribution.

To do this, we use this helpful utility class to convert numeric values into an RGB
string:
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib as mpl
class MplColorHelper:
def __init__(self, cmap_name, start_val, stop_val):
self.cmap_name = cmap_name
self.cmap = plt.get_cmap(cmap_name)
self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
def get_rgba(self, val):
return self.scalarMap.to_rgba(val, bytes=True)
def get_rgb_str(self, val):
r, g, b, a = self.get_rgba(val)
return f"rgb({r},{g},{b})"
All we need to do is specify the min-max values over which to scale the colormap, then compute the RGB
string, and assign it to each node.
def set_node_colors(g, cmap='winter'):
'''Set colours based on importance values'''
# scale our colourmap to be between the min-max importance
vals = nx.get_node_attributes(g, 'importance').values()
min_val, max_val = min(vals), max(vals)
# initialise colour helper
node_color_generator = MplColorHelper(cmap, min_val, max_val)
# get rgb string for each node
node_colors = {
node : node_color_generator.get_rgb_str(data['importance']) for node, data in g.nodes(data=True)
}
nx.set_node_attributes(g, node_colors, name='color')
The same applies for edges.
def set_edge_colors(g, cmap='winter'):
'''Set colours based on importance values'''
# scale our colourmap to be between the min-max importance
vals = nx.get_edge_attributes(g, 'importance').values()
min_val, max_val = min(vals), max(vals)
# initialise colour helper
edge_color_generator = MplColorHelper(cmap, min_val, max_val)
# get rgb string for each edge
edge_colors = {
(u, v) : edge_color_generator.get_rgb_str(data['importance']) for u, v, data in g.edges(data=True)
}
nx.set_edge_attributes(g, edge_colors, name='color')
We now plot our graph – gravis
deduces the colour automatically by looking for the color
attribute.
We can also set the node and edge size according to importance by using the arguments node_size_data_source
and edge_size_data_source
.
set_node_level_importance(g)
set_node_colors(g)
set_edge_level_importance(g)
set_edge_colors(g)
gv.d3(
g,
graph_height=700,
# we now have importance values to use to size nodes
node_size_data_source='importance',
use_node_size_normalization=True,
node_size_normalization_min=15,
node_size_normalization_max=35,
# we also have importance values to use to size edges
edge_size_data_source='importance',
use_edge_size_normalization=True,
edge_size_normalization_min=1,
edge_size_normalization_max=5,
# we tweak the vizualisation parameters to prevent nodes
# from overlapping too much.
use_collision_force=True,
collision_force_radius=50,
collision_force_strength=1
)
and voila! We have now coloured and sized our nodes and edges according to their importance.

3. Adding node shapes according to the model prediction
We want to add an extra layer of information to our graphs by using different node shapes to represent the model prediction. We do this by making use of the class predictions stored in target
:
def set_node_shapes(g, class_to_label_map:dict):
target = g.graph['target']
for i, node in enumerate(g.nodes()):
g.nodes[node]['shape'] = class_to_label_map[target[i]]
For each node, we set a node attribute called shape
given the target
value; the desired shape per class is provided through class_to_label_map
.
# node shapes according to their predicted class
class_to_label_map = {
0 : 'dot',
1 : 'rectangle',
2 : 'hexagon',
}
set_node_level_importance(g)
set_node_colors(g)
set_node_shapes(g, class_to_label_map)
set_edge_level_importance(g)
set_edge_colors(g)
# this remains the same as before
gv.d3(
g,
graph_height=700,
node_size_data_source='importance',
use_node_size_normalization=True,
node_size_normalization_min=15,
node_size_normalization_max=35,
edge_size_data_source='importance',
use_edge_size_normalization=True,
edge_size_normalization_min=1,
edge_size_normalization_max=5,
use_collision_force=False,
collision_force_radius=50,
collision_force_strength=1,
)

And voila, we can immediately see the predicted class for every node!
We can see that Node 0 is predicted as class 2 (Hexagon), and it is coloured bright green whilst its immediate neighbours are closer to blue.
This seems to suggest that the contents of the publication was one of the main factors that led to the prediction.
This makes us wonder what the Node 0 feature importances look like – wouldn’t it be great to see this breakdown? Let’s see how we can do this in the next section.
4. Display importance information when clicking on nodes and edges
gravis
provides an information panel at the bottom of the display, which by default only shows us the node label.

We can do better than this. For nodes, let’s populate this display with:
- The label
- Node level importance
- Features and their importance
and for edges:
- the edge level importance
For nodes, we have importance values per feature. We unpack this value from node_mask
and save it into each node under the attribute feature_importance
:
def set_feature_level_importance_per_node(g):
feature_level_importance = g.graph['node_mask']
feature_level_importance_per_node = {
node : feature_level_importance[node] for node in g.nodes
}
nx.set_node_attributes(g, feature_level_importance_per_node, name='feature_importance')
Next, the actual contents of the info panel is set by adding a click
attribute to each node and edge, which is constructed as a str
.
To keep things neat, we declare a class called ClickComponents
.
class ClickComponents():
HEADER_COMPONENT = "Importance Values:"
LABEL_COMPONENT = "Label: {}"
NEWLINE = "n"
NODE_HEADER = LABEL_COMPONENT + NEWLINE + HEADER_COMPONENT + NEWLINE
EDGE_HEADER = HEADER_COMPONENT + NEWLINE
NODE_ATTRIBUTION_ROW = "Node Importance: {}" + NEWLINE
EDGE_ATTRIBUTION_ROW = "Edge Importance: {}" + NEWLINE
and we use this to construct our str
as such:

def set_node_click_display(g, feature_names:list[str]):
for _, data in g.nodes(data=True):
# initialise click information as empty string.
data['click'] = ""
# add label information (first two lines)
data['click'] += ClickComponents.NODE_HEADER.format(data['label'])
# set header + node importance value (third line)
node_attribution = data['importance']
data['click'] += ClickComponents.NODE_IMPORTANCE_ROW.format(node_attribution)
# set feature level importance values
# (remaining lines, one for each feature)
vals = data['feature_importance']
idx = np.argsort(vals) # default ascending order
descending_order_idx = idx[::-1] # change to descending order
ordered_vals = np.array(vals)[descending_order_idx]
ordered_names = np.array(feature_names)[descending_order_idx]
for name, val in zip(ordered_names, ordered_vals):
data['click'] += ClickComponents.FEATURE_IMPORTANCE_ROW.format(name, val)
For edges, it is simply:
def set_edge_click_display(g):
for _, _, data in g.edges(data=True):
# set header + edge attribution value
edge_attribution = data['importance']
data['click'] = ClickComponents.EDGE_HEADER
data['click'] += ClickComponents.EDGE_IMPORTANCE_ROW.format(edge_attribution)
# No feature level importance values for edges.
Note that for nodes, we have a feature_names
parameter.
For the PubMed
dataset, we do not know the original feature names or values, but for your own models you can easily retain this information in your preprocessing and pass this in. Here, we just give it an arbitrary name.
import gravis as gv
set_node_level_importance(g)
set_node_colors(g)
set_node_shapes(g)
set_feature_level_importance_per_node(g)
set_node_label(g)
# add made up feature names
set_node_click_display(g, [f"Feature {i}" for i in range(500)])
set_edge_level_importance(g)
set_edge_colors(g)
set_edge_click_display(g)
Now, when we click on node 0, the display looks like this:

Remember, our node features were word vectors so each feature here corresponds to some word that was used in the publication that is Node 0, so from this display we would have been able to see which words the model thought was important.
Alas, we do not have this information at hand, but I hope you get the gist of what we can do and how useful this would be.
Well done – You have a fully interactive network graph visualisation

I will leave it here and let you explore the features that gravis
offers yourself. I have also covered gravis
in a dedicated post here if you want more details.
Some pointers on things you can take a look at:
Explore the settings bar on the right.
- You can use the
Data selection
tab to change the text on each node. By default it displays the node ID, but this can be changed to any value that we added to thenetworkx
graph object, like label or importance values. - You can change the graph layout algorithm parameters through the
layout algorithm
tab. Handy for more complicated graphs.
Modify the code
- Try and modify the code above to plot other types of graphs, such as heterogeneous graphs. You can find code to help you with this in the article above.
- Use a different explanation algorithm or model, or utilise the code to display something different, for example, knowledge graphs.

If you liked my article, please give it a clap and share it with anyone else who might be interested.
If there is anything you read here that you want me to dive into more detail in a different article, do leave a comment with some suggestions!
References
[1] Galileo Mark Namata, Ben London, Lise Getoor, and Bert Huang. Query- driven active surveying for collective classification. In International Workshop on Mining and Learning with Graphs (MLG), Edinburgh, Scotland, 2012.3
[2] Courtesy of the National Library of Medicine. https://www.nlm.nih.gov/databases/download/terms_and_conditions.html