Intro to DeepMind’s Graph-Nets

A short overview of the core components of Graph-Nets

Kristof Neys
Towards Data Science

--

Photo by Iza Gawrych on Unsplash

In October 2018, a team of 27 researchers from DeepMind/Google, MIT, and the University of Edinburgh published a paper entitled: “Relational inductive biases, deep learning, and graph networks”.

The crux of the paper is partly a ‘position paper’ as well as practical implementation because it includes a library for building graph neural networks in TensorFlow and Sonnet (Sonnet is DeepMind’s library for Neural Networks on top of TensorFlow). To date this paper has received almost 1,000 citations, hence it warrants some closer investigation.

In this note, Mark Needham and I will first summarize the key theoretical arguments which the paper sets out and second illustrate the Graph-Net library through the use of a toy example.

TLDR: Graph-Nets is DeepMind’s lower level Graph Neural Network model and library that offers such flexibility that almost any existing GNN can be implemented using 6 core functions and can be extended to Temporal Graphs.

In Theory…

The premise of the paper is as follows: for AI to achieve human-like abilities it must apply Combinatorial Generalisation (CG), and to realize this objective it must bias its learning towards structured representations and computations.

Combinatorial Generalisation is defined:

as constructing new inferences, predictions, and behaviors from known building blocks”.

Informally: a small set of elements such as words can be productively composed in limitless ways such as sentences and a book. Humans have the ability to make “infinite use of finite means”.

Then, how can we improve AI’s capacity for CG?
We do this by the aforementioned “biasing learning towards structured representations and computations…and specifically systems that operate on graphs”. Or put more succinctly: the paper’s thesis is that to advance AI we need to change towards learning on graphs.

The paper formalizes this approach by defining ‘Relational reasoning’ and ‘Inductive bias’ as follows:

“relational reasoning is defined as manipulating structured representations of entities and relations using rules.”

Entities are elements with attributes, hence nodes, a relation is a relationship between them, and a ‘rule’ is a function that maps entities and relations to other entities and relations.

The inductive bias (or learning bias) is the set of assumptions that the learning algorithm uses to predict outputs of given inputs that it has not encountered.

An example would be K-nearest neighbors: the assumption/bias is that occurrences that are near each other tend to belong to the same class, and are determined at the outset.

Lazy Programmer

Then, a “Relational Inductive Bias” is referred to as an inductive bias that imposes constraints (i.e. assumptions/bias) on relationships and interactions among entities in a learning process. An example would be a Multi-Layer Perceptron (MLP) where the hierarchical processing is a type of relational inductive bias.

Right, that all sounds marvelous but how can I use this on actual graphs you must be thinking by now?

Well, the paper makes the complaint that there is a lack of a model with explicit representations of entities and relations and learning algorithms that find rules for computing their interactions.
And then makes the point that in general graphs are a representation that supports arbitrary (pairwise) relational structure, and computations over graphs do afford a strong relational inductive bias, more superior than RNN, CNN etc.

As such the authors introduce the concept of “Graph Networks (GN)”, defined as:

“A GN framework defines a class of functions for Relational reasoning and Graph-structured representations”

In summary, the GN framework in effect generalizes and extends the existing Graph Neural Networks. The key extension that the GN framework offers is its ability to process AND predict sequences of graphs and thereby the trajectory of a dynamical system over time. In theory, playing pool can be modeled:

A rigid body system such as a pool table — image extracted from Battaglia et al.

The main unit of computation in the GN framework is the “GN block”, which is a graph-to-graph model that takes a graph as input, performs computations of the structure, and returns a graph as output.

A GN block — image extracted from Battaglia et al.

In the code, such a block is represented by the “GraphsTuple”. A GN block consists of 6 core functions:

  • 3 ‘update’ functions: one per node, edge and global attribute
  • 3 ‘aggregation’ functions: aggregating edge attribute per node and aggregating edge and node attributes globally

This approach allows for significant flexibility and as such the GN framework can be used to implement a wide variety of architectures.

In Practice…

The core building block that represents graph-structured data is the GraphsTuple class. This object represents batches of one or more graphs and all network modules take instances of GraphsTuple as input and return GraphsTuple as output.

The graphs are…

  • directed (one-way edges),
  • attributed (node-, edge-, and graph-level features are allowed),
  • multigraphs (multiple edges can connect any two nodes, and self-edges are allowed).

To create a GraphsTuple, and thereby a Graph, we need to create a list of a dictionary(ies) with keys: globals, nodes and edges which represents the respective float-valued feature vectors.

The edge list is represented by senders and receivers represented by and integer valued node index.

We can also load graphs without features by simply omitting the keys from the dictionary. Equally, a simple set can be represented as a graph where the node features are the elements of a set. In this case, the dictionary contains only one key-value pair.

Example of how to create a graph in Graph-Nets:

This graph in the GraphsTuple can be visualised as follows:

GraphTuples Rendering, made with https://arrows.app

Graph-Nets Library & Application

To reiterate, the GN framework defines a class of functions, and as such, the Graph-Nets library lists 51 classes of functions.

These can be split into three main parts.

  • First, the core modules are given by the graph-nets.modules and consists of 7 classes.
  • Secondly, for building custom graph-net modules the classes in the graph_nets.blocks will need to be deployed and
  • Finally, the remaining functions are utility functions.

We will briefly describe and illustrate the construction of two modules: GraphNetwork and InteractionNetwork.

How to instantiate a graph network module?

The implementation of a Graph Network is essentially done using the modules.GraphNetwork class and constructs the core GN block.
This configuration can take three learnable sub-functions for edge, node and global and are calls to the Sonnet library and modules.
These can be linear modules (snt.Linear) or multi-layer perceptron (snt.nets.MLP) or potentially any of the 14 classes in the Recurrent module of Sonnet.

The functionality that Sonnet offers in the backend does seem to be quite substantial, allowing for significant flexibility in the design of any GNN.

Once a GraphsTuple is created, as we did earlier, the remaining work left is to create a GraphNetwork is to specify these functions, and an example is as follows:

OUTPUT_EDGE_SIZE = 6
OUTPUT_NODE_SIZE = 5
OUTPUT_GLOBAL_SIZE = 3
graph_network = modules.GraphNetwork(
edge_model_fn=lambda: snt.Linear(output_size=OUTPUT_EDGE_SIZE),
node_model_fn=lambda: snt.Linear(output_size=OUTPUT_NODE_SIZE),
global_model_fn=lambda: snt.Linear(output_size=OUTPUT_GLOBAL_SIZE))
output_graphs = graph_network(graphs_tuple)

Interaction Networks

An example of a “Message Passing Neural Network” is given by the modules.InteractionNetwork which implements the work of the paper by Battaglia et al.: “Interaction Networks for Learning about Objects, Relations and Physics” . Here the authors developed a model “which can reason about how objects in complex systems interact”, and for which the message passing component is a crucial part.

Here again, once a GraphsTuple is constructed, it only requires the Sonnet functions to be specified and the GraphsTuple to be passed, as follows:

interact = modules.InteractionNetwork(
edge_model_fn=lambda: snt.Linear(output_size=OUTPUT_EDGE_SIZE),
node_model_fn=lambda: snt.Linear(output_size=OUTPUT_NODE_SIZE)
)
interact_graph = interact(graphs_tuple)

This will return a GraphsTuple with updated edges and nodes as follows:

As can be seen, the node and edge features are transformed and expanded substantially, providing the basis for the Interaction network model.

The authors also provide a number of runnable examples as notebooks

Conclusion

Graph-Nets is a low-level library for building GNN’s in TensorFlow/Sonnet that offers substantial flexibility, allowing for the implementation of most if not all existing GNN architectures.

The ability to compose ‘GN blocks’ allows for the sequences of graphs and computation over dynamic systems.

As for future work, we intend to explore how we can integrate Graph-Nets with Neo4j as well as exploring other GNN libraries such as ‘jraph’ and ‘Deep Graph Library’.

--

--

Graph Data Science specialist at Neo4j, fascinated by anything with Graphs and Deep Learning. PhD student at Birkbeck, University of London