The world’s leading publication for data science, AI, and ML professionals.

An Interactive Visualisation for your Graph Neural Network Explanations

A step-by-step guide on how to build one in five easy steps, with code already written for you.

..

Image by author. PubMed dataset network composed of 530 nodes and 778 edges, drawn using gravis
Image by author. PubMed dataset network composed of 530 nodes and 778 edges, drawn using gravis

Introduction

I use Graph Neural Networks in my day-to-day job, and I have wasted many days due to the lack of a decent network visualisation tool when trying to explain and review the outputs of a newly trained model.

So this has motivated me to write this article, where I provide a step-by-step guide on how to create a fully interactive network visualisation of a GNN model explanation, in Python, WITHOUT having to rely on expensive third-party solutions.

We go from static plots like this:

Image by author. The default plot generated using torch-geometric. The graph is static, layout is unclear as the graph becomes complex, and node/edge importances are inconspicuous.
Image by author. The default plot generated using torch-geometric. The graph is static, layout is unclear as the graph becomes complex, and node/edge importances are inconspicuous.

To interactive and insightful visualisations like this:

Image by author. Interactive plot generated using gravis. Nodes can be drag-and-dropped, and clicking on nodes can display their features and their importance. Green indicates higher importance, blue indicates lower importance.
Image by author. Interactive plot generated using gravis. Nodes can be drag-and-dropped, and clicking on nodes can display their features and their importance. Green indicates higher importance, blue indicates lower importance.

This visualisation will differ from existing methods (e.g. matplotlib, networkx) as it will allow us to:

  • Drag and drop nodes.
  • Click on both nodes and edges to reveal feature importance values.
  • Colour nodes and edges according to their importance.
  • Apply different shapes to nodes according to the model prediction.
  • Change the display settings dynamically using a menu-panel.

Whilst being:

  • convenient to implement.
  • compatible with networkx, a commonly used network analysis package in Python.
  • provide moderately good scalability to the size of the graph (we’re not making something for prod, but we need it to handle hundreds of nodes at least).

We will cover the following

  • Installation and setup
  • Explain the example node classification task, the dataset, and the model we will use in this article.
  • A quick intro to explaining a GNN using torch-geometric, and what the explanation object looks like.
  • A step-by-step guide on creating the visualisation.

This article will not be explaining the fundamentals of GNNs nor the explanation algorithms used here, as there are many articles and videos out there that already do this.

The models and algorithms I choose are arbitrary and can be substituted to those of your choice.

Also, I have chosen examples that often appear in tutorials, so that people who are familiar with them can go straight to the visualisation part and save some time.

You can also find the Jupyter Notebook that accompanies this article here:

GitHub – bl3e967/medium-gnn-explanation-and-visualisation

…which you can have open side-by-side with this article whilst reading.

So, enough with the intro, let’s get started.

Installation and Set Up

Packages

To get us started, we first need to install our dependencies.

We will be using torch and torch-geometric to train our GNN.

We also install some additional torch related dependencies required for some helper functions we will be using later.

pip install torch torch_geometric
pip install torch-sparse torch-scatter -f https://data.pyg.org/whl/torch-{your-torch-version}+cpu.html

For visualisation, we use gravis,

pip install gravis

my favourite visualisation package. For more info on what this package has to offer, I have written about it in detail here:

The New Best Python Package for Visualising Network Graphs

And finally some additional dependencies:

pip install numpy scikit-learn matplotlib networkx

The Task, the Dataset, and the Model

Before we start visualising, we need (1) an example task and (2) a model to explain.

Multi-class Node Classification Task

We will be using the _PubMed Dataset_¹ (CC BY-SA), a citation network of scientific publications from the PubMed² database relating to diabetes.

Each publication belongs to one of three anonymised categories, and our model is tasked with predicting the correct class for each publication.

Dataset Description

I provide the description from the PubMed dataset source:

The dataset consists of a single large network, composed of 19.7K nodes and 88.7K edges. Each node has a TF/IDF weighted word vector from a dictionary consisting of 500 unique words.

So each node is a publication, an edge is a link between two publication where there exists a citation between them, and features represent words.

This is quite a small dataset—there are 19.7K nodes, but it only specifies 60 nodes for training, 500 for validation and 1000 for testing. The rest are connections that we do not use for training or inference.

Download and Preprocess the Data

We will use the torch-geometric dataset implementation which downloads and preprocesses the data for you in a single line of code.

from torch_geometric.datasets import Planetoid

data_train = Planetoid(root="./data/", name='PubMed').data

print(data_train)
Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])

You can see that the data object consists of several properties. Here is a breakdown of what they mean:

  • x : The node features, of shape [number_of_nodes, number_of_features]
  • edge_index : an edge-list of shape [2, number_of_edges] which defines the connections from_node and to_node. Nodes are represented as the index value in which you can find the node’s corresponding feature values in x.
  • y : The labels, shape [number_of_nodes] where it can take values [0, 1, 2]
  • train_mask , val_mask, test_mask: a boolean tensor of shape [number_of_nodes] indicating whether the node in the corresponding index in x should be used for training/validation/testing. The sum of the elements in each mask equals 60, 500 and 1000 respectively.

The Model

To explain a GNN we obviously need a trained model. We can use the below code to define a simple two-step Graph Attention Network. I have set the hyperparameters according to this paper.

class GAT(torch.nn.Module):
    def __init__(self, input_dim):
        # fix hparams according to paper
        self.n_heads = 8
        self.dropout = 0.6
        self.hidden_dim = 8
        self.num_classes = 3

        super(GAT, self).__init__()
        self.conv1 = GATConv(input_dim, self.hidden_dim, heads=self.n_heads)
        self.conv2 = GATConv(
            self.n_heads * self.hidden_dim, self.num_classes, heads=1
        )

    def forward(self, x, edge_index, adj=None):
        # Layer 1
        x = self.conv1(x, edge_index)
        x = F.dropout(F.elu(x), p=self.dropout, training=self.training)

        # Layer 2
        x = self.conv2(x, edge_index)
        x = F.dropout(F.elu(x), p=self.dropout, training=self.training)
        return F.softmax(x, dim=1)

We can now train it using the below code. You can also replicate the training through the notebook I provided.

First, we define what happens in a single train-validation iteration.

def train(model, data, optimizer):
    # initialise
    model.train()
    optimizer.zero_grad()

    # optimize
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    return loss.item()

def val(model, data):
    # compute loss
    model.eval()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask])

    # compute metrics
    pred = out.argmax(dim=1)
    acc = (pred[data.val_mask] == data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()

    return acc, loss.item()

We then set up the training by defining the hyperparameters, model and optimizer:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# initialise model
num_node_features = data_train['x'].shape[1]
model = GAT(num_node_features)

# Push to GPU if availble
model.to(device)
data_train = data_train.to(device)

# initialise optimizer for training
LEARNING_RATE = 0.01
WEIGHT_DECAY = 1E-5
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)

And now, we run through the training loop.

train_loss_vals = []
val_loss_vals = []
acc_vals = []

n_epochs = 100
for epoch in range(n_epochs):
    # train val iteration
    train_loss = train(model, data_train, optimizer)
    acc, val_loss = val(model, data_train)

    # keep track of metrics 
    train_loss_vals.append(train_loss)
    val_loss_vals.append(val_loss)
    acc_vals.append(acc)

    print(f'Epoch: {epoch:03d}, Loss: {val_loss:.4f}, Accuracy: {acc:.4f}')

Below are the loss curves and accuracy plots:

Image by author. Training and Validation Loss Curve
Image by author. Training and Validation Loss Curve
Image by author. Validation accuracy over training epochs
Image by author. Validation accuracy over training epochs

For the purposes of this article, we just need a model that is good enough so we don’t bother with any further hyperparameter optimisation.

How to Explain the model

  1. Initialise the Explainer object

torch_geometric provides multiple explanation algorithms, and for this article I will use the GNNExplainer. You can find all the explanation algorithms supported by torch-geometric and their use cases here.

from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)
  • We use a node_mask_type = 'attributes' to get a feature-level breakdown of importance values (i.e. importance values for each feature on the node).
  • edge_mask_type = 'object' provides us with edge-level importance values (i.e. importance values for each edge). If we had edge-level features, we could change this to 'attributes' to get feature-level importance values.
  1. Sample nodes and their local neighbourhood

We use the NeighbourLoader class to sample a node and its local neighbourhood that we explain.

from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data=data_train,
    num_neighbors=[10, 10],
    input_nodes=data_train.test_mask, # sample from the test set only
)

The num_neighbors parameter specifies how many neighbours to sample for each hop. So, if num_neighbors=[10,10], it means we are randomly sampling two hops deep into a network, and up to 10 nodes in the first hop and up to 10 nodes in the second hop.

Image by author. An example of neighbour sampling when _num_neighbors = [3, 3]_ on a graph where the source node (blue) has 5 nodes in the first hop, and successive nodes have between 1 and 4 neighbours. In each hop, up to three neighbours are randomly sampled.
Image by author. An example of neighbour sampling when _num_neighbors = [3, 3]_ on a graph where the source node (blue) has 5 nodes in the first hop, and successive nodes have between 1 and 4 neighbours. In each hop, up to three neighbours are randomly sampled.

As we want to plot something interesting, we will loop over the loader to find an example sub-graph that has at least 40 nodes.

for batch in loader:
    if batch['x'].shape[0] > 40).any():
        break
print(batch)
Data(x=[97, 500], edge_index=[2, 110], y=[97], train_mask=[97], val_mask=[97], test_mask=[97], n_id=[97], e_id=[110], input_id=[1], batch_size=1)

We have ended up with a 2-layered graph consisting of 97 nodes and 110 edges.

  1. Run the Explanation Algorithm

We can now run the explanation algorithm on our sampled graph, where we explain the node at index 0:

explanation = explainer(batch.x, batch.edge_index, index=0)
print(explanation)
Explanation(node_mask=[97, 500], edge_mask=[110], prediction=[97, 3], target=[97], index=[1], x=[97, 500], edge_index=[2, 110])

The Explanation Object

An Explanation object is composed of the following attributes (for brevity, I’ve excluded any we don’t need to consider for this article):

Explanation(
  node_mask=[97, 500], # Node importance mask, shape [#nodes, #features]
  edge_mask=[110],     # Edge importance mask, shape [#edges]
  prediction=[97, 3],  # The predicted prob values, shape [#nodes, #classes]
  target=[97],         # The class predictions by the model, shape [#nodes]
  ...
  x=[97, 500],         # node features [#nodes, #features]
  edge_index=[2, 110]  # edge connections [2, #edges]
)
  • Node mask: This contains importance values for each feature in each node in the graph. This is the result of us specifying node_mask_type='attributes' when initialising the Explainer object. The sum of importance values across all features is equal to the node level importance.
  • Edge Mask: This contains importance values for each edge, where we set edge_mask_type='object'. If our dataset had edge features, then like we did for nodes, we can set edge_mask_type='attributes' to get feature level importances.
  • Prediction: The output of the model. In our case, our model provides probability values.
  • Target: The predicted class from the model. Note, this is not the label.

Visualising our Explanation

Given a model, a network, and an explanation, we can start to pull everything together into an interactive visualisation.

Step 1: Convert the Explanation into a NetworkX Graph object

We firstly translate our explanation object into a networkx graph object, as the latter is compatible with our visualisation package, gravis. The full code to do this is provided below.

from torch_geometric.utils import to_networkx

# store labels into the explanation object - we want to use it in our vis.
# store it as 'label' and not 'y' - gravis expects 'y' to store coords
explanation['label'] = batch['y']

# rename 'x' as gravis expects 'x' to store coords
explanation['node_features'] = batch['x']

# convert the explanation into a networkx graph object.
g = to_networkx(
  explanation, 

  # these attributes are unique to the graph, so are graph attributes
  graph_attrs=['node_mask', 'edge_mask', 'target', 'label'], 

  # the node features are unique for each node, so are node attributes.
  node_attrs=['node_features']
)

Let’s break this down.

Firstly, the explanation object has everything for our visualisation except the labels y.

The labels will be useful to have as part of our node-level display so we manually add this.

explanation['label'] = batch['y']

Note, __ we save the labels under the key label instead of y. Similarly, we rename the feature values x to node_features.

explanation['node_features'] = batch['x']

This is because gravis expects the keys x and y to be coordinates. If we don’t do this, the nodes will be fixed in place in weird locations.

Finally, we use the to_networkx function to translate explanation into a networkx.DiGraph instance whilst specifying the attributes we want to retain during the conversion.

# convert the explanation into a networkx graph object.
g = to_networkx(
  explanation, 

  # these attributes are unique to the graph, so are graph attributes
  graph_attrs=['node_mask', 'edge_mask', 'target', 'label'], 

  # the node features are unique for each node, so are node attributes.
  node_attrs=['node_features']
)
  • node_mask and edge_mask are required for determining colours and thicknesses for nodes and edges.
  • node_features and y are useful information that we wish to display when the user clicks on a node.
  • We need the target to set node shapes based on the model prediction. Remember, target stores the predicted class, not the label.

So far, plotting this in gravis will give us a plot like below:

import gravis as gv

gv.d3(g, graph_height=700)
Image by author. An initial plot of our network
Image by author. An initial plot of our network

Step 2: Add Node and Edge Colours

Next, lets colour the nodes and edges according to their importance. We can use the below helper functions to extract this information from node_mask and edge_mask.

def set_node_level_imoprtance(g):
    '''Get node level importance'''
    # sum over each feature importance per node to get the overall node importance
    node_level_importance = np.array(g.graph['node_mask']).sum(axis=1)

    # assign the importance value to each node as an attribute
    node_level_importance_dict = { i : node_level_importance[i] for i in g.nodes }
    nx.set_node_attributes(g, node_level_importance_dict, name="importance")

def set_edge_level_importance(g):
    '''Get edge level importance'''
    edge_level_importance = g.graph['edge_mask']

    # assign the importance value to each edge as an attribute
    edge_level_importance_dict = { edge : edge_level_importance[i] for i, edge in enumerate(g.edges) }
    nx.set_edge_attributes(g, edge_level_importance_dict, name="importance")

For nodes, we have a tensor of shape [#nodes, #features], so we sum over the individual feature values for each node to get node level importances.

node_level_importance = np.array(g.graph['node_mask']).sum(axis=1)

Meanwhile, edge_mask is already on the edge level (shape = [#edges]) so we don’t need to sum anything.

edge_level_importance = g.graph['edge_mask'] 

We then want to map the importance values onto a colormap. In this instance we use the winter colourmap in matplotlib – green indicates a positive contribution to the final predicted class while blue indicates a negative contribution.

Image by author, generated using Matplotlib. Low values are mapped to blue, high values are mapped to green.
Image by author, generated using Matplotlib. Low values are mapped to blue, high values are mapped to green.

To do this, we use this helpful utility class to convert numeric values into an RGB string:

from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib as mpl

class MplColorHelper:
    def __init__(self, cmap_name, start_val, stop_val):
        self.cmap_name = cmap_name
        self.cmap = plt.get_cmap(cmap_name)
        self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
        self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)

    def get_rgba(self, val):
        return self.scalarMap.to_rgba(val, bytes=True)

    def get_rgb_str(self, val):
        r, g, b, a = self.get_rgba(val)
        return f"rgb({r},{g},{b})"

All we need to do is specify the min-max values over which to scale the colormap, then compute the RGB string, and assign it to each node.

def set_node_colors(g, cmap='winter'):
    '''Set colours based on importance values'''

    # scale our colourmap to be between the min-max importance
    vals = nx.get_node_attributes(g, 'importance').values()
    min_val, max_val = min(vals), max(vals)

    # initialise colour helper
    node_color_generator = MplColorHelper(cmap, min_val, max_val)

    # get rgb string for each node
    node_colors = { 
        node : node_color_generator.get_rgb_str(data['importance']) for node, data in g.nodes(data=True) 
    }
    nx.set_node_attributes(g, node_colors, name='color')

The same applies for edges.

def set_edge_colors(g, cmap='winter'):
    '''Set colours based on importance values'''
    # scale our colourmap to be between the min-max importance
    vals = nx.get_edge_attributes(g, 'importance').values()
    min_val, max_val = min(vals), max(vals)

    # initialise colour helper
    edge_color_generator = MplColorHelper(cmap, min_val, max_val)

    # get rgb string for each edge
    edge_colors = {
        (u, v) : edge_color_generator.get_rgb_str(data['importance']) for u, v, data in g.edges(data=True)
    }
    nx.set_edge_attributes(g, edge_colors, name='color')

We now plot our graph – gravis deduces the colour automatically by looking for the color attribute.

We can also set the node and edge size according to importance by using the arguments node_size_data_source and edge_size_data_source.

set_node_level_importance(g)
set_node_colors(g)

set_edge_level_importance(g)
set_edge_colors(g)

gv.d3(
    g, 

    graph_height=700,

    # we now have importance values to use to size nodes
    node_size_data_source='importance',
    use_node_size_normalization=True,
    node_size_normalization_min=15,
    node_size_normalization_max=35,

    # we also have importance values to use to size edges
    edge_size_data_source='importance',
    use_edge_size_normalization=True,
    edge_size_normalization_min=1,
    edge_size_normalization_max=5,

    # we tweak the vizualisation parameters to prevent nodes
    # from overlapping too much. 
    use_collision_force=True,
    collision_force_radius=50,
    collision_force_strength=1
)

and voila! We have now coloured and sized our nodes and edges according to their importance.

Image by author. Larger, greener nodes are of higher importance compared to smaller blue nodes.
Image by author. Larger, greener nodes are of higher importance compared to smaller blue nodes.

3. Adding node shapes according to the model prediction

We want to add an extra layer of information to our graphs by using different node shapes to represent the model prediction. We do this by making use of the class predictions stored in target:

def set_node_shapes(g, class_to_label_map:dict):
    target = g.graph['target']
    for i, node in enumerate(g.nodes()):
        g.nodes[node]['shape'] = class_to_label_map[target[i]]

For each node, we set a node attribute called shape given the target value; the desired shape per class is provided through class_to_label_map.

# node shapes according to their predicted class
class_to_label_map = {
    0 : 'dot',
    1 : 'rectangle',
    2 : 'hexagon',
}

set_node_level_importance(g)
set_node_colors(g)
set_node_shapes(g, class_to_label_map)

set_edge_level_importance(g)
set_edge_colors(g)

# this remains the same as before
gv.d3(
    g, 
    graph_height=700,
    node_size_data_source='importance',
    use_node_size_normalization=True,
    node_size_normalization_min=15,
    node_size_normalization_max=35,
    edge_size_data_source='importance',
    use_edge_size_normalization=True,
    edge_size_normalization_min=1,
    edge_size_normalization_max=5,
    use_collision_force=False,
    collision_force_radius=50,
    collision_force_strength=1,
)
Image by author. Our network now visualised with shapes. Circle for class 0, Rectangle for class 1, Hexagon for class 2.
Image by author. Our network now visualised with shapes. Circle for class 0, Rectangle for class 1, Hexagon for class 2.

And voila, we can immediately see the predicted class for every node!

We can see that Node 0 is predicted as class 2 (Hexagon), and it is coloured bright green whilst its immediate neighbours are closer to blue.

This seems to suggest that the contents of the publication was one of the main factors that led to the prediction.

This makes us wonder what the Node 0 feature importances look like – wouldn’t it be great to see this breakdown? Let’s see how we can do this in the next section.

4. Display importance information when clicking on nodes and edges

gravis provides an information panel at the bottom of the display, which by default only shows us the node label.

Image by author. The information panel displays the node label after clicking on node 0.
Image by author. The information panel displays the node label after clicking on node 0.

We can do better than this. For nodes, let’s populate this display with:

  • The label
  • Node level importance
  • Features and their importance

and for edges:

  • the edge level importance

For nodes, we have importance values per feature. We unpack this value from node_mask and save it into each node under the attribute feature_importance:

def set_feature_level_importance_per_node(g):
    feature_level_importance = g.graph['node_mask']
    feature_level_importance_per_node = {
        node : feature_level_importance[node] for node in g.nodes
    }
    nx.set_node_attributes(g, feature_level_importance_per_node, name='feature_importance')

Next, the actual contents of the info panel is set by adding a click attribute to each node and edge, which is constructed as a str.

To keep things neat, we declare a class called ClickComponents.

class ClickComponents():
    HEADER_COMPONENT = "Importance Values:"
    LABEL_COMPONENT = "Label: {}"
    NEWLINE = "n"

    NODE_HEADER = LABEL_COMPONENT + NEWLINE + HEADER_COMPONENT + NEWLINE
    EDGE_HEADER = HEADER_COMPONENT + NEWLINE

    NODE_ATTRIBUTION_ROW = "Node Importance: {}" + NEWLINE
    EDGE_ATTRIBUTION_ROW = "Edge Importance: {}" + NEWLINE

and we use this to construct our str as such:

An illustration of how each ClickComponent attribute is used to construct the contents of the string.
An illustration of how each ClickComponent attribute is used to construct the contents of the string.
def set_node_click_display(g, feature_names:list[str]):
    for _, data in g.nodes(data=True):
        # initialise click information as empty string. 
        data['click'] = ""

        # add label information (first two lines)
        data['click'] += ClickComponents.NODE_HEADER.format(data['label'])

        # set header + node importance value (third line)
        node_attribution = data['importance']
        data['click'] += ClickComponents.NODE_IMPORTANCE_ROW.format(node_attribution)

        # set feature level importance values 
        # (remaining lines, one for each feature)
        vals = data['feature_importance']
        idx = np.argsort(vals)              # default ascending order
        descending_order_idx = idx[::-1]    # change to descending order
        ordered_vals = np.array(vals)[descending_order_idx]
        ordered_names = np.array(feature_names)[descending_order_idx]

        for name, val in zip(ordered_names, ordered_vals):
            data['click'] += ClickComponents.FEATURE_IMPORTANCE_ROW.format(name, val)

For edges, it is simply:


def set_edge_click_display(g):
    for _, _, data in g.edges(data=True):
        # set header + edge attribution value
        edge_attribution = data['importance']
        data['click'] = ClickComponents.EDGE_HEADER
        data['click'] += ClickComponents.EDGE_IMPORTANCE_ROW.format(edge_attribution)

        # No feature level importance values for edges. 

Note that for nodes, we have a feature_names parameter.

For the PubMed dataset, we do not know the original feature names or values, but for your own models you can easily retain this information in your preprocessing and pass this in. Here, we just give it an arbitrary name.

import gravis as gv

set_node_level_importance(g)
set_node_colors(g)
set_node_shapes(g)
set_feature_level_importance_per_node(g)

set_node_label(g)
# add made up feature names
set_node_click_display(g, [f"Feature {i}" for i in range(500)])

set_edge_level_importance(g)
set_edge_colors(g)
set_edge_click_display(g)

Now, when we click on node 0, the display looks like this:

Image by author. Feature importance breakdown for Node 0.
Image by author. Feature importance breakdown for Node 0.

Remember, our node features were word vectors so each feature here corresponds to some word that was used in the publication that is Node 0, so from this display we would have been able to see which words the model thought was important.

Alas, we do not have this information at hand, but I hope you get the gist of what we can do and how useful this would be.

Well done – You have a fully interactive network graph visualisation

Image by author
Image by author

I will leave it here and let you explore the features that gravis offers yourself. I have also covered gravis in a dedicated post here if you want more details.

The New Best Python Package for Visualising Network Graphs

Some pointers on things you can take a look at:

Explore the settings bar on the right.

  • You can use the Data selection tab to change the text on each node. By default it displays the node ID, but this can be changed to any value that we added to the networkx graph object, like label or importance values.
  • You can change the graph layout algorithm parameters through the layout algorithm tab. Handy for more complicated graphs.

Modify the code

  • Try and modify the code above to plot other types of graphs, such as heterogeneous graphs. You can find code to help you with this in the article above.
  • Use a different explanation algorithm or model, or utilise the code to display something different, for example, knowledge graphs.
Image by author. The same plot, now displaying importance values instead of node id.
Image by author. The same plot, now displaying importance values instead of node id.

If you liked my article, please give it a clap and share it with anyone else who might be interested.

If there is anything you read here that you want me to dive into more detail in a different article, do leave a comment with some suggestions!

References

[1] Galileo Mark Namata, Ben London, Lise Getoor, and Bert Huang. Query- driven active surveying for collective classification. In International Workshop on Mining and Learning with Graphs (MLG), Edinburgh, Scotland, 2012.3

[2] Courtesy of the National Library of Medicine. https://www.nlm.nih.gov/databases/download/terms_and_conditions.html


Related Articles