
Tutorial
A special thanks to Alvaro Sanchez Gonzalez from DeepMind and Bryan Perozzi and Sami Abu-el-haija from Google who assisted me with this tutorial
Updated 04/22/2023 for minor fixes and adding the Graph Nets approach
Graph data is everywhere. Graph research is in its infancy and tools for modeling graph data are only starting to emerge. This makes it the perfect time to jump in if you are a data scientist looking to distinguish yourself. Unfortunately, it can be difficult being on the cutting edgedue to a lack of tutorials and support. This guide hopes to significantly reduce that pain point.
Why TensorFlow-GNN?
TF-GNN was recently released by Google for Graph Neural Networks using TensorFlow. While there are other GNN libraries out there, TF-GNN’s modeling flexibility, performance on large-scale graphs due to distributed learning, and Google backing means it will likely emerge as an industry standard. This guide assumes you already understand the merits of this library, but please see this paper for more information and performance comparisons. Also, check out the documentation for TF-GNN. If you are new to GNN altogether, check out this guide for a conceptual understanding.
What is the downside?
With this library currently in an alpha stage, the code is very exact on the structures, input shapes, and formats required to model successfully. This makes it very difficult to navigate without a guide. Unfortunately, there is not much information out there for using TF-GNN. The guides I could find focus on the same context-level prediction use case using a pre-built TensorFlow dataset. As of writing this, there is not a single walk-through for:
- Making edge or node predictions
- Starting with your own Pandas or NetworkX datasets
- Creating holdout datasets
- Model tuning
- Troubleshooting errors you may run into
After a good month of rereading documentation, trial-and-error coding, and some direct help from the TensorFlow developers at Google/DeepMind, I decided to put this guide together.
"Many [hours] died to bring us this information."
What this guide will cover:
First, we will start very simply to get the building blocks down. Then we will move to a more advanced example – college football conference predictions. Here is the outline of what will be covered:
- TF-GNN elements
- Building Blocks
- Graph tensor from Pandas
- Data setup
- Graph tensor from NetworkX
- Feature engineering
- Creating test splits
- Creating a graph TensorFlow dataset
- Building the model
- Node model
- Edge model
- Context model
- Troubleshooting errors
- Parameter tuning
TF-GNN elements
A graph consists of nodes and edges. Here is an example of a simple graph showing people (nodes) who recently had contact each other (edges):

This same graph could also be represented as node and edge tables. We can also add features to these nodes and edges. For example, we can add ‘age’ as a node feature and an ‘is-friend’ indicator as an edge feature.

When we add edges to TF-GNN, we need to index by number rather than name. We can do that like so:
node_df = node_df.reset_index()
merge_df = node_df.reset_index().set_index('Name').rename(
columns={'index':'Name1_idx'})
edge_df = pd.merge(edge_df,merge_df['Name1_idx'],
how='left',left_on='Name1',right_index=True)
merge_df = merge_df.rename(columns={'Name1_idx':'Name2_idx'})
edge_df = pd.merge(edge_df,merge_df['Name2_idx'],
how='left',left_on='Name2',right_index=True)

Finally, we might have a context value for the graph. For example, maybe this friend group scored an average of 84% on a certain test. That will not mean much for this single-graph example. If we had other friend graphs, we could perhaps predict scores for new friend groups based on learned group dynamics.
Graph tensor from pandas
With these elements, we can now build the foundation for our GNN: a graph tensor.
import tensorflow_gnn as tfgnn
graph_tensor = tfgnn.GraphTensor.from_pieces(
node_sets = {
"People": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Age': np.array(node_df['Age'],
dtype='int32').reshape(len(node_df),1)})},
edge_sets ={
"Contact": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'Is-friend': np.array(edge_df['Is-friend'],
dtype='int32').reshape(len(edge_df),1)},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(edge_df['Name1_idx'], dtype='int32')),
target = ("People", np.array(edge_df['Name2_idx'], dtype='int32'))))
})
Notice how the features we created fit into the nodes and edges. The indented structure makes it simple to add additional nodes, edges, and features. For example, we could easily add nodes and edges for the movies each friend has watched and include a graph context value this time.
graph_tensor = tfgnn.GraphTensor.from_pieces(
context_spec = tfgnn.ContextSpec.from_field_specs(
features_spec ={
"score": [[0.84]]
}),
node_sets = {
"People": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Age': np.array(node_df['Age'],
dtype='int32').reshape(len(node_df),1)}),
"Movies": tfgnn.NodeSet.from_fields(
sizes = [len(movie_df)],
features ={
'Name': np.array(movie_df['Name'],
dtype='string').reshape(len(movie_df),1),
'Length': np.array(movie_df['Length'],
dtype='float32').reshape(len(movie_df),1)})},
edge_sets ={
"Contact": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'Is-friend': np.array(edge_df['Is-friend'],
dtype='int32').reshape(len(edge_df),1)},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(edge_df['Name1_idx'], dtype='int32')),
target = ("People", np.array(edge_df['Name2_idx'], dtype='int32')))),
'Watched': tfgnn.EdgeSet.from_fields(
sizes = [len(watched_df)],
features = {},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(watched_df['Name_idx'], dtype='int32')),
target = ("Movies", np.array(watched_df['Movie_idx'], dtype='int32'))))
})
Note: Be very careful with your dtypes and shapes. Any deviations will cause errors or training issues. The only supported dtypes are ‘int32’, ‘float32’, and ‘string’. If you are having issues, please see the troubleshooting section towards the end of this article.
You may have noticed that the graph tensor is directional with a source and target. This might be fine for Sam watching a movie, but communication is bidirectional. When Sam talks to Amy, Amy is also talking to Sam. For bidirectional data, you will want to duplicate those edges (with source and target reversed) to indicate both directions of data flow.

With this foundation, we are now ready to transition to making predictions on a real dataset.
Data setup
The training data is a network of American football games between Division IA colleges during regular season Fall 2000, as compiled by M. Girvan and M. Newman. Node data includes college names and an index of the conference they belong to (e.g. conference 8 = Pac 10). Edges include the two college names, indicating a game was played between them. The data can be pulled as follows (see Google Colab to follow along):
import urllib.request
import io
import zipfile
import networkx as nx
url = "http://www-personal.umich.edu/~mejn/netdata/football.zip"
sock = urllib.request.urlopen(url) # open URL
s = io.BytesIO(sock.read()) # read into BytesIO "file"
sock.close()
zf = zipfile.ZipFile(s) # zipfile object
txt = zf.read("football.txt").decode() # read info file
gml = zf.read("football.gml").decode() # read gml data
# throw away bogus first line with # from mejn files
gml = gml.split("n")[1:]
G = nx.parse_gml(gml) # parse gml data
print(txt)
Graph tensor from NetworkX
Our data is now in a NetworkX graph. Let’s see how it looks with nodes colored by which conference they belong to.
cmap = {0:'#bd2309', 1:'#bbb12d',2:'#1480fa',3:'#14fa2f',4:'#faf214',
5:'#2edfea',6:'#ea2ec4',7:'#ea2e40',8:'#577a4d',9:'#2e46c0',
10:'#f59422',11:'#8086d9'}
colors = [cmap[G.nodes[n]['value']] for n in G.nodes()]
pos = nx.spring_layout(G, seed=1987)
nx.draw_networkx_edges(G, pos, alpha=0.2)
nx.draw_networkx_nodes(G, pos, nodelist=G.nodes(),
node_color=colors, node_size=100)

For our node model, we will attempt to predict the conference a school belongs to. For our edge model, we will attempt to predict if a game was an in-conference game. Both predictions will be evaluated on a holdout dataset. **** How do we do this from NetworkX? It is possible to build a graph tensor directly from a graph using these functions to extract the data:
node_data = G.nodes(data=True)
edge_data = G.edges(data=True)
The problem is, we still want to do some feature engineering and we do not yet have our holdout dataset. For these reasons, I highly recommend taking the approach of converting your graph data to Pandas. Later, we can plug our data into a graph tensor using the method shown in our first example.
node_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient='index')
node_df.index.name = 'school'
node_df.columns = ['conference']
edge_df = nx.to_pandas_edgelist(G)

Feature engineering
Using the base graph, a model might be able to determine if two colleges are in the same conference based on the network. But how would it know which conference specifically? How could it learn the differences between conferences without any node or edge data? For this task, we will need to add more features.
What kind of features should we gather? I am no expert in college football, but I would imagine conferences are put together based on proximity and rank. This guide is focused on TF-GNN so I will add these new features using magic, but you can find the specific code in the linked Google Colab.
For nodes, we will add latitude/longitude and the previous year’s (1999) rank, wins, and conference wins. We will also convert the conference column into 12 dummy variable columns for a softmax prediction.

For edges, we will calculate the distance between schools, add a name similarity score (maybe schools with the same state in their name are more likely to be in the same conference), and a target value for games being a within-conference game.

Let’s visualize our data with our new information (orange edges indicate a conference game). It definitely appears that geography at least plays a role in conference selection.

Creating test splits
Creating a training split is straightforward; exclude the holdout nodes and edges the same way you normally would. Holdout data, however, is a little different from your typical Machine Learning application. Because the overall connections are important for an accurate prediction, the final prediction will need to be on the entire graph. Once a prediction is made, the results can be filtered down to the holdout data for the final evaluation. I will show this process in more detail at the prediction stage; here is how I create the splits for now:
from sklearn.model_selection import train_test_split
node_train, node_test = train_test_split(node_df,test_size=0.15,random_state=42)
edge_train = edge_df.loc[~((edge_df['source'].isin(node_test.index)) | (edge_df['target'].isin(node_test.index)))]
edge_test = edge_df.loc[(edge_df['source'].isin(node_test.index)) | (edge_df['target'].isin(node_test.index))]
With our new splits, we can now make our bidirectional adjustments and add the edge index columns.
def bidirectional(edge_df):
reverse_df = edge_df.rename(columns={'source':'target','target':'source'})
reverse_df = reverse_df[edge_df.columns]
reverse_df = pd.concat([edge_df, reverse_df], ignore_index=True, axis=0)
return reverse_df
def create_adj_id(node_df,edge_df):
node_df = node_df.reset_index().reset_index()
edge_df = pd.merge(edge_df,node_df[['school','index']].rename(columns={"index":"source_id"}),
how='left',left_on='source',right_on='school').drop(columns=['school'])
edge_df = pd.merge(edge_df,node_df[['school','index']].rename(columns={"index":"target_id"}),
how='left',left_on='target',right_on='school').drop(columns=['school'])
edge_df.dropna(inplace=True)
return node_df, edge_df
edge_full_adj = bidirectional(edge_df)
edge_train_adj = bidirectional(edge_train)
node_full_adj,edge_full_adj = create_adj_id(node_df,edge_full_adj)
node_train_adj,edge_train_adj = create_adj_id(node_train,edge_train_adj)
Creating a TensorFlow dataset
We are now ready to create our graph tensors which we will transform into TensorFlow datasets.
def create_graph_tensor(node_df,edge_df):
graph_tensor = tfgnn.GraphTensor.from_pieces(
node_sets = {
"schools": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Latitude': np.array(node_df['Latitude'], dtype='float32').reshape(len(node_df),1),
'Longitude': np.array(node_df['Longitude'], dtype='float32').reshape(len(node_df),1),
'Rank': np.array(node_df['Rank'], dtype='int32').reshape(len(node_df),1),
'Wins': np.array(node_df['Wins'], dtype='int32').reshape(len(node_df),1),
'Conf_wins': np.array(node_df['Conf_wins'], dtype='int32').reshape(len(node_df),1),
'conference': np.array(node_df.iloc[:,-12:], dtype='int32'),
}),
},
edge_sets ={
"games": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'name_sim_score': np.array(edge_df['name_sim_score'], dtype='float32').reshape(len(edge_df),1),
'euclidean_dist': np.array(edge_df['euclidean_dist'], dtype='float32').reshape(len(edge_df),1),
'conference_game': np.array(edge_df['conference_game'], dtype='int32').reshape(len(edge_df),1)
},
adjacency = tfgnn.Adjacency.from_indices(
source = ("schools", np.array(edge_df['source_id'], dtype='int32')),
target = ("schools", np.array(edge_df['target_id'], dtype='int32')),
)),
})
return graph_tensor
full_tensor = create_graph_tensor(node_full_adj,edge_full_adj)
train_tensor = create_graph_tensor(node_train_adj,edge_train_adj)
Before creating the dataset, we need a function that will split our graph into our training data and the target we will be predicting (shown as label below). For our node prediction problem, we will make ‘conference’ our label. We also need to drop the ‘conference_game’ feature from the dataset since it would create a data leakage issue (i.e. cheating).
def node_batch_merge(graph):
graph = graph.merge_batch_to_components()
node_features = graph.node_sets['schools'].get_features_dict()
edge_features = graph.edge_sets['games'].get_features_dict()
label = node_features.pop('conference')
_ = edge_features.pop('conference_game')
new_graph = graph.replace_features(
node_sets={'schools':node_features},
edge_sets={'games':edge_features})
return new_graph, label
We will do the reverse for our edge model: drop the ‘conference’ feature and split off ‘conference_game’ as our target (label).
def edge_batch_merge(graph):
graph = graph.merge_batch_to_components()
node_features = graph.node_sets['schools'].get_features_dict()
edge_features = graph.edge_sets['games'].get_features_dict()
_ = node_features.pop('conference')
label = edge_features.pop('conference_game')
new_graph = graph.replace_features(
node_sets={'schools':node_features},
edge_sets={'games':edge_features})
return new_graph, label
We can now create our dataset and map it through the function above.
def create_dataset(graph,function):
dataset = tf.data.Dataset.from_tensors(graph)
dataset = dataset.batch(32)
return dataset.map(function)
#Node Datasets
full_node_dataset = create_dataset(full_tensor,node_batch_merge)
train_node_dataset = create_dataset(train_tensor,node_batch_merge)
#Edge Datasets
full_edge_dataset = create_dataset(full_tensor,edge_batch_merge)
train_edge_dataset = create_dataset(train_tensor,edge_batch_merge)
The order of these procedures is extremely important:
- We create our dataset from the graph tensor.
- We split our dataset in batches (read up on batch sizes).
- In the map function, we merge those batches back into one graph.
- We split/drop the features as needed.
The model will not train (or not correctly) if you do not follow this order precisely.
Building the model
We have our datasets, now the fun part! First, we define the inputs using our dataset spec.
graph_spec = train_node_dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)
Now we need to initialize our features. We will create functions for initializing the nodes and edges. Then we map our features through these functions. To keep things simple, I will create a dense layer for each feature.
def set_initial_node_state(node_set, node_set_name):
features = [
tf.keras.layers.Dense(32,activation="relu")(node_set['Latitude']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Longitude']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Rank']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Wins']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Conf_wins'])
]
return tf.keras.layers.Concatenate()(features)
def set_initial_edge_state(edge_set, edge_set_name):
features = [
tf.keras.layers.Dense(32,activation="relu")(edge_set['name_sim_score']),
tf.keras.layers.Dense(32,activation="relu")(edge_set['euclidean_dist'])
]
return tf.keras.layers.Concatenate()(features)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
There is a lot of customization that can happen with this previous step. For example, we could create word embeddings for string features. We could probably gain some accuracy by hashing a latitude / longitude grid rather than just using a dense layer. TensorFlow has many options available to us.
A few things to note:
- If you have multiple nodes or edges, you will need to add ‘if statements’ to apply features to the correct node/edge.
- Nodes or edges without features can also be initialized with the ‘MakeEmptyFeature’ function.
- For a node-centric problem, initializing edges is optional (read more on node vs edge centric).
- The first node must have at least one feature. You may have to create an embedding on an index if you have no features (results will likely not be very good).
# Examples, do not use for this problem
def set_initial_node_state(node_set, node_set_name):
if node_set_name == "node_1":
return tf.keras.layers.Embedding(115,3)(node_set['id'])
elif node_set_name == "node_2":
return tfgnn.keras.layers.MakeEmptyFeature()(node_set)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state)(input_graph)
Before we develop our update loop, we need one more helper function. As we add dense layers, we will want to make sure we are utilizing L2 regulation and/or dropout (L1 would work as well).
def dense_layer(self,units=64,l2_reg=0.1,dropout=0.25,activation='relu'):
regularizer = tf.keras.regularizers.l2(l2_reg)
return tf.keras.Sequential([
tf.keras.layers.Dense(units,
kernel_regularizer=regularizer,
bias_regularizer=regularizer),
tf.keras.layers.Dropout(dropout)])
Node Model
There are several model architectures out there, but graph convolutional networks are by far the most common (see other approaches described here). Graph convolutions are similar to convolutions commonly used in computer vision problems. The main difference is that graph convolutions work on the irregular data you find with graph structures. Let’s jump into the actual code.
graph_updates = 3 # tunable parameter
for i in range(graph_updates):
graph = tfgnn.keras.layers.GraphUpdate(
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = dense_layer(32),
reduce_type="sum",
sender_edge_feature = tfgnn.HIDDEN_STATE,
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64)))})(graph) #start here
logits = tf.keras.layers.Dense(12,activation='softmax')(graph.node_sets["schools"][tfgnn.HIDDEN_STATE])
node_model = tf.keras.Model(input_graph, logits)
The code above can seem a little confusing because of how TensorFlow stacking works. Remember that the (graph) labeled ‘#start here’ at the end of the ‘GraphUpdate’ function is really the input for the code that comes before it. At first, this (graph) equals the initialized features we mapped previously. The input gets fed into the ‘GraphUpdate’ function becoming the new (graph). With each ‘graph_updates’ loop, the previous ‘GraphUpdate’ becomes the input for the new ‘GraphUpdate’ along with a dense layer specified with the ‘NextStateFromConcat’ function. This diagram should help explain:

The ‘GraphUpdate’ function simply updates the specified states (node, edge, or context) and adds a next state layer. In this case, we are only updating the node states with ‘NodeSetUpdate’ but we will explore an edge-centric approach when we work on our edge model. With this node update, we are applying a convolutional layer along the edges, allowing for information to feed to the node from neighboring nodes and edges. The number of graph updates is a tunable parameter, with each update allowing for information to travel from further nodes. For example, the three updates specified in our case allow for information to travel from up to three nodes away. After our graph updates, the final node state becomes the input for our prediction head labeled ‘logits’. Because we are predicting 12 different conferences, we have a dense layer of 12 units with a softmax activation. Now we can compile the model.
node_model.compile(
tf.keras.optimizers.Adam(learning_rate=0.01),
loss = 'categorical_crossentropy',
metrics = ['categorical_accuracy']
)
node_model.summary()
And finally, we train the model. I am using a callback to stop training when the validation dataset stops improving accuracy. It isn’t perfect since we have to use the full dataset (explained above). This will cause our accuracy number to include data leakage. A perfect solution would be to write a custom evaluation function that only returns accuracy for the validation nodes on the validation data, and training nodes for the training data. That is a lot of work (would take a tutorial in itself) to be a couple epochs closer to the most accurate stopping point. I choose to keep it simple and live with a marginally less accurate model.
es = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',mode='min',verbose=1,
patience=10,restore_best_weights=True)
node_model.fit(train_node_dataset.repeat(),
validation_data=full_node_dataset,
steps_per_epoch=10,
epochs=1000,
callbacks=[es])
Time to see how we did using node_model.predict(full_node_dataset) and printing the results on a map using magic (see Google Colab).

Overall, we had a respectable 88% accuracy (see Google Colab for model parameters). The model seems to have a harder time for the mountain states. Diving in deeper yields some interesting insights. For example, the model falsely predicted Utah to be in the Pac 10 conference. The following year, however, Utah did in fact join Pac 10. It is entirely possible that the model is correctly identifying how things should be and the ~12% error is really a measurement of human inconsistency when creating conferences. Another way to think about it is with a social network of friends. If the network predicts two people are friends when they have never met, is the model wrong or are they a good match to be friends? For many (or a majority) of graph problems, these "errors" are what you are really trying to find. They can then be used to recommend products to buy, movies to watch, people you should connect with, etc.
For this case, let’s assume the data is perfect and we are interested in classification accuracy. To really know how well we did, we will need to test the accuracy on our holdout data. To do this, we will make a prediction on the full dataset and filter down to the holdout nodes.
def evaluate_node():
### Add raw prediction ####
yhat = node_model.predict(full_node_dataset)
yhat_df = node_full_adj.set_index('school').iloc[:,-12:].copy()
yhat_df.iloc[:,:] = yhat
### Classify max of softmax output ###
yhat_df = yhat_df.apply(lambda x: x == x.max(), axis=1).astype(int)
### Merge output back to single column ###
yhat_df = yhat_df.dot(yhat_df.columns).to_frame().rename(columns={0:'conf_yhat'})
yhat_df = yhat_df['conf_yhat'].str.replace('conf_', '').astype(int).to_frame()
yhat_df['conf_actual'] = node_full_adj['conference']
### Filter down to test nodes ###
yhat_df = yhat_df.loc[yhat_df.index.isin(params['testset'].index)]
### Calculate accuracy ###
yhat_df['Accuracy'] = yhat_df['conf_yhat']==yhat_df['conf_actual']
return yhat_df['Accuracy'].mean()
For this model, accuracy drops to ~72% (Don’t Panic, a drop is expected on a holdout dataset). Given the limited feature engineering, only one year of data, and 12 output predictions – those results are reasonable. Upon visual inspection of the maps below (and comparing to the full map above), most of the errors seem like decent guesses.

Edge Model
Now we will try to predict if a specific game is an in-conference game. We already defined our edge datasets above and most of the steps can be reused with only one change:
### Change to train_edge_dataset ###
graph_spec = train_edge_dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
We do need to make a few changes to the graph updates though. First, we need to add an ‘edge_sets’ update to our ‘GraphUpdate’ function. Leaving in the ‘node_sets’ update is optional but the model does seem to do better for me when I keep it in. Next, we will switch from a GCN to a Graph Nets approach. This method treats edges as first-class citizens (i.e. a fancy way of saying they will learn their own weights which is what we are after). Finally, we need to update ‘logits’ to be a one unit sigmoid activation dense layer since we are predicting a dummy variable.
graph_updates = 3
for i in range(graph_updates):
graph = tfgnn.keras.layers.GraphUpdate(
edge_sets = {'games': tfgnn.keras.layers.EdgeSetUpdate(
next_state = tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64,activation='relu')))},
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.Pool(
tag=tfgnn.TARGET,
reduce_type="sum",
feature_name = tfgnn.HIDDEN_STATE)},
tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64)))})(graph)
logits = tf.keras.layers.Dense(1,activation='sigmoid')(graph.edge_sets['games'][tfgnn.HIDDEN_STATE])
edge_model = tf.keras.Model(input_graph, logits)
We compile the model using ‘binary_crossentropy’ this time.
edge_model.compile(
tf.keras.optimizers.Adam(learning_rate=0.01),
loss = 'binary_crossentropy',
metrics = ['Accuracy']
)
edge_model.summary()
And we fit the model using the same callback defined in our node problem.
edge_model.fit(train_edge_dataset.repeat(),
validation_data=full_edge_dataset,
steps_per_epoch=10,
epochs=1000,
callbacks=[es])
yhat = edge_model.predict(full_edge_dataset)
yhat_df = edge_full_adj.copy().set_index(['source','target'])
yhat_df['conf_game_yhat'] = yhat.round(0)
yhat_df = yhat_df.loc[yhat_df.index.isin(
edge_test.set_index(['source','target']).index)]
yhat_df['loss'] = abs(yhat_df['conference_game'] - yhat_df['conf_game_yhat'])
loss = yhat_df['loss'].mean()
print("edge accuracy:",1 - loss)
When evaluated on the holdout dataset, we get 85% accuracy compared to a 56% mean. The model did its job and I am satisfied with those results.
Context Model
This particular problem does not have a context value. Let’s imagine that we sliced the graph above so we had a separate graph for each conference. These new graphs would show every game played for teams in the conference and ignore all other games. We could then have values for each graph for how the conference was ranked. Now we can train a model to make context-level predictions.
First, we need to add our context values to the graph.
graph_tensor = tfgnn.GraphTensor.from_pieces(
context = tfgnn.Context.from_fields(
features ={
<context_feature>
}),
node_sets = {
...
Next we need to create a new dataset with the context mapped to the label.
def node_batch_merge(graph):
graph = graph.merge_batch_to_components()
context_features = graph.context.get_features_dict()
label = context_features.pop('<context_feature>')
new_graph = graph.replace_features(
context=context_features)
return new_graph, label
We have the ability to set our initial context state. In this case, we are predicting this feature so it will be absent from from our training data. For other models, context may be a trainable feature and can be set like so:
def set_initial_context_state(context):
return tf.keras.layers.Dense(32,activation="relu")(context['<context_feature>'])
graph = tfgnn.keras.layers.MapFeatures(
context_fn=set_initial_context_state,
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
Again, we can optionally add a context update to the ‘GraphUpdate’ (see below). I have not tested this method so feel free to experiment.
graph = tfgnn.keras.layers.GraphUpdate(
node_sets ={...},
context = tfgnn.keras.layers.ContextUpdate({
'schools': tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean")},
tfgnn.keras.layers.NextStateFromConcat(tf.keras.layers.Dense(128))))
Finally, we update our ‘logits’ for a context prediction
logits = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean",
node_set_name="schools")(graph)
Troubleshooting errors
I ran into many errors and poorly trained models trying to figure out the code above. While I tried to keep things generic enough to apply to many different problems, you will no doubt run into errors as you make adjustments for your data. The trick is to identify the source of your error. The best way I found to diagnose errors was to create a graph schema.
In our code above, we pulled the graph schema from our dataset. You can, however, build a graph schema directly. For our football example, the graph schema would look like this:
graph_spec = tfgnn.GraphTensorSpec.from_piece_specs(
context_spec = tfgnn.ContextSpec.from_field_specs(
features_spec ={
#Added as an example for context problems
#"conf_rank": tf.TensorSpec(shape=(None,1), dtype=tf.float32),
}),
node_sets_spec={
'schools':
tfgnn.NodeSetSpec.from_field_specs(
features_spec={
'Latitude': tf.TensorSpec((None, 1), tf.float32),
'Longitude': tf.TensorSpec((None, 1), tf.float32),
'Rank': tf.TensorSpec((None, 1), tf.int32),
'Wins': tf.TensorSpec((None, 1), tf.int32),
'Conf_wins': tf.TensorSpec((None, 1), tf.int32),
'conference': tf.TensorSpec((None, 12), tf.int32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32))
},
edge_sets_spec={
'games':
tfgnn.EdgeSetSpec.from_field_specs(
features_spec={
'name_sim_score': tf.TensorSpec((None, 1), tf.float32),
'euclidean_dist': tf.TensorSpec((None, 1), tf.float32),
'conference_game': tf.TensorSpec((None, 1), tf.int32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32),
adjacency_spec=tfgnn.AdjacencySpec.from_incident_node_sets(
'schools', 'schools'))
})
We can test if our ‘graph_spec’ is at least valid by attempting to build and compile the model. If you get an error, there is likely an issue with your feature shapes or your ‘setinitial…’ functions. If it works, you can verify that the schema you created is compatible with your ‘graph_tensor’.
graph_spec.is_compatible_with(full_tensor)
If false, you can print out ‘full_tensor.spec’ and ‘graph_spec’ to compare each piece to ensure the shapes and dtypes are exactly the same. You can also create a randomly generated graph tensor directly from the ‘graph_spec’.
random_graph = tfgnn.random_graph_tensor(graph_spec)
With this ‘random_graph’ you can attempt to train a model. This should help you determine if your error is with the spec or the model code. If you do not get any errors, you can print the values of the ‘random_graph’ to see how the outputs compare to your ‘graph_tensor’.
print("Nodes:",random_graph.node_sets['schools'].features)
print("Edges:",random_graph.edge_sets['games'].features)
print("Context:",random_graph.context.features)
These steps should allow you to track down the majority of issues you run into.
Parameter tuning
We have successfully fixed any errors we were having and trained a model. Now we want to tune our hyperparameters for an accurate model. My tuner of choice is the Hyperopt library because of its ease of use and integrated Bayesian optimization. But first we want to convert our modeling code above to a class with variables.
class GCNN:
def __init__(self,params):
self.params = params
def set_initial_node_state(self, node_set, node_set_name):
features = [
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Latitude']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Longitude']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Rank']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Wins']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Conf_wins'])
]
return tf.keras.layers.Concatenate()(features)
def set_initial_edge_state(self, edge_set, edge_set_name):
features = [
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(edge_set['name_sim_score']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(edge_set['euclidean_dist'])
]
return tf.keras.layers.Concatenate()(features)
def dense_layer(self,units=64):
regularizer = tf.keras.regularizers.l2(self.params['l2_reg'])
return tf.keras.Sequential([
tf.keras.layers.Dense(units,
kernel_regularizer=regularizer,
bias_regularizer=regularizer,
activation='relu'),
tf.keras.layers.Dropout(self.params['dropout'])])
def build_model(self):
input_graph = tf.keras.layers.Input(type_spec=self.params['trainset'].element_spec[0])
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=self.set_initial_node_state,
edge_sets_fn=self.set_initial_edge_state
)(input_graph)
if self.params['loss']=='categorical_crossentropy':
for i in range(self.params['graph_updates']):
graph = tfgnn.keras.layers.GraphUpdate(
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = self.dense_layer(self.params['message_dim']),
reduce_type="sum",
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))})(graph)
logits = tf.keras.layers.Dense(12,activation='softmax')(graph.node_sets['schools'][tfgnn.HIDDEN_STATE])
else:
for i in range(self.params['graph_updates']):
graph = tfgnn.keras.layers.GraphUpdate(
edge_sets = {'games': tfgnn.keras.layers.EdgeSetUpdate(
next_state = tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))},
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = self.dense_layer(self.params['message_dim']),
reduce_type="sum",
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))})(graph)
logits = tf.keras.layers.Dense(1,activation='sigmoid')(graph.edge_sets['games'][tfgnn.HIDDEN_STATE])
return tf.keras.Model(input_graph, logits)
def train_model(self,trial=True):
model = self.build_model()
model.compile(tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate']),
loss=self.params['loss'],
metrics=['Accuracy'])
callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss',
mode='min',
verbose=1,
patience=self.params['patience'],
restore_best_weights=True)]
model.fit(self.params['trainset'].repeat(),
validation_data=self.params['full_dataset'],
steps_per_epoch=self.params['steps_per_epoch'],
epochs=self.params['epochs'],
verbose=0,
callbacks = callbacks)
loss = self.evaluate_model(model,trial=trial)
if trial == True:
sys.stdout.flush()
hypt_params = {
'graph_updates':self.params['graph_updates'],
'feature_dim':self.params['feature_dim'],
'next_state_dim':self.params['next_state_dim'],
'message_dim':self.params['message_dim'],
'l2_reg':self.params['l2_reg'],
'dropout':self.params['dropout'],
'learning_rate':self.params['learning_rate']}
print(hypt_params,'loss:',loss)
return {'loss': loss, 'status': STATUS_OK}
else:
print('loss:',loss)
return model
def evaluate_model(self,model,trial=True):
if self.params['loss'] == 'categorical_crossentropy':
yhat = model.predict(full_node_dataset)
yhat_df = node_full_adj.set_index('school').iloc[:,-12:].copy()
yhat_df.iloc[:,:] = yhat
yhat_df = yhat_df.apply(lambda x: x == x.max(), axis=1).astype(int)
yhat_df = yhat_df.dot(yhat_df.columns).to_frame().rename(columns={0:'conf_yhat'})
yhat_df = yhat_df['conf_yhat'].str.replace('conf_', '').astype(int).to_frame()
yhat_df['conf_actual'] = node_full_adj.set_index('school')['conference']
yhat_df = yhat_df.loc[yhat_df.index.isin(node_test.index)]
yhat_df['Accuracy'] = yhat_df['conf_yhat']==yhat_df['conf_actual']
loss = 1 - yhat_df['Accuracy'].mean()
else:
yhat = model.predict(full_edge_dataset)
yhat_df = edge_full_adj.copy().set_index(['source','target'])
yhat_df['conf_game_yhat'] = yhat.round(0)
yhat_df = yhat_df.loc[yhat_df.index.isin(
edge_test.set_index(['source','target']).index)]
yhat_df['loss'] = abs(yhat_df['conference_game'] - yhat_df['conf_game_yhat'])
loss = yhat_df['loss'].mean()
return loss
Now we define our parameters. For our tuning parameters, we can either expressly define the value (e.g. ‘dropout’: 0.1) or define the space for Hyperopt to experiment with as I did below. ‘hp.choice’ will choose between options you specify while ‘hp.uniform’ will pick options between two values. There are many other options available in the Hyperopt documentation.
params = {
### Tuning parameters ###
'graph_updates': hp.choice('graph_updates',[2,3,4]),
'feature_dim': hp.choice('feature_dim',[16,32,64,128]),
'message_dim': hp.choice('message_dim',[16,32,64,128]),
'next_state_dim': hp.choice('next_state_dim',[16,32,64,128]),
'l2_reg': hp.uniform('l2_reg',0.0,0.3),
'dropout': hp.choice('dropout',[0,0.125,0.25,0.375,0.5]),
'learning_rate': hp.uniform('learning_rate',0.0,0.1),
### Static parameters ###
'loss': 'categorical_crossentropy',
'epochs': 1000,
'steps_per_epoch':10, ### This could also be a tuned parameter
'patience':10,
'trainset':train_node_dataset,
'full_dataset':full_node_dataset
}
Next, we define a helper function and plug it into ‘fmin’ along with our parameters. Each evaluation is a trained model so this can take a while depending on your hardware. Consider doing fewer ‘max_evals’ if it is too slow for you. My personal rule of thumb is ~15 evaluations per tuned parameter so I would expressly define some of the parameters in relation to the drop in the number of evaluations.
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
def tune_model(params):
return GCNN(params).train_model()
best = fmin(tune_model, params, algo=tpe.suggest,
max_evals=100, trials=Trials())
Now that we have our best hyperparameters, we can train our final model (NOTE: your accuracy will be slightly different due to how TensorFlow randomly initializes its weights).
### Perameters from my hyperopt run ###
best = {'graph_updates': 4,
'feature_dim': 64,
'next_state_dim': 32,
'message_dim': 128,
'l2_reg': 0.095,
'dropout': 0,
'learning_rate': 0.0025
}
node_params = params
for param, value in best.items():
node_params[param] = value
node_model = GCNN(node_params).train_model(trial=False)
We can tune and train our edge model with a few slight adjustments:
params['loss'] = 'binary_crossentropy'
params['trainset'] = train_edge_dataset
params['full_dataset'] = full_edge_dataset
best = fmin(tune_model, params, algo=tpe.suggest,
max_evals=100, trials=Trials())
### Perameters from my hyperopt run ###
best = {'graph_updates': 4,
'feature_dim': 64,
'next_state_dim': 32,
'message_dim': 128,
'l2_reg': 0.095,
'dropout': 0,
'learning_rate': 0.0025
}
edge_params = params
for param, value in best.items():
edge_params[param] = value
edge_model = GCNN(edge_params).train_model(trial=False)
Final thoughts
GNN research is still in its infancy. New modeling methods are likely to be discovered. With TF-GNN still in an alpha state, there is a good chance there may be some code changes over the years. Please comment below if you find changes or errors that I have not fixed yet and I will update this guide the best I can. If you did not like this article, feel free to draw an analogy between me and your favorite historical dictator in the comments. Otherwise, a clap or nice comment would be appreciated.
My hope is that this guide can be a starting place for more people to enter this field and experiment. Consider this your opportunity to be at the beginning of the next AI wave!
About me
I am a senior data scientist and part-time freelancer with over 12 years of experience. I am always looking to connect so please feel free to:
- Connect with me on LinkedIn
- Follow me on Twitter
- Visit my website: www.modelforge.ai
- See my other articles
Please feel free to comment below if you have any questions.