Introduction to Machine Learning with Graphs

This is a conceptual introduction to machine learning with graphs and the tasks that are relevant in this field of research.

Phuc Kien Nguyen
Towards Data Science

--

Image by author.

What are graphs?

Graphs are data structures to describe relationships and interactions between entities in complex systems. In general, a graph contains a collection of entities called nodes and another collection of interactions between a pair of nodes called edges. Nodes represent entities, which can be of any object type that is relevant to our problem domain. By connecting nodes with edges, we will end up with a graph (network) of nodes. In addition, every node and edge may have attributes (or features) that contain additional information about those instances.

For example, if we would like to know how clients of a bank are related to each other based on their money transactions, then we can model this as a network of clients that transfer money to each other across the entire client data set. In this network, each node would be a client, and an edge between two clients represents that those two clients have once transfer money to each other. Node attributes could be bank account numbers and date of birth. Edge attributes could be the total amount of money, the number of transactions, and date of the first transaction. Fig 1. Shows the visualization of this network.

Fig 1. An Undirected Homogeneous Graph. Image by author.

Undirected Graphs vs Directed Graphs

Graphs that don’t include the direction of an interaction between a node pair are called undirected graphs (Needham & Hodler). The graph example of Fig. 1 is an undirected graph because according to our business problem we are interested in finding if clients are related to each other based on whether they have sent money to each other or not. We’re not interested in who is the sender or beneficiary of the money. Therefore the direction of the interaction is not relevant. However, if we are interested in who is the sender and beneficiary, then the direction of the money transfer is important. Graphs that include the direction of an interaction between a pair of nodes are called directed graphs. We visualize this by placing arrows on the edges of the graph. From the example above, we could also model the network as a directed graph. In that case, an arrow that goes from node Adam to Barry, would mean that money has been transferred from Adam to Barry. Therefore Adam is the sender and Barry the beneficiary in this money transaction relationship. Fig 2. shows this situation.

Fig 2. A Directed Homogeneous Graph. Image by author.

Homogeneous vs Heterogeneous Graphs

Another way to distinguish graphs is by looking at what types of nodes the graph has. When all nodes in the graph have only one and the same type, then this graph is called a homogeneous graph. The example above in Fig 1. is a homogeneous graph because all nodes represent the type ‘person’. Imagine that not all clients of the bank are persons but also companies. Then the graph would include another type of node that represents ‘companies’ which have their own set of node features such as the chamber of commerce number and date of establishment. A graph that has more than one node type is called heterogeneous graphs. In addition, there are also graphs with different types of edges. This means that in the network there are multiple ways that nodes can interact with each other. These graphs are called multi-relational. One remark is that in many use cases, edges in heterogeneous graphs are constrained by the types of nodes it connects. Because some interaction can only occur between two specific node types. Fig 3. Shows a network where the money is being transferred between companies and persons. This graph contains edge directions and heterogeneous nodes.

Fig 3. A Directional Heterogeneous Graph. Image by author.

When creating graphs it is key to understand what business problem you try to solve, because that will determine if edges in the graph should have directions and whether it is homogeneous or heterogeneous. This is important because when we run graph algorithms to calculate network properties we get different results on directed graphs versus undirected graphs. For instance, when running a network algorithm on an undirected payment transaction network model, we then assume that the money being transferred will go both ways. This is highly unlikely.

What is machine learning with graphs?

Machine learning has become a key approach to solve problems by learning from historical data to find patterns and predict future events. When we try to predict a target output value based on given input labeled data we’re approaching the problem in a supervised fashion. If the goal is to find patterns in our data, where we often create clusters of data points for that, then our approach is so-called unsupervised. Machine learning with graphs blends the line between this distinction because of two key differences in approaching the problem.

ML with graphs learn from connections between data points
The first key difference between machine learning with graph versus traditional (un)supervised methods is that the latter learn from the properties of individual data points. Those properties or features don’t include information on how individual data points are connected to each other, while properties about relationships between data points provide valuable information to describe the data set. With graphs, those data points are represented by nodes, and information about the relationships is captured in the edges of the network. Whereas, traditional approaches of machine learning requires data scientists to hand-pick that information manually and translate them into features during the ‘feature engineering’ step of the machine learning dev cycle (Hamilton).

ML with graphs is semi-supervised learning
The second key difference is that machine learning with graphs try to solve the same problems that supervised and unsupervised models attempting to do, but the requirement of having labels or not during training is not strictly obligated. With machine learning on graphs we take the full graph to train the model, this includes also all the unlabeled nodes. Although the labels are missing on some of these nodes, we can still use all the information about neighborhood nodes and edges in our test set to improve the model during training. This is significantly different from supervised models where unlabeled data is not included during training. Machine learning on graphs in this case uses labeled and unlabeled data to train the model and therefore is often called semi-supervised (Hamilton). Here I would like to point out that semi-supervised learning is not exclusively to machine learning on graphs. There is a fair field of research and application dedicated to generative models that can learn from unlabeled data to improve the performance of a supervised classifier.

Recently well-studied and applied machine learning techniques with graphs can be roughly divided into three tasks: node embedding, node classification, and linked prediction. I will describe these tasks in general, to show what they entail and how they can be used in practice.

Node Embedding

Creating node embeddings is the task of aggregating information of a node’s position in the graph and its local neighbors. This aggregation process results in an encoded feature vector, called node embedding, that summarizes the properties of a node and its relationships with its local neighbor nodes. Getting a node embedding is in many cases the first step towards implementing a machine learning model with graphs. Typically the resulting node embedding will be passed into another downstream machine learning model during training as an additional feature. The intuition here is that adding the node embedding will improve the model performance of that downstream model because it contains underlying structural information about the data points that were not captured with the initial feature set of the model (Hamilton, Ying & Leskovec).

Fig 4. Generate a node embedding. Image by author.

For example, let’s take the transaction data from above between Adam and Barry, and apply a random forest classifier to determine whether a person is a fraud or not. In general, a typical machine learning life-cycle would start with transforming raw data into structured data by using some feature engineering techniques that allow us to pass those features into our random forest model to predict the labels. However, with the task of node embedding, we’re learning the feature representation of each person in the transaction network. We then, pass that node embedding, as a feature vector, into our random forest model like all other features and predict the labels as usual. See fig 5. Recent applications show that a combination of the network feature with regular features can boost performance strongly.

Fig 5. Using a node embedding in combination with regular features as an input feature vector into a downstream random forest classifier. Image by author.

Node Classification

Above we’ve seen how we can improve machine learning models by adding a node embedding as an input feature vector into a random forest model. However, it is also possible to classify node labels directly from the graph structural data without relying on a downstream machine learning model. This task is called node classification. The goal of node classification is to predict the label of each node based on its association with other neighborhood nodes. The true labels are only included on a subset of the entire graph. Thus, given a partially labeled graph, predict the labels of the nodes without labels.

For example, if we would model a wild-life trading network to identify illegal trading activities and parties, then in this graph each node could be a buyer or seller, and an edge would represent a trade transaction between those buyers and sellers. Node attributes could include name, date of birth and bank account number, while edge attributes could include product name, trade document number, and price. Let’s say a subset of the sellers is labeled as illegal traders based on reported data. Node classification in this case aims to predict whether a trader who has not been labeled, should be labeled as illegal or not. It does this by looking into the trading behavior with others in its network. Nodes with similar features, and edges to the one that has been labeled as illegal, are more likely to be illegal as well. See Fig 6.

Fig 6. Node classification: Given a graph with labeled and unlabeled nodes, predict the nodes without labels based on their node features and their neighborhood nodes. Image by author.

Link Prediction

The task of link prediction is to determine the probability of a link between two nodes in a graph (Zhand & Chen). A well-known class of approaches to carry out this task is called heuristic methods. These methods calculate similarity scores between two nodes based on their heuristics such as graph distance, common neighbors, or Pagerank. Heuristic methods basically reveal graph, node and edge properties at a point in time. Those properties we can calculate directly from the graph to obtain the similarity score for each node pair. After that, we then sort the node pairs based on their similarity score and we predict that an edge should exist between the highest-scoring node pairs. Finally, we evaluate our predictions by checking if a non-existence edge in our initial time frame, becomes present at a later point in time.

A drawback of heuristic methods is that it assumes when two nodes have common neighbors they are more likely to connect. This is might be the case when the graph represents a social network such as the illegal wild-life trade example above, but for instance, in a protein-protein interaction network, two proteins sharing many common neighbors are less likely to interact (Hamilton). The challenge here is to determine which heuristic method we should use to calculate the similarity score that makes sense for our use case. Recent studies show that we can actually learn how this heuristic method should look to optimize our link prediction. The idea is that we take for each edge the enclosing nodes and generate an embedding for those nodes. Then we pass those two node embeddings into a function that concatenate (or sum, avg, distance etc.) them into a new feature vector and pass that to a downstream binary classifier. See Fig 7.

Fig 7. Link Prediction: Given a node pair without having a link between them, predict if they become connected in the future based on their node features and neighborhood nodes. Image by author.

Conclusion

Graphs analytics focus on investigating relationships between data points in our data set. Representing data structures as graphs allow us to discover relationships and patterns which could have been ignored if we model our data around isolated data points. Machine learning on graphs helps us to encode such graph structures that can be exploited further by machine learning models (Hamilton).

Sources

Graph Representation Learning
https://www.cs.mcgill.ca/~wlh/grl_book/

Representation Learning on Graphs: Methods and Applications
https://arxiv.org/abs/1709.05584

Link Prediction Based on Graph Neural Networks https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf

Graph Algorithms
https://www.oreilly.com/library/view/graph-algorithms/9781492047674/

--

--

Data Scientist at ABN Amro Bank | Anti-Money Laundering | Co-Author of Graph-Powered Analytics and Machine Learning with TigerGraph | Amsterdam