What do UberEats and Pinterest have in common?
They both use GraphSAGE ** to power their recommender system on a massive scale: millions and billion**s of nodes and edges.
-
🖼 ️ Pinterest developed its own version called PinSAGE to recommend the most relevant images (pins) to its users. → Their graph has 18 billion connections and 3 billion nodes.
-
🍽 ️ UberEats also reported using a modified version of GraphSAGE to suggest dishes, restaurants, and cuisines. → UberEats claims to support more than 600,000 restaurants and 66 million users.
In this tutorial, we’ll use a dataset with 20k nodes instead of billions because Google Colab cannot handle our ambitions. We will stick to the original GraphSAGE architecture, but the previous variants also bring exciting features we will discuss.
You can run the code with the following Google Colab notebook.
🌐 I. PubMed dataset

In this article, we will use the PubMed dataset. As we saw in the previous article, PubMed is part of the Planetoid dataset (MIT license). Here’s a quick summary:
- It contains 19,717 scientific publications about diabetes from PubMed’s database;
- Node features are TF-IDF weighted word vectors with 500 dimensions, which is an efficient way of summarizing documents without transformers;
- The task is a multi-class classification with three categories: diabetes mellitus experimental, diabetes mellitus type 1, and diabetes mellitus type 2.
This is the beauty and the curse of deep learning: I don’t know anything about diabetes, but I’ll still feel pretty satisfied if we reach 70% accuracy. At least we’re not building the next IBM Watson.
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='.', name="Pubmed")
data = dataset[0]
# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
# Print information about the graph
print(f'nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
print(f'Test nodes: {sum(data.test_mask).item()}')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')
Dataset: Pubmed()
-------------------
Number of graphs: 1
Number of nodes: 19717
Number of features: 500
Number of classes: 3
Graph:
------
Training nodes: 60
Evaluation nodes: 500
Test nodes: 1000
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False
As we can see, PubMed has an insanely low number of training nodes compared to the whole graph. There are only 60 samples to learn how to classify the 1000 test nodes.
Despite this challenge, GNNs manage to obtain high levels of accuracy. Here’s the leaderboard of known techniques (a more exhaustive benchmark can be found on PapersWithCode):

I couldn’t find any result for GraphSAGE on PubMed with this specific setting (60 training nodes, 1000 test nodes), so I don’t expect a great accuracy. But another metric can be just as relevant when working with large graphs: training time.
🧙♂️ II. GraphSAGE in theory


The GraphSAGE algorithm can be divided into two steps:
- Neighbor sampling;
- Aggregation.
🎰 A. Neighbor sampling
Mini-batching is a common technique used in Machine Learning.
It works by breaking down a dataset into smaller batches, which allows us to train models more effectively. Mini-batching has several benefits:
- Improved accuracy – mini-batches help to reduce overfitting (gradients are averaged), as well as variance in error rates;
- Increased speed – mini-batches are processed in parallel and take less time to train than larger batches;
- Improved scalability – an entire dataset can exceed the GPU memory, but smaller batches can get around this limitation.
Mini-batching is so useful it became standard in regular neural networks. However, it is not as straightforward with graph data, since splitting the dataset into smaller chunks would break essential connections between nodes.
So, what can we do? In recent years, researchers developed different strategies to create graph mini-batches. The one we’re interested in is called neighbor sampling. There are many other techniques you can find on PyG’s documentation, such as subgraph clustering.

Neighbor sampling considers only a fixed number of random neighbors. Here’s the process:
- We define the number of neighbors (1 hop), the number of neighbors of neighbors (2 hops), etc. we would like to have.
- The sampler looks at the list of neighbors, of neighbors of neighbors, etc. of a target node and randomly selects a predefined number of them;
- The sampler outputs a subgraph containing the target node and the randomly selected neighboring nodes.
This process is repeated for every node in a list or the entirety of the graph. However, creating a subgraph for each node is not efficient, that is why we can process them in batches instead. In this case, each subgraph is shared by multiple target nodes.
Neighbor sampling has an added benefit. Sometimes, we observe extremely popular nodes that act like hubs, such as celebrities on social media. Obtaining the hidden vectors of these nodes can be computationally very expensive since it requires calculating the hidden vectors of thousands or even millions of neighbors. GraphSAGE fixes this issue by simply ignoring most of the nodes!
In PyG, neighbor sampling is implemented through the [NeighborLoader](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader)
object. Let’s say we want 5 neighbors and 10 of their neighbors (num_neighbors
). As we discussed, we can also specify a batch_size
to speed up the process by creating subgraphs for multiple target nodes.
Python">from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
# Create batches with neighbor sampling
train_loader = NeighborLoader(
data,
num_neighbors=[5, 10],
batch_size=16,
input_nodes=data.train_mask,
)
# Print each subgraph
for i, subgraph in enumerate(train_loader):
print(f'Subgraph {i}: {subgraph}')
# Plot each subgraph
fig = plt.figure(figsize=(16,16))
for idx, (subdata, pos) in enumerate(zip(train_loader, [221, 222, 223, 224])):
G = to_networkx(subdata, to_undirected=True)
ax = fig.add_subplot(pos)
ax.set_title(f'Subgraph {idx}')
plt.axis('off')
nx.draw_networkx(G,
pos=nx.spring_layout(G, seed=0),
with_labels=True,
node_size=200,
node_color=subdata.y,
cmap="cool",
font_size=10
)
plt.show()
Subgraph 0: Data(x=[389, 500], edge_index=[2, 448], batch_size=16)
Subgraph 1: Data(x=[264, 500], edge_index=[2, 314], batch_size=16)
Subgraph 2: Data(x=[283, 500], edge_index=[2, 330], batch_size=16)
Subgraph 3: Data(x=[189, 500], edge_index=[2, 229], batch_size=12)

We created 4 subgraphs of various sizes. It allows us to process them in parallel and they’re easier to fit on a GPU since they’re smaller.
The number of neighbors is an important parameter since pruning our graph removes a lot of information. How much, exactly? Well, quite a lot. We can visualize this effect by looking at the node degrees (number of neighbors).
from torch_geometric.utils import degree
from collections import Counter
def plot_degree(data):
# Get list of degrees for each node
degrees = degree(data.edge_index[0]).numpy()
# Count the number of nodes for each degree
numbers = Counter(degrees)
# Bar plot
fig, ax = plt.subplots(figsize=(18, 6))
ax.set_xlabel('Node degree')
ax.set_ylabel('Number of nodes')
plt.bar(numbers.keys(),
numbers.values(),
color='#0A047A')
# Plot node degrees from the original graph
plot_degree(data)
# Plot node degrees from the last subgraph
plot_degree(subdata)


In this example, the maximum node degree of our subgraphs is 5, which is much lower than the original max value. It’s important to remember this tradeoff when talking about GraphSAGE.
PinSAGE ** implements another sampling solution using random walk**s. It has two main objectives:
- Sample a fixed number of neighbors (like GraphSAGE);
- Obtain their relative importance (important nodes are seen more frequently than others).
This strategy feels a bit like a fast attention mechanism. It assigns weights to nodes and increases the relevance of the most popular ones.
💥 B. Aggregation
The aggregation process determines how to combine the feature vectors to produce the node embeddings. The original paper presents three ways of aggregating features:
- Mean aggregator;
- LSTM aggregator;
- Pooling aggregator.

The mean aggregator is the simplest one. The idea is close to a GCN approach:
- The hidden features of the target node and its selected neighbors are averaged (Ñᵢ);
- A linear transformation with a weight matrix 𝐖 is applied.

The result can then be fed to a non-linear activation function like ReLU.
The LSTM aggregator can seem like a weird idea because this architecture is sequential: it assigns an order to our unordered nodes. This is why the authors randomly shuffle them to force the LSTM to only consider the hidden features. It is the best performing technique in their benchmarks.
The pooling aggregator feeds each neighbor’s hidden vector to a feedforward neural network. A max-pooling operation is applied to the result.
🧠 III. GraphSAGE in PyTorch Geometric
We can easily implement a GraphSAGE architecture in PyTorch Geometric with the SAGEConv
layer. This implementation uses two weight matrices instead of one, like UberEats’ version of GraphSAGE:

Let’s create a network with two SAGEConv
layers:
- The first one will use ReLU as the activation function and a dropout layer;
- The second one will directly output the node embeddings.
As we’re dealing with a multi-class classification task, we’ll use the cross-entropy loss as our loss function. I also added an L2 regularization of 0.0005 for good measure.
To see the benefits of GraphSAGE, let’s compare it with **** a GCN and a GAT without any sampling.
class GraphSAGE(torch.nn.Module):
"""GraphSAGE"""
def __init__(self, dim_in, dim_h, dim_out):
super().__init__()
self.sage1 = SAGEConv(dim_in, dim_h)
self.sage2 = SAGEConv(dim_h, dim_out)
self.optimizer = torch.optim.Adam(self.parameters(),
lr=0.01,
weight_decay=5e-4)
def forward(self, x, edge_index):
h = self.sage1(x, edge_index)
h = torch.relu(h)
h = F.dropout(h, p=0.5, training=self.training)
h = self.sage2(h, edge_index)
return h, F.log_softmax(h, dim=1)
def fit(self, data, epochs):
criterion = torch.nn.CrossEntropyLoss()
optimizer = self.optimizer
self.train()
for epoch in range(epochs+1):
acc = 0
val_loss = 0
val_acc = 0
# Train on batches
for batch in train_loader:
optimizer.zero_grad()
_, out = self(batch.x, batch.edge_index)
loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
acc += accuracy(out[batch.train_mask].argmax(dim=1),
batch.y[batch.train_mask])
loss.backward()
optimizer.step()
# Validation
val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
val_acc += accuracy(out[batch.val_mask].argmax(dim=1),
batch.y[batch.val_mask])
# Print metrics every 10 epochs
if(epoch % 10 == 0):
print(f'Epoch {epoch:>3} | Train Loss: {loss/len(train_loader):.3f} '
f'| Train Acc: {acc/len(train_loader)*100:>6.2f}% | Val Loss: '
f'{val_loss/len(train_loader):.2f} | Val Acc: '
f'{val_acc/len(train_loader)*100:.2f}%')
With GraphSAGE, we loop through batches (our 4 subgraphs) created by the neighbor sampling process. The way we calculate the accuracy and the validation loss is also different because of that.
Here are the results (in terms of accuracy and training time) for **** the GCN, the GAT, and GraphSAGE:
GCN test accuracy: 78.40% (52.6 s)
GAT test accuracy: 77.10% (18min 7s)
GraphSAGE test accuracy: 77.20% (12.4 s)
The three models obtain similar results in terms of accuracy. We expect the GAT to perform better because its aggregation mechanism is more nuanced, but it’s not always the case.
The real difference is the training time: GraphSAGE is 88 times faster than the GAT and 4 times **** faster than the GCN in this example!
Here lies the true power of GraphSAGE. We do lose a lot of information by pruning our graph with neighbor sampling. The final node embeddings might not be as good as what we could find with a GCN or a GAT. But this is not the point: GraphSAGE is designed to improve scalability. In turn, it can lead to building larger graphs that can improve accuracy.

This work was done in a supervised training setting (node classification), but we could also train GraphSAGE in an unsupervised way.
In this case, we can’t use the cross-entropy loss. We have to engineer a loss function that forces nodes that are nearby in the original graph to remain close to each other in the embedding space. Conversely, the same function must ensure that distant nodes in the graph must have distant representations in the embedding space. This is the loss that is presented in GraphSAGE’s paper.
In the case of PinSAGE and UberEeats’ modified GraphSAGE, we’re dealing with recommender systems.
The goal is to correctly rank the most relevant items (pins, restaurants) for each user, which is very different. We don’t only want to know what the closest embeddings are, we have to produce the best rankings possible. This is why these systems are also trained in an unsupervised way, but with another loss function: a max-margin ranking loss.
Conclusion
GraphSAGE is an incredibly fast architecture to process large graphs. It might not be as accurate as a GCN or a GAT, but it is an essential model for handling massive amounts of data. It delivers this speed thanks to a clever combination of 1/ neighbor sampling to prune the graph and 2/ fast aggregation with a mean aggregator in this example. In this article,
- We explored a new dataset with PubMed, which is several times larger than the previous one;
- We explained the idea behind neighbor sampling, which only considers a predefined number of random neighbors at each hop;
- We saw the three aggregators presented in GraphSAGE’s paper and focused on the mean aggregator;
- We benchmarked ** three models (GraphSAGE, GAT, and GCN) in terms of accuracy and training tim**e.
We saw three architectures with the same end application: node classification. But GNNs have been successfully applied to other tasks. In the next tutorials, I’d like to use them in two different contexts: graph and edge prediction. This will be a good way to discover new datasets and applications where GNNs dominate the state of the art.
If you enjoyed this article, let’s connect on Twitter @maximelabonne for more graph learning content. Thanks for your attention! 📣
Next article
Chapter 4: How to Design the Most Powerful Graph Neural Network
Learn more about machine learning and support my work with one click – become a Medium member here:
If you’re already a member, you can follow me on Medium.