The world’s leading publication for data science, AI, and ML professionals.

Graph Machine Learning with Python Part 4: Supervised & Semi-Supervised Learning

Classifying and Predicting Paintings in the Metropolitan Museum of Art

Network of Paintings in the MET. Image by author
Network of Paintings in the MET. Image by author

Introduction

This story will explore how we can reason from and model graphs using labels via Supervised and Semi-Supervised Learning. I’m going to be using a MET Art Collections dataset that will build on my previous parts on Metrics, Unsupervised Learning, and more. Be sure to check out the previous story before this one to keep up on some of the pieces as I won’t cover all concepts again in this one:

Graph Machine Learning with Python Part 3: Unsupervised Learning

Table of Contents

  1. Feature-based Methods
  2. Embedding-based Methods
  3. Collective Classification
  4. Summary

Feature-based Methods

The easiest approach to conduct Supervised Learning is to use graph measures as features in a new dataset or in addition to an existing dataset. I have seen this method yield positive results for modeling tasks, but it can be really dependent on 1. how you model as a graph (what are the inputs, outputs, edges, etc.) and 2. which metrics to use.

Depending on the prediction task, we could compute node-level, edge-level, and graph-level metrics. These metrics can serve as rich information about the entity itself and also its relationship to other entities. This can be seen as a classical Supervised ML task, but the emphasis here is on Feature Selection. Depending on your prediction task, you may choose different graph metrics to predict for the label.

In our dataset, the label we’ll choose is the "Highlight" field. This choice of the label can change depending on the goal one has, but for this project, the assumption I’m having is that there are a lot of art pieces out there that are worthy of being highlighted (displayed in public domains) but are largely hidden. Let’s see if we can use graphs to create models to help predict which other artworks should be highlighted.

0 10733 
1 305
Name: Highlight, dtype: int64

We see a large class imbalance here, which makes sense since only some artworks are able to be highlighted and usually museums contain a vast collection in their archives.

The dataframe just contains a variety of the graph metrics that we’ll use as features to predict "Highlight".

Let’s randomly oversample our dataset and then predict for our target variable.

Accuracy 0.7889492753623188 
Precision 0.08961303462321792 
Recall 0.6984126984126984 
F1-score 0.15884476534296027

Not the greatest results, but we did just use the base model so we could likely get it performing much better with extensive tuning or feature engineering but that’s outside the scope of this story. One thing to note is that we want to optimize for recall for our problem. It’s better to get "relevant" results than "exact" results in this case, so it is good to see a relatively high recall score (and remember we’re only using a few graph metrics to do this!).

Finally, this was a node-level prediction task where the node in question is an artwork. This same concept can really easily be done for edge or graph-level (with traditional features) tasks as well making it highly versatile.

Embedding-based Methods

Shallow embedding-based methods for Supervised Learning differ from Unsupervised Learning in that they attempt to find the best solution for a node, edge, or graph-level prediction task. The two main algorithms for this are: Label Propagation and Label Spreading.

These two methods fall in the realm of Semi-Supervised Learning because there are only a few labeled nodes in the graph. The main assumption here is that there is homophily in the network, meaning similar nodes may be more likely to attach to each other than dissimilar ones.

Graphs add at least one more dimension of modeling and choices that makes your assumptions really important to validate. For example, in part 3 I specifically chose to design this network in a particular way but it could easily have been crafted differently. The way I chose was to create nodes for each attribute an artwork may have instead of embedding it in the artwork node.

Image by author
Image by author

This may still hold for our use case (ie. artworks in similar periods or cultures would naturally be more connected to each other than if not), but these are the types of choices you should reevaluate alongside modeling assumptions.

Label Propagation

This algorithm propagates a node’s label to its neighbors (nodes) that have the highest probability of being reached from that node.

sklearn has already implemented this algorithm for us to use with two main kernels: kNN and rbf. We can control parameters like gamma and n_neighbors to control for the "distance" the model scans around the labeled nodes.

For our use case, we only have one label technically (1 for "Highlight") and we also have a large volume of artworks connected to nodes like "unknown_Period" or "unknown_Culture" as seen in part 3 (those nodes had high centrality measures). This all essentially means that most of our model configurations will just output "Highlight" to be propagated for nearly all of our nodes indicating that all artworks should be highlighted. Although there’s nothing wrong with that, it doesn’t provide us much value from a model perspective – we could have just said that and been done with it.

Highlight     Is Public Domain 
No Highlight  True                6567               
              False               4166 
Highlight     True                 235               
              False                 70
dtype: int64

What we need to do first is create a heuristic that allows us control of what is "eligible" of really being highlighted. In our dataset, only about 300 artworks are highlighted and of those greater than 75% are in the public domain. We can attribute nodes that aren’t highlighted and aren’t in the public domain with a 0 indicating that they "are not eligible" to be highlighted. This can easily be changed to whatever heuristic makes sense given the data and the domain.

-1    6567  
0     4166  
1     305 
Name: Highlight, dtype: int64

Note that the model needs unlabeled nodes to have a -1 value.

You can extract the model’s probabilities for each class using predictproba() method and then use transduction after model has been fit to attribute the labels to the rows.

array([[4.46560147e-17, 1.00000000e+00],        
      [4.46560147e-17, 1.00000000e+00],        
      [4.89043690e-17, 1.00000000e+00],        
      ...,        
      [1.00000000e+00, 4.60058884e-51],        
      [1.00000000e+00, 1.50981524e-52],        
      [1.00000000e+00, 1.07327710e-50]])
0    6300 
1    4738 
dtype: int64

We now have over 4000 more artworks to potentially consider for highlighting in museums! Model evaluation should also be done and there’s some parameters (gamma or n_neighbors) that can be tuned to ensure a "good" split, but this is a fairly subjective task so task-based evaluation with domain experts is likely to be the best evaluation strategy.

Mathematically, the algorithm uses a transition matrix to multiply to the matrix of class values for each node at a certain time step. This iterates for a number of iteration until convergence.

Let’s cover some of the notation and how the matrices come together. Graph Machine Learning by Stamile et. al. explains it quite well if you’d like further detail.

Adjacency matrix of graph
Adjacency matrix of graph
Diagonal degree matrix of graph
Diagonal degree matrix of graph
Transition matrix that indicates the probability of each node traversing to another
Transition matrix that indicates the probability of each node traversing to another
Probability matrix of class assignment at time t
Probability matrix of class assignment at time t

Label Spreading

The Label Spreading algorithm tackles a critical function/limitation of the Label Propagation algorithm: the initial labeling of nodes to classes. Label Propagation assumes truth of the labels that are initially provided in the dataset and they cannot be changed during training. Label Spreading relaxes this constraint and allows nodes originally labeled to be relabeled during training.

Why? This can be really beneficial if there was some error or bias in the initially labeled nodes; any error present at initial labeling gets propagated throughout and it can be really difficult to detect that through just metrics.

array([[1.00000000e+00, 2.14018170e-10],        
      [1.00000000e+00, 2.05692171e-10],        
      [1.00000000e+00, 2.13458111e-10],        
      ...,        
      [1.00000000e+00, 4.27035952e-18],        
      [1.00000000e+00, 3.85172946e-28],        
      [1.00000000e+00, 1.33374342e-32]])
0    6127 
1    4911 
dtype: int64

Although it looks fairly similar on the surface, this can yield quite different results due to the difference in probabilities. Instead of calculating a transition matrix, the Label Spreading algorithm calculates a normalized graph Laplacian matrix which, similar to the transition matrix, can be seen as a lower dimensional representation of the graph’s nodes and edges.

Normalized graph Laplacian matrix
Normalized graph Laplacian matrix
Probability matrix with using alpha, a regularization parameter to tune how much the original solution influences each iteration
Probability matrix with using alpha, a regularization parameter to tune how much the original solution influences each iteration

Collective Classification

The power of graphs is to use key concepts such as homophily and influence to extract patterns through connections. Homophily refers to the tendency of nodes to associate with other nodes they are similar to and influence refers to the ability of social connections to affect the individual characteristics of a person/node. Both of these effects are captured in graphs and we can leverage these as assumptions to conduct really powerful Supervised and Semi-Supervised ML tasks.

I discussed above some of the more common techniques, but we can dive a bit deeper into this expansive world and unpack some of the theories behind the classification of unknown labels in a network. This part will closely follow Dr. Jure Leskovec’s course on Machine Learning with Graphs on Stanford Online. Let’s clarify the premises that our methods will rely on:

  1. Similar nodes are typically close together or directly connected
  2. The label of a node may depend on: its features, the labels of nodes in its nearby vicinity, and the features of the nodes in its nearby vicinity.

Collective Classification is a probabilistic framework built on the Markov Assumption: the label of one node depends on the labels of its neighbors (first-order Markov Chain). It involves 3 steps:

Local Classifier: assign initial labels.

  • Standard classification task to predict label based on node attributes/features. No network information used.

Relational Classifier: Capture correlations between nodes.

  • Uses network information to classify labels of the nodes based on the labels and/or attributes of its neighbors.

Collective Inference: Propagate correlations through the network.

  • Apply relational classifier to each node iteratively and iterate until the inconsistencies between neighboring labels are minimized.

There are three common models used for Collective Classification:

Probabilistic Relational Classifier

  • Class probability of a node is a weighted average of the class probabilities of its neighbors.
  • Labeled nodes are initialized with the ground-truth label. Unlabeled nodes are initialized with 0.5.
  • Nodes are labeled by summing the nearby first-order nodes labels (ground-truth labels and 0.5 for unlabeled nodes) and dividing by the total first-order nodes.
  • All nodes are updated in random order until convergence or until the maximum number of iterations is reached.

Limitations:

  • Convergence is not guaranteed
  • Model doesn’t use node feature information

Iterative Classification

  • Main idea is to classify nodes based on their attributes as well as the labels of the neighbor set.
  • Two classifiers are essentially trained. One predicts the node label based on the node feature vector. The second predicts the node label based on the node feature vector and the summary of labels of the node’s neighbors.
  • The summary of labels of the neighbors is meant to be a vector and can be represented as: a histogram of the number of each label, the most common label, or the number of different labels in the neighbor set.
  • In phase 1, the two classifiers are built on the training set and in phase 2 the labels are set on the test set based on classifier 1, the summary vector is calculated, and the labels are predicted with classifier 2. This is repeated for each node.

Limitations:

  • Convergence is not guaranteed with this method either.

Loopy Belief Propagation

  • Dynamic programming approach where nodes "pass messages" to each other iteratively to answer probability queries. Belief is calculated when consensus is reached.
  • Each node can only interact (pass message) with its neighbors and that message gets heard by each node from its neighbor, gets updated, and then passed forward. Think of this just like the game telephone.
Jure Lescovec, Stanford, CS224W: Machine Learning with Graphs
Jure Lescovec, Stanford, CS224W: Machine Learning with Graphs
  • The algorithm requires an ordering of nodes, and the edge directions are according to the order of nodes set. This defines how the message is passed.
  • Utilizes a label-label potential matrix which is the dependency between a node and its neighbor. It’s proportional to the probability a node belongs to a class given that it has a neighbor in that same class.
  • Utilizes prior beliefs which is proportional to the probability of a node being a class.
  • Utilizes the "message" which is a node’s estimate of the following node being a class.
  • All messages are initialized to 1 and then for each node, we want to compute what message will be sent along (the belief a given node has of the class of the following node). This is computed by summing over all the states of neighbors the label-label potential, the prior, and the messages sent by neighbors from previous round.
Jure Lescovec, Stanford, CS224W: Machine Learning with Graphs
Jure Lescovec, Stanford, CS224W: Machine Learning with Graphs

Limitations:

  • Some issues can arise when you have cycles in your graph because if the initial belief of a node is incorrect it can be reinforced via the cycle and/or not converge. This can be because each message is considered independent which isn’t really true in cycles.

Some of these are models required to be built by following the algorithm but I recommend doing a quick read of the research papers behind some of the sklearn or networkx papers referenced in the docs – it’s likely some of these techniques are mentioned!

Summary

In this part, I covered how you can take graph information to conduct Supervised and Semi-Supervised learning. The value of using graphs provides rich spatial connectivity and centrality features at the minimum and a wide array of new techniques to expand your repertoire of problem-solving strategies.

If you combine the ideas discussed here with the previous parts on metrics, random worlds and diffusion models, and Unsupervised Learning techniques you can find yourself analyzing your data in a completely new dimension.

Even with all these new ideas and methodologies, I still haven’t even discussed what a lot of the hype in recent years has been for with graphs – Graph Neural Networks. In the following part(s), I’ll take this series into Graph Neural Networks but do keep in mind that Deep Learning still remains relevant for only specific data problems. The majority of data problems are tabular and should not really need deep neural networks to solve, but it never hurts to expand one’s knowledge base!

There are also a lot that I didn’t cover in this story to avoid writing a book and overloading with content, but here are some of the main resources I used and specific topics I think practitioners would gain a lot of value from reviewing:

References

[1] Claudio Stamile, Aldo Marzullo, Enrico Deusebio, Graph Machine Learning

[2] Jure Leskovec, Stanford, CS224W: Machine Learning with Graphs

[3] Easley, David and Kleinberg, Jon. 2010. Networks, Crowds, and Markets: Reasoning About a Highly Connected World

[4] The Metropolitan Museum of Art Open Access


Related Articles