Thoughts and Theory
A Comprehensive Case-Study of GraphSage with Hands-on-Experience using PyTorchGeometric Library and Open-Graph-Benchmark’s Amazon Product Recommendation Dataset
This blog post provides a comprehensive study of the theoretical and practical understanding of GraphSage which is an inductive graph representation learning algorithm. For a practical application, we are going to use the popular PyTorch Geometric library and Open-Graph-Benchmark dataset. We use the ogbn-products dataset which is an undirected and unweighted graph, representing an Amazon product co-purchasing network to predict shopping preferences. Nodes represent products sold on Amazon, and edges between two products indicate that the products are purchased together. The goal is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels making it a Node Classification Task.
So in brief here is the outline of the blog:
- What is GraphSage
- Importance of Neighbourhood Sampling
- Getting Hands-on Experience with GraphSage and PyTorch Geometric Library
- Open-Graph-Benchmark’s Amazon Product Recommendation Dataset
- Creating and Saving a model
- Generating Graph Embeddings Visualisations and Observations
Power Up!!
I have conducted a Workshop on the topic "Machine Learning on Graphs with PyTorch Geometric, NVIDIA Triton, and ArangoDB: Thinking Beyond Euclidean Space". This workshop digs deeper into the Importance of Graph Data Structures, Applications of Graph ML, Motivation behind Graph Representation Learning, How to use Graph ML in Production with Nvidia Triton Inference Server and ArangoDB using a real world application.
What is Graph Representation Learning?
Once the graph is created after incorporating meaningful relationships (edges) between all the entities (nodes) of the graph. The next question that comes into mind is finding a way to integrate the information about graph structure (e.g. information about the node’s global position in the graph or its local neighbourhood structure) into a machine learning model. One way to extract structural information from the graph is to compute its graph statistics using node degrees, clustering coefficients, kernel functions or hand-engineered features to estimate local neighbourhood structures. However, with these methods we can not perform end-to-end learning i.e features cannot be learned with the help of loss function during the training process. To tackle the above problem, representation learning approaches have been adopted to encode the structural information about the graphs into the euclidean space (vector/embedding space).
The key idea behind graph representation learning is to learn a mapping function that embeds nodes, or entire (sub)graphs (from non-euclidean), as points in low-dimensional vector space (to embedding space). The aim is to optimize this mapping so that nodes which are nearby in the original network should also remain close to each other in the embedding space (vector space), while shoving unconnected nodes apart. Therefore by doing this, we can preserve the geometric relationships of the original network inside the embedding space by learning a mapping function. The below diagram depicts the mapping process, encoder enc maps nodes u and v to low-dimensional vectors zu and zv :

Let’s understand this more intuitively with an interesting example from the graph structure of the Zachary Karate Club social network. In this graph, the nodes represent the persons and there exists an edge between the two persons if they are friends. The colouring in the graph represents different communities. Figure A) represents the Zachary Karate Club social network and B) illustrates the 2D visualisation of node embeddings created from the Karate graph using a DeepWalk method. If you analyse both diagrams you will find that the mapping of nodes from a graph structure (non-euclidean or irregular domain) to an embedding space (figure B) is done in such a manner that the distances between nodes in the embedding space mirror closeness in the original graph (preserving the structure of the node’s neighbourhood). For e.g, the community of the people marked as violet and green shares close proximity in the karate graph as compared to the communities violet and sea green which are far away from each other. When the DeepWalk method is applied on the karate graph (in order to learn the node embeddings) we can observe the same proximity behaviour when the learned node embeddings are visualised in 2D space.

We can use these learned node embeddings for various machine-learning downstream tasks:
1) It can be used as a feature input for downstream ML tasks (eg. community detection via node classification or link prediction)
2) We could construct a KNN/Cosine similarity graph from embeddings. The graph could be used to make recommendations (e.g product recommendation)
3) Visual exploration of data by reducing them to 2 or 3 dimensions using U-Map, t-SNE algorithms (eg. performing clustering).
4) Dataset Comparisons
5) Transfer Learning
GraphSage Motivation!!
In this blogpost/notebook, we will cover a GraphSage (Sample and Aggregate) algorithm which is an inductive (it can generalize to unseen nodes) deep learning method developed by Hamilton, Ying, and Leskovec (2017) for graphs used to generate low-dimensional vector representations for nodes. This is in contrast with the previous Graph Machine Learning methods like Graph Convolutional Networks or DeepWalk which are inherently transductive i.e they can only generate embeddings for the nodes present in the fixed graph during the training. This implies that, if in future the graph evolves and new nodes (unseen during the training) make their way into the graph then we need to retrain the whole graph in order to compute the embeddings for the new node. This limitation makes the transductive approaches inefficient to get applied on the ever-evolving graphs (like social networks, protein-protein networks, etc) because of their inability to generalize on unseen nodes. The other main limitation of transductive approaches (mainly DeepWalk or Node2Vec) is that they cannot leverage the node features e.g text attributes, node profile information, node degrees, etc. On the other hand, the GraphSage algorithm exploits both the rich node features and the topological structure of each node’s neighbourhood simultaneously to efficiently generate representations for new nodes without retraining.
Some of the Popular GraphSage Use Cases:
1) Dynamic Graphs: These are graphs which evolve over time like social network graphs from Facebook, Linkedin or Twitter or posts on Reddit, users and videos on Youtube.
2) Generated node embeddings via unsupervised loss function can be used for various downstream Machine Learning tasks like node classification, clustering, and link prediction.
3) Real-World applications which require to compute embeddings for their subgraphs
4) Protein-Protein interaction graphs: Here, the trained embedding generator can predict the node embeddings for the data collected on new species/organisms
5) UberEats: It uses the power of Graph ML to suggest to its users the dishes, restaurants, and cuisines they might like next. To make these recommendations Uber eats uses the GraphSAGE algorithm because of its inductive nature and the power to scale up to billion nodes
6) Pinterest: It uses the power of PinSage (another version of GraphSage) for making visual recommendations (pins are visual bookmarks e.g. for buying clothes or other products). PinSage is a random-walk-based GraphSage algorithm which learns embeddings for nodes (in billions) in web-scale graphs.
Working Principles of GraphSage

The working process of GraphSage is mainly divided into two steps, the first is performing neighbourhood sampling of an input graph and the second one learning aggregation functions at each search depth. We will discuss each of these steps in detail starting with a little motivation of what was the need to perform the sampling of nodes neighbourhood. Afterwards, we will discuss the importance of learning aggregator functions which basically helped the GraphSage algorithm to achieve its property of inductiveness.
What is the importance of Neighbourhood Sampling?
Let’s understand this from the perspective of the Graph Convolutional Network diagram (GCNs) described below. GCNs is an algorithm which can leverage both the graph topological information (i.e. node’s neighbourhood) and node features and then distil this information in order to generate node representations or dense vector embeddings. The below diagram represents the working process of GCNs intuitively. On the left-hand side, we have a sample input graph where its nodes are represented by their corresponding feature vectors (e.g. node degree or text embeddings, etc). We start with defining a search depth (K) which informs the algorithm up to what depth it should gather the information from the neighbourhood of a target node. Here, K is a hyperparameter and it also depicts the number of layers used in the GCNs.
At K=0, GCNs initialises all the node embeddings to their original feature vector. Now, let’s say we want to compute the embeddings for the target node 0 at layer K=1, then we aggregate (it is permutation invariant function to its neighbours) all the feature vectors of nodes (including itself) which are at a 1-hop distance from the node 0 (at this timestep or layer we are aggregating the original feature representations of nodes which are at K=0). For the target node 0, GCNs uses a mean aggregator to compute the mean of the neighbourhood node features along with its own features (self-loop). After K=1, the target node 0 now knows about the information about its immediate neighbourhood; this process is shown below in the GCNs image (r.h.s). We repeat this process for all the nodes in the graph (i.e. for every node, we aggregate over the 1-hop neighbourhood) in order to find the new representations for each node at each layer.
Note: As the search depth increases, the reach of the target node in terms of aggregating features from its local neighbourhood also increases. For e.g. at K=1 the target node knows the information about its local neighbourhood which is a 1-hop distance, at K=2 the target node knows the information about its local neighbourhood which is at a 1-hop distance and the neighbours of the nodes of 1-hop distance i.e up to 2-hop distance.

Issues with the GCN approach
As we have discussed above, that GCNs compute node representations using neighbourhood aggregation. For training purposes, we can represent the k-hop neighbourhood of a target node as a computational graph and send these computational graphs in a mini-batch fashion in order to learn the weights of the network (i.e. applying stochastic gradient descent). The below diagram illustrates a computational graph for the target node 0 up to 2-hop neighbourhood. Now, the problem with this is that:
1) Computationally Expensive: Since for each node, we need to generate the complete K-hop neighbourhood computational graph and then need aggregate plenty of information from its surroundings. As we go deeper into the neighbourhood (large K) computation graph becomes exponentially large. This could lead to a problem while fitting these big computational graphs inside GPU memory.
2) The curse of Hub nodes or Celebrity nodes: Hub nodes are those nodes which are very high degree nodes in the graph for eg. a very popular celebrity having millions of connections. If that is the case then we need to aggregate the information from millions of nodes in order to compute the embeddings for the hub node. Therefore, the generated computational graph for the hub node is very huge. This problem is illustrated below diagrammatically (r.h.s).

Therefore, the idea is not to take the entire K-hop neighbourhood of a target node but to select a few nodes at random from the K-hop neighbourhood in order to generate a computational graph. This process is known as neighbourhood sampling which provides the GraphSage algorithm with its unique ability to scale up to a billion nodes in the graph. Therefore, using this approach if we encounter any hub node then we are not going to take its entire K-hop neighbourhood but rather select a few nodes at random from each layer or search depth K. Now, the generated computational graph is more efficient to handle by the GPU. The below diagram shows this process by sampling at most 2 neighbours at each hop.

Why GraphSage is called an Inductive Representation Learning algorithm?
GraphSage is an inductive version of GCNs which implies that it does not require the whole graph structure during learning and it can generalize well to the unseen nodes. It is a branch of graph neural networks that learns node representations by sampling and aggregating neighbours from multiple search depths or hops. Its inductive property is based upon the premises that we don’t need to learn the embeddings for each node but rather learn an aggregation function (could be any differentiable function like mean, pooling or lstm) which when given information (or features) from the local neighbourhood of a node then it knows how to aggregate those features (learning takes place via stochastic gradient descent) such that the aggregated feature representation of a node v now includes the information about its local surroundings or neighbourhood.
The GraphSage is different from GCNs in two ways: i.e. 1) Instead of taking the entire K-hop neighbourhood of a target node, GraphSage first samples or prunes the K-hop neighbourhood computation graph and then performs the feature aggregation operation on this sampled graph in order to generate the embeddings for a target node. 2) During the learning process, in order to generate the node embeddings; GraphSage learns the aggregator function whereas GCNs make use of the symmetrically normalized graph Laplacian.
The below diagram illustrates how GraphSage node 0 aggregates information from its sampled local neighbours at search depth K=1. If we observe the r.h.s graph, we will find out that at K=1 the target node 0 is now having the information about its surroundings up to 1-hop.

Formal Explanation of GraphSage
As explained above, the key concept of GraphSage is to learn how to aggregate feature information from a node’s local neighbourhood. Now, let’s understand more formally how GraphSage generates node embedding at each layer (K) using forward propagation. We understand this with the help of visuals and then map this understanding to the pseudocode mentioned in the GraphSage paper. But before that, let’s define some notations which are used in the paper.
Defining Notations:

As seen in the above GraphSage diagram at K=1, the target node 0 aggregates information (features) from its local neighbours up to 1-hop. Similarly at k=2, the target node 0 aggregates information from its local neighbours up to 2-hops i.e now it knows what lies in its neighbourhood up to 2-hops. Therefore, we can iterate this process where target node 0 incrementally obtain more and more information from further reaches of the graph. We do this process of information gathering for each of the nodes in the original graph (∀v ∈ V). Let’s add some visuals to understand this iterative process much more intuitively:
The following image depicts the computation graph of a target node 0 at layer K=0, at this point in time all the nodes in the graph are initialised to their original feature vectors. Our aim is to find the final representation of node 0 (i.e z0) at layer K=2 through an iterative local neighbourhood information-gathering process. This iterative process is also sometimes known as the message-passing approach.
Therefore, we can represent this step formally as :


Note: Since medium does not support subscript, I will write the hidden layer (h) representation as (superscript) h (subscript).
superscript denotes -> Kth layer
subscript denotes -> Node id
Neighbourhood Aggregation (At K=1)
Since nodes gather the information incrementally from the deeper depths of the graph, we start our iteration process from search depth 1….K. At K=1, we aggregate neighbouring node representations for our target node 0 (1h0), i.e node 2 and 3 representations which were at the previous layer (K-1h2 and K-1h3) into a single vector. Here 1h0 is an aggregated representation. At the same time step, nodes 2, 3, and 9 will also aggregate the feature vectors from their respective local neighbourhoods up to a distance of 1-hop. Now at this point in time, each of the nodes in a computation graph knows what kind of information lies in their immediate surroundings.
Therefore, we can represent this step formally as :

Updation
Once we achieve the aggregated representation i.e 1h0, the next step would be to concatenate or combine this aggregated representation with its previous layer representation (0h0). Then transformation is applied to this concatenated output by multiplying it with a weight matrix WK, you can think of this process as similar to applying convolutional kernels (learnable weight matrices) on images in order to extract features from it. In the end, we apply a non-linear activation function on this transformed output making it capable to learn and perform more complex tasks.
Important Note: The GraphSage algorithm learns the weight matrix individually at each search depth K or you can also say that it learns how to aggregate information from a node’s neighbourhood at each search depth.
Hence, we can represent this step formally as :


Normalizing Node Embeddings
Subsequently, normalization is applied on node representation khv (or at this time step 1h0) which helps the algorithm to maintain the general distribution of node embeddings. This step is computed as :

Node Embeddings at K=2
The information gathering from the node’s local neighbourhood at K=1 is completed. At K=2, nodes explore the further reaches of the graph i.e going beyond their immediate neighbourhoods and looking into a hop distance of 2. Again we perform the node’s local neighbourhood aggregation, but this time the target node 0 will now have the information of its neighbours which are at 1-hop and 2-hop distances. Then again we repeat the process of updating and normalisation for the search depth K=2. Since we have set the value of K=2 for understanding the flow of the GraphSage algorithm, therefore, we will stop here. After K=2, each node in the computation graph is represented by their respective final node embeddings i.e. zv.
This workflow is shown below in the image:

Now, we can easily map our understanding to the following GraphSage algorithm from the paper:

Loss Function: Learning the parameters
The authors have recorded the results in the paper by using two different types of loss functions which are as follows:
Unsupervised Case: As described in the graph representation learning section, the aim is to optimize the mapping so that nodes which are nearby in the original network should also remain close to each other in the embedding space (vector space), while shoving unconnected nodes apart.
Supervised Case: Authors use regular cross-entropy loss for performing the task of node classification.
Below is the unsupervised loss function used in the paper:

Hands-On-Experience on GraphSage with PyTorch Geometric Library and OGB Benchmark Dataset!
We will understand the working process of GraphSage in more detail with the help of a real-world dataset from the Open Graph Benchmark (OGB) datasets. The OGB is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs developed by Stanford University.
Heads-up!!!
Lot of code ahead, if you are interested in getting your hands dirty with the code which I would really encourage you to do so, then I have already prepared a google colab notebook with which you can play around….
Dataset
We use the obgn-products dataset which is an undirected and unweighted graph, representing an Amazon product co-purchasing network. Nodes represent products sold on Amazon, and edges between two products indicate that the products are purchased together. Node features represent bag-of-words features taken from the product descriptions. The goal is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels making it a Node Classification Task.
Let’s start with downloading the necessary libraries
# Installing Pytorch Geometric
%%capture
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric
!pip install ogb
!pip install umap-learn
Importing Necessary Libraries
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
import os.path as osp
import pandas as pd
import numpy as np
import collections
from pandas.core.common import flatten
# importing obg datatset
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from pandas.core.common import flatten
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(16.7,8.27)})
sns.set_theme(style="ticks")
import collections
from scipy.special import softmax
import umap
Download and load the dataset
root = osp.join(osp.dirname(osp.realpath('./')), 'data', 'products')
dataset = PygNodePropPredDataset('ogbn-products', root)
Getting train, validation and test index
# split_idx contains a dictionary of train, validation and test node indices
split_idx = dataset.get_idx_split()
# predefined ogb evaluator method used for validation of predictions
evaluator = Evaluator(name='ogbn-products')
Let’s check the training, validation and test node split.
# lets check the node ids distribution of train, test and val
print('Number of training nodes:', split_idx['train'].size(0))
print('Number of validation nodes:', split_idx['valid'].size(0))
print('Number of test nodes:', split_idx['test'].size(0))
Number of training nodes: 196615
Number of validation nodes: 39323
Number of test nodes: 2213091
Loading the dataset
data = dataset[0]
Graph Statistics
# lets check some graph statistics of ogb-product graph
print("Number of nodes in the graph:", data.num_nodes)
print("Number of edges in the graph:", data.num_edges)
print("Node feature matrix with shape:", data.x.shape) # [num_nodes, num_node_features]
print("Graph connectivity in COO format with shape:", data.edge_index.shape) # [2, num_edges]
print("Target to train against :", data.y.shape)
print("Node feature length", dataset.num_features)
Number of nodes in the graph: 2449029
Number of edges in the graph: 123718280
Node feature matrix with shape: torch.Size([2449029, 100])
Graph connectivity in COO format with shape: torch.Size([2, 123718280])
Target to train against : torch.Size([2449029, 1])
Node feature length 100
Checking the number of unique labels
# there are 47 unique categories of product
data.y.unique()
Load integer to real product category from label mapping provided inside the dataset
df = pd.read_csv('/data/products/ogbn_products/mapping/labelidx2productcategory.csv.gz')
Let’s see some of the product categories
df[:10]

Creating a dictionary of product categories and corresponding integer label
label_idx, prod_cat = df.iloc[: ,0].values, df.iloc[: ,1].values
label_mapping = dict(zip(label_idx, prod_cat))
# counting the numbers of samples for each category
y = data.y.tolist()
y = list(flatten(y))
count_y = collections.Counter(y)
print(count_y)
Neighbourhood Sampling
This module iteratively samples neighbours (at each layer) and constructs bipartite graphs that simulate the actual computation flow of GNNs.
sizes: denotes how many neighbours we want to sample for each node in each layer.
NeighborSampler
holds the current :obj:batch_size
, the IDs :obj:n_id
of all nodes involved in the computation, and a list of bipartite graph objects via the tuple :obj:(edge_index, e_id, size)
, where :obj:edge_index
represents the bipartite edges between source and target nodes, obj:e_id
denotes the IDs of original edges in the full graph, and :obj:size
holds the shape of the bipartite graph.
The actual computation graphs are then returned in reverse mode, meaning that we pass messages from a larger set of nodes to a smaller one, until we reach the nodes for which we originally wanted to compute embeddings.
train_idx = split_idx['train']
train_loader = NeighborSampler(data.edge_index, node_idx=train_idx,
sizes=[15, 10, 5], batch_size=1024,
shuffle=True)
GraphSage Algorithm
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
super(SAGE, self).__init__()
self.num_layers = num_layers
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
def forward(self, x, adjs):
# `train_loader` computes the k-hop neighborhood of a batch of nodes,
# and returns, for each layer, a bipartite graph object, holding the
# bipartite edges `edge_index`, the index `e_id` of the original edges,
# and the size/shape `size` of the bipartite graph.
# Target nodes are also included in the source nodes so that one can
# easily apply skip-connections or add self-loops.
for i, (edge_index, _, size) in enumerate(adjs):
xs = []
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
xs.append(x)
if i == 0:
x_all = torch.cat(xs, dim=0)
layer_1_embeddings = x_all
elif i == 1:
x_all = torch.cat(xs, dim=0)
layer_2_embeddings = x_all
elif i == 2:
x_all = torch.cat(xs, dim=0)
layer_3_embeddings = x_all
#return x.log_softmax(dim=-1)
return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
total_edges = 0
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
total_edges += edge_index.size(1)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x)
pbar.update(batch_size)
if i == 0:
x_all = torch.cat(xs, dim=0)
layer_1_embeddings = x_all
elif i == 1:
x_all = torch.cat(xs, dim=0)
layer_2_embeddings = x_all
elif i == 2:
x_all = torch.cat(xs, dim=0)
layer_3_embeddings = x_all
pbar.close()
return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings
Instantiate model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 256, dataset.num_classes, num_layers=3)
model = model.to(device)
Load Node Feature Matrix and Node labels
x = data.x.to(device)
y = data.y.squeeze().to(device)
Training
def train(epoch):
model.train()
#pbar = tqdm(total=train_idx.size(0))
#pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = 0
for batch_size, n_id, adjs in train_loader:
# `adjs` holds a list of `(edge_index, e_id, size)` tuples.
adjs = [adj.to(device) for adj in adjs]
optimizer.zero_grad()
l1_emb, l2_emb, l3_emb = model(x[n_id], adjs)
#print("Layer 1 embeddings", l1_emb.shape)
#print("Layer 2 embeddings", l1_emb.shape)
out = l3_emb.log_softmax(dim=-1)
loss = F.nll_loss(out, y[n_id[:batch_size]])
loss.backward()
optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
#pbar.update(batch_size)
#pbar.close()
loss = total_loss / len(train_loader)
approx_acc = total_correct / train_idx.size(0)
return loss, approx_acc
Epochs!!
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
for epoch in range(1, 21):
loss, acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
Saving the model for the inference part
We need to save the model for the inference part because google colab cannot create two graph loaders at the same time because of the limitation of the RAM size. Therefore, we first train with train_loader and then make inferences on test data using this saved model.
Here you can either save the model on google MyDrive or locally on your computer.
#torch.save(model, '/content/drive/MyDrive/model_weights/graph_embeddings/model.pt')
# saving model in mydrive
from google.colab import drive
drive.mount('/content/drive')
fp = '/content/drive/MyDrive/model.pt'
torch.save(model, './model.pt')
torch.save(model, fp)
Inference: Let’s check GraphSage Inductive Power!!
This part includes making the use of a trained GraphSage model in order to compute node embeddings and perform node category prediction on test data. Afterwards, we compare the U-Map visualisations of node embeddings at 3 different layers of GraphSage and draw some interesting observations.
It would be much more useful to a reader if he/she runs the inference part of GraphSage in google colab notebook which I have prepared in order to get better intuition on how the visualisations at each layer of GraphSage is computed.
GraphSage Layer-1 Node Embeddings Visualization

Observation
Node embeddings visualization at layer-1 shows that the model is not able to separate the product categories well (as the embeddings of different product categories are very close together ), therefore we cannot predict/estimate with high probability which two products can be bought together in the future or if someone buys one product then which other product he/she might also be interested.
GraphSage Layer-2 Node Embeddings Visualization

Observation
At layer-2 we can see some separate clusters of product categories forming and we can draw some valuable insights from their for e.g. movies&TV vs CDs&Vinyl, Beauty vs health & personal care, video games vs toys & games. However, books and beauty clusters are very far away from each other.
GraphSage Layer-3 Node Embeddings Visualization

Observation
At layer-3 node representations are little more finer than the layer-2 as we can see some more distant clusters for eg. cell phones & accessories vs electronics.
Acknowledgements
I would like to thank the whole ML team of ArangoDB for providing me with valuable feedback about the blog.
Want to connect with me: Linkedin
References (more learning material)
- fastgraphml: A low-code framework to accelerate the Graph ML model development process
- https://www.arangodb.com/2021/08/a-comprehensive-case-study-of-graphsage-using-pytorchgeometric/?utm_content=176620548&utm_medium=social&utm_source=linkedin&hss_channel=lcp-5289249 (The original blog post)
- Inductive Representation Learning on Large Graphs
- http://web.stanford.edu/class/cs224w/slides/17-scalable.pdf
- A Voyage through Graph Machine Learning Universe: Motivation, Applications, Datasets, Graph ML Libraries, Graph Databases
- https://medium.com/pinterest-engineering/pinsage-a-new-graph-convolutional-neural-network-for-web-scale-recommender-systems-88795a107f48
- https://eng.uber.com/uber-eats-graph-learning/
- More stuff related to Graph ML