Graph Neural Networks

By David Rose

David Rose
5 min readJul 5, 2021

The TL;DR

What?

A graph is a method of representing a network data and the connections between. In the most basic form you have two parts that make up a graph: nodes and edges.

The nodes represent the samples of data and the edges represent some sort of link between them. Sometimes the link (edge) can be a single obvious property connecting multiple objects such as:

  • Distances between cities in miles
  • Friendships on Facebook
  • Citations between papers on Arxiv

Or sometimes the network can be connected by multiple relevant attributes. With a supply chain network you will have routes that can be described by:

  • The trucks that deliver on routes between warehouses
  • The type of route (sea, land, air)
  • The average time or cost to transfer between locations

Why?

Compared to tabular datasets there is no assumption of IID, rather the intent with graphs is the express purpose of the samples of data being related in some way or another.

How?

We can transition a graph to the form of a typical machine learning problem by giving both the nodes and the edges their own features, and then performing the task of classifying a label on a specific node or a graph as a whole. Maybe for COVID contact tracing we are wanting to find who may have had contact with a specific person (node) but we only have partial knowledge of contact so far (the labels true/false) based on the edge features (distance in where they live) and node features (demographics, job type, common locations visited).

Features can be anything you would normally use in a dataset:

  • numerical embeddings of words on a webpage
  • pixel values of images
  • one-hot encoding of categories

Some Examples of Graph Networks

The Task

Node Classification

One of the most common tasks performed with GNNs. The basic idea is that we can take a specific reference node, in this case node A, and perform neighborhood aggregation on the surrounding nodes and the edges between them.

What are the network layers?

  • Nodes: Recurrent networks
  • Edges: feed-forward networks

What is the process?

Message passing: In performing the neighborhood aggregation, we pass messages (or embeddings) between the surrounding nodes in regards to our reference node A. This in effect causes the information embedded in the nodes and edges to began filtering through the network, where they begin to learn from their neighbors.

Diagram of classifying a node with arrows leading to the reference node from the neighbors
Source: https://web.stanford.edu/class/cs224w/slides/08-GNN.pdf

Simple Steps:

  1. Neighboring nodes pass their messages (embeddings) through the edge networks into the recurrent network on the reference node.
  2. The new embedding of the reference recurrent unit is updated by applying the recurrent function on the current embedding and a summation of the edge network outputs from neighboring nodes.
Source: https://medium.com/dair-ai/an-illustrated-guide-to-graph-neural-networks-d5564a551783

Then What?

Once you have performed this step a few times, we now have a set of new embeddings, different from than we began. So the nodes now have their original information, along with an aggregation of the information contained within their surrounding nodes. We can then take this information and and send them further along other layers in a pipeline, or sum up all the embeddings to get a vector H that represents the whole graph.

Source: https://medium.com/dair-ai/an-illustrated-guide-to-graph-neural-networks-d5564a551783

Math Notation

State of each Node

Each node is represented by the state of it’s neighborhood below:

  • x_v: The node feature
  • x_co[v]: Features of the edges connecting with v
  • h_ne[v]: Embeddings of the neighboring nodes of v
  • x_nv[v]: Features of the neighboring nodes of v
  • f: The transition function that projects these inputs into a d-dimensional space

State of Aggregate Nodes

H and X denote the concatenation of all the h and x values as an iterative update process.

Typical Sampling Process

Source: https://web.stanford.edu/class/cs224w/slides/08-GNN.pdf

Code Example

PyTorch or TensorFlow?

When importing dgl we can specify the backend to use with the environmental variable DGLBACKEND.

env: DGLBACKEND='pytorch'

Dataset (Reddit)

Using the Reddit dataset which has already been processed and ready for download. It is a collection of posts made during the month of September 2014. The label is the subreddit name for the node (post) and nodes are connected if the same user comments on both.

Sampling

We use the first 20 days for training and the remaining days for testing (with 30% used for validation).

Number of categories: 41 Node features dict_keys(['label', 'feat', 'test_mask', 'train_mask', 'val_mask']) Edge features dict_keys([]) Total nodes: 232,965 
Total edges: 114,615,892

The Model

Here we put together a simple two-layer Graph Convolutional Network (GCN). Each layer computes new node representations by aggregating neighbor information.

DGL layers work easily within PyTorch and can be stacked along with standard PyTorch layers.

Training

The training process is basically similar to any other PyTorch training loop.

epoch 0, loss: 3.755, val acc: 0.011 (best 0.011), test acc: 0.011 (best 0.011) 
epoch 5, loss: 2.773, val acc: 0.362 (best 0.362), test acc: 0.362 (best 0.362)
epoch 10, loss: 2.177, val acc: 0.598 (best 0.598), test acc: 0.595 (best 0.595)
epoch 15, loss: 1.716, val acc: 0.665 (best 0.665), test acc: 0.660 (best 0.660)
epoch 20, loss: 1.355, val acc: 0.759 (best 0.759), test acc: 0.754 (best 0.754)
epoch 25, loss: 1.083, val acc: 0.824 (best 0.824), test acc: 0.820 (best 0.820)
epoch 30, loss: 0.888, val acc: 0.857 (best 0.857), test acc: 0.853 (best 0.853)
epoch 35, loss: 0.756, val acc: 0.894 (best 0.894), test acc: 0.891 (best 0.891)
epoch 40, loss: 0.661, val acc: 0.902 (best 0.902), test acc: 0.899 (best 0.899)
epoch 45, loss: 0.593, val acc: 0.912 (best 0.912), test acc: 0.909 (best 0.909)

Save the Trained Graph

Once the model is trained we can easily save it and load back later on with the built-in functions.

Graph(num_nodes=232965, num_edges=114615892, ndata_schemes={'val_mask': Scheme(shape=(), dtype=torch.uint8), 'train_mask': Scheme(shape=(), dtype=torch.uint8), 'test_mask': Scheme(shape=(), dtype=torch.uint8), 'label': Scheme(shape=(), dtype=torch.int64), 'feat': Scheme(shape=(602,), dtype=torch.float32)} edata_schemes={})

--

--