Graph Convolutional Networks: Introduction to GNNs

Graph Neural Network Course: Chapter 1

graph neural networks
Author

Maxime Lbonne

Published

February 20, 2022

Find more a lot more architectures and applications using graph neural networks in my book, Hands-On Graph Neural Networks 👇 Hands-On Graph Neural Networks Using Python

Graph Neural Networks (GNNs) represent one of the most captivating and rapidly evolving architectures within the deep learning landscape. As deep learning models designed to process data structured as graphs, GNNs bring remarkable versatility and powerful learning capabilities.

Among the various types of GNNs, the Graph Convolutional Networks (GCNs) have emerged as the most prevalent and broadly applied model. GCNs are innovative due to their ability to leverage both the features of a node and its locality to make predictions, providing an effective way to handle graph-structured data.

In this article, we will delve into the mechanics of the GCN layer and explain its inner workings. Furthermore, we will explore its practical application for node classification tasks, using PyTorch Geometric as our tool of choice.

PyTorch Geometric is a specialized extension of PyTorch that has been created specifically for the development and implementation of GNNs. It is an advanced, yet user-friendly library that provides a comprehensive suite of tools to facilitate graph-based machine learning. To commence our journey, the PyTorch Geometric installation will be required. If you are using Google Colab, PyTorch should already be in place, so all we need to do is execute a few additional commands.

All the code is available on Google Colab and GitHub.

# Install PyTorch Geometric
import torch
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Numpy for matrices
import numpy as np

# Visualization libraries
import matplotlib.pyplot as plt
import networkx as nx

Now that PyTorch Geometric is installed, let’s explore the dataset we will use in this tutorial.

🌐 I. Graph data

Graphs are a nonlinear type of data you can find everywhere: social networks, computer networks, molecules, text, images, and so on. In this article, we will study the infamous and much-used Zachary’s karate club dataset.

Zachary’s karate club represents the relationships within a karate club studied by Wayne W. Zachary in the 1970s. It is a kind of social network, where every node is a member, and members who interacted outside the club are connected to each other.

In this example, the club is divided into four groups: we would like to assign the right group to every member (node classification) just by looking at their connections.

Let’s import the dataset with PyG’s built-in function and try to understand the Datasets object it uses.

from torch_geometric.datasets import KarateClub

# Import dataset from PyTorch Geometric
dataset = KarateClub()

# Print information
print(dataset)
print('------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
KarateClub()
------------
Number of graphs: 1
Number of features: 34
Number of classes: 4

This dataset only has 1 graph, where each node has a feature vector of 34 dimensions and is part of one out of four classes (our four groups). Actually, the Datasets object can be seen as a collection of Data (graph) objects.

We can further inspect our unique graph to know more about it.

# Print first element
print(f'Graph: {dataset[0]}')
Graph: Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

The Data object is particularly interesting. Printing it offers a good summary of the graph we’re studying:

  • x=[34, 34] is the node feature matrix with shape (number of nodes, number of features). In our case, it means that we have 34 nodes (our 34 members), each node being associated to a 34-dim feature vector.
  • edge_index=[2, 156] represents the graph connectivity (how the nodes are connected) with shape (2, number of directed edges).
  • y=[34] is the node ground-truth labels. In this problem, every node is assigned to one class (group), so we have one value for each node.
  • train_mask=[34] is an optional attribute that tells which nodes should be used for training with a list of True or False statements.

Let’s print each of these tensors to understand what they store. Let’s start with the node features.

data = dataset[0]

print(f'x = {data.x.shape}')
print(data.x)
x = torch.Size([34, 34])
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

Here, the node feature matrix x is an identity matrix: it doesn’t contain any relevant information about the nodes. It could contain information like age, skill level, etc. but this is not the case in this dataset. It means we’ll have to classify our nodes just by looking at their connections.

Now, let’s print the edge index.

print(f'edge_index = {data.edge_index.shape}')
print(data.edge_index)
edge_index = torch.Size([2, 156])
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

The edge_index has a quite counter-intuitive way of storing the graph connectivity. Here, we have two lists of 156 directed edges (78 bidirectional edges) because the first list contains the sources and the second one the destinations. It is called a coordinate list (COO) and is just one way of efficiently storing a sparse matrix.

A more intuitive way to represent the graph connectivity would be a simple adjacency matrix \(A\), where a non-zero element \(A_{ij}\) indicates a connection from \(i\) to \(j\).

The adjacency matrix can be inferred from the edge_index with a utility function.

from torch_geometric.utils import to_dense_adj

A = to_dense_adj(data.edge_index)[0].numpy().astype(int)
print(f'A = {A.shape}')
print(A)
A = (34, 34)
[[0 1 1 ... 1 0 0]
 [1 0 1 ... 0 0 0]
 [1 1 0 ... 0 1 0]
 ...
 [1 0 0 ... 0 1 1]
 [0 0 1 ... 1 0 1]
 [0 0 0 ... 1 1 0]]

With graph data, nodes are rarely highly interconnected. For example, our adjacency matrix \(A\) is very sparse (filled with zeros). Storing so many zeros is not efficient at all, which is why the COO format is adopted by PyG.

On the contrary, ground-truth labels are easy to understand.

print(f'y = {data.y.shape}')
print(data.y)
y = torch.Size([34])
tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
        2, 2, 0, 0, 2, 0, 0, 2, 0, 0])

Our node ground-truth labels stored in y simply encode the group number (0, 1, 2, 3) for each node, which is why we have 34 values.

Finally, let’s print the train mask.

print(f'train_mask = {data.train_mask.shape}')
print(data.train_mask)
train_mask = torch.Size([34])
tensor([ True, False, False, False,  True, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False])

The train mask shows which nodes are supposed to be used for training with True statements. These nodes represent the training set, while the others can be considered as the test set.

But we’re not done yet! The Data object has a lot more to offer: many graph properties can be checked using utility functions. For example:

  • is_directed() tells you if the graph is directed, which means that the adjacency matrix is not symmetric
  • isolated_nodes() checks if some nodes are not connected to the rest of the graph (and will probably be harder to classify)
  • has_self_loops() indicates if at least one node is connected to itself. This is not the same as loops: loops mean that you can take a path that starts and ends at the same node.

All of these properties return False for Zachary’s karate club.

print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False

Finally, we can convert a graph from PyTorch Geometric to the popular graph library NetworkX using to_networkx. This is particularly useful to visualize a small graph with NetworkX and Matplotlib.

Let’s plot our dataset with a different color for each group.

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(12,12))
plt.axis('off')
nx.draw_networkx(G,
                pos=nx.spring_layout(G, seed=0),
                with_labels=True,
                node_size=800,
                node_color=data.y,
                cmap="hsv",
                vmin=-2,
                vmax=3,
                width=0.8,
                edge_color="grey",
                font_size=14
                )
plt.show()

This plot of Zachary’s karate club displays our 34 nodes, 78 (bidirectional) edges, and 4 labels with 4 different colors. Now that we’ve seen the essentials of loading and handling a dataset with PyTorch Geometric, we can introduce the Graph Convolutional Network.

✉️ II. Graph Convolutional Network

In this section, let’s try to redesign the graph convolutional layer from scratch.

In neural networks, linear layers apply a linear transformation to the incoming data. They transform input features \(x\) into hidden vectors \(h\) using a weight matrix \(W\). If we ignore biases, we can write:

\[h = \mathbf{W} x\]

With graph data, we have access to connections between nodes. Why is that relevant? In most networks, we make the hypothesis that similar nodes are more likely to be connected to each other than dissimilar ones (it’s called network homophily).

We can enrich our node representation by aggregating its features with those of its neighbors. This operation is called convolution, or neighborhood aggregation. Let’s denote \(\tilde{\mathcal{N}}_i\) the neighborhood of node \(i\) including itself.

\[h_i = \sum_{j \in \tilde{\mathcal{N}}_i} \mathbf{W} x_j\]

Unlike filters in Convolutional Neural Networks (CNNs), our weight matrix \(W\) is unique and shared among every node. But there is another issue: nodes do not have a fixed number of neighbors like pixels do.

What if one node only has 1 neighbor, and another one has 500 of them? We would add 500 values instead of just one: the resulting embedding \(h\) would be much larger for the node with 500 neighbors.

However, this doesn’t make sense: nodes should always be comparable, so they need to have a similar range of values. To address this issue, we can normalize the result based on the number of connections. In graph theory, this number is called a degree.

\[h_i = \dfrac{1}{\deg(i)} \sum_{j \in \tilde{\mathcal{N}}_i} \mathbf{W} x_j\]

We’re almost there! Introduced by Kipf et al. in 2016, the graph convolutional layer has one final improvement.

Indeed, the authors noticed that features from nodes with a lot of neighbors will spread much more easily than those from more isolated nodes. To counterbalance this effect, they proposed to give bigger weights to features from nodes with few neighbors. This operation can be written as follows:

\[h_i = \sum_{j \in \tilde{\mathcal{N}}_i} \dfrac{1}{\sqrt{\deg(i)}\sqrt{\deg(j)}} \mathbf{W} x_j\]

Notice that when \(i\) and \(j\) have the same number of neighbors, it is equivalent to our own layer. Now, let’s see how to implement it in Python.

🧠 III. Implementing a GCN

PyTorch Geometric directly implements the graph convolutional layer using GCNConv.

In this example, we will create a simple GCN with only one GCN layer, a ReLU activation function, and one linear layer. This final layer will output four values, corresponding to our four groups. The highest value will determine the class of each node.

In the following code block, we define the GCN layer with a 3-dim hidden layer.

from torch.nn import Linear
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn = GCNConv(dataset.num_features, 3)
        self.out = Linear(3, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.gcn(x, edge_index).relu()
        z = self.out(h)
        return h, z

model = GCN()
print(model)
GNN(
  (gcn): GCNConv(34, 3)
  (out): Linear(in_features=3, out_features=4, bias=True)
)

If we added a second GCN layer, our model would not only aggregate feature vectors from the neighbors of each node, but also from the neighbors of these neighbors.

We can stack several graph layers to aggregate more and more distant values, but there’s a catch: if we add too many layers, the aggregation becomes so intense that all the embeddings end up looking the same. This phenomenon is called over-smoothing and can be a real problem when you have too many layers.

Now that we’ve defined our GNN, let’s write a simple training loop with PyTorch. I chose a regular cross-entropy loss since it’s a multi-class classification task, with Adam as optimizer. We could use the training mask, but we will ignore it for this exploratory exercise.

The training loop is standard: we try to predict the correct labels, and we compare the GCN’s results to the values stored in data.y. The error is calculated by the cross-entropy loss and backpropagated with Adam to fine-tune our GNN’s weights and biases. Finally, we print metrics every 10 epochs.

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

# Calculate accuracy
def accuracy(pred_y, y):
    return (pred_y == y).sum() / len(y)

# Data for animations
embeddings = []
losses = []
accuracies = []
outputs = []

# Training loop
for epoch in range(201):
    # Clear gradients
    optimizer.zero_grad()

    # Forward pass
    h, z = model(data.x, data.edge_index)

    # Calculate loss function
    loss = criterion(z, data.y)

    # Calculate accuracy
    acc = accuracy(z.argmax(dim=1), data.y)

    # Compute gradients
    loss.backward()

    # Tune parameters
    optimizer.step()

    # Store data for animations
    embeddings.append(h)
    losses.append(loss)
    accuracies.append(acc)
    outputs.append(z.argmax(dim=1))

    # Print metrics every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')
Epoch   0 | Loss: 1.35 | Acc: 38.24%
Epoch  10 | Loss: 1.21 | Acc: 38.24%
Epoch  20 | Loss: 1.08 | Acc: 41.18%
Epoch  30 | Loss: 0.92 | Acc: 70.59%
Epoch  40 | Loss: 0.72 | Acc: 73.53%
Epoch  50 | Loss: 0.54 | Acc: 88.24%
Epoch  60 | Loss: 0.41 | Acc: 88.24%
Epoch  70 | Loss: 0.33 | Acc: 88.24%
Epoch  80 | Loss: 0.29 | Acc: 88.24%
Epoch  90 | Loss: 0.26 | Acc: 88.24%
Epoch 100 | Loss: 0.24 | Acc: 88.24%
Epoch 110 | Loss: 0.23 | Acc: 88.24%
Epoch 120 | Loss: 0.22 | Acc: 88.24%
Epoch 130 | Loss: 0.22 | Acc: 88.24%
Epoch 140 | Loss: 0.21 | Acc: 88.24%
Epoch 150 | Loss: 0.20 | Acc: 88.24%
Epoch 160 | Loss: 0.20 | Acc: 91.18%
Epoch 170 | Loss: 0.19 | Acc: 97.06%
Epoch 180 | Loss: 0.17 | Acc: 100.00%
Epoch 190 | Loss: 0.14 | Acc: 100.00%
Epoch 200 | Loss: 0.12 | Acc: 100.00%

Great! Without much surprise, we reach 100% accuracy on the training set. It means that our model learned to correctly assign every member of the karate club to its correct group.

We can produce a neat visualization by animating the graph and see the evolution of the GNN’s predictions during the training process.

%%capture
from IPython.display import HTML
from matplotlib import animation
plt.rcParams["animation.bitrate"] = 3000

def animate(i):
    G = to_networkx(data, to_undirected=True)
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=True,
                    node_size=800,
                    node_color=outputs[i],
                    cmap="hsv",
                    vmin=-2,
                    vmax=3,
                    width=0.8,
                    edge_color="grey",
                    font_size=14
                    )
    plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
              fontsize=18, pad=20)

fig = plt.figure(figsize=(12, 12))
plt.axis('off')

anim = animation.FuncAnimation(fig, animate, \
            np.arange(0, 200, 10), interval=500, repeat=True)
html = HTML(anim.to_html5_video())
display(html)

The first predictions are random, but the GCN perfectly labels every node after a while. Indeed, the final graph is the same as the one we plotted at the end of the first section. But what does the GCN really learn?

By aggregating features from neighboring nodes, the GNN learns a vector representation (or embedding) of every node in the network. In our model, the final layer just learns how to use these representations to produce the best classifications. However, embeddings are the real products of GNNs.

Let’s print the embeddings learned by our model.

# Print embeddings
print(f'Final embeddings = {h.shape}')
print(h)
Final embeddings = torch.Size([34, 3])
tensor([[2.3756e+00, 5.1330e-01, 0.0000e+00],
        [3.2511e+00, 1.4347e+00, 0.0000e+00],
        [2.0562e+00, 1.5209e+00, 0.0000e+00],
        [2.9461e+00, 1.1436e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.4295e+00, 1.0172e+00, 0.0000e+00],
        [6.3575e-01, 2.6594e+00, 0.0000e+00],
        [1.9876e+00, 1.3767e+00, 0.0000e+00],
        [0.0000e+00, 6.0713e-04, 0.0000e+00],
        [2.2577e+00, 1.1747e+00, 0.0000e+00],
        [2.3823e+00, 1.1449e+00, 0.0000e+00],
        [2.2940e+00, 1.3844e+00, 0.0000e+00],
        [1.9155e-01, 2.7673e+00, 0.0000e+00],
        [1.8206e-01, 2.7194e+00, 0.0000e+00],
        [0.0000e+00, 2.0684e-03, 0.0000e+00],
        [2.3367e+00, 1.1042e+00, 0.0000e+00],
        [1.7925e-01, 2.7942e+00, 0.0000e+00],
        [2.0630e+00, 1.4096e+00, 0.0000e+00],
        [1.9360e-01, 2.7587e+00, 0.0000e+00],
        [2.2845e+00, 1.1088e+00, 0.0000e+00],
        [1.8486e-01, 2.7376e+00, 0.0000e+00],
        [0.0000e+00, 2.8447e+00, 0.0000e+00],
        [0.0000e+00, 8.9724e-01, 0.0000e+00],
        [0.0000e+00, 9.5606e-01, 0.0000e+00],
        [2.1157e-01, 2.8055e+00, 0.0000e+00],
        [2.6385e-01, 2.4765e+00, 0.0000e+00],
        [2.9965e-01, 8.5145e-01, 0.0000e+00],
        [0.0000e+00, 3.3316e+00, 0.0000e+00],
        [4.0497e-01, 2.8716e+00, 0.0000e+00],
        [0.0000e+00, 6.8132e-01, 0.0000e+00],
        [0.0000e+00, 4.1963e+00, 0.0000e+00],
        [0.0000e+00, 3.8991e+00, 0.0000e+00]], grad_fn=<ReluBackward0>)

As you can see, embeddings do not need to have the same dimensions as feature vectors. Here, I chose to reduce the number of dimensions from 34 (dataset.num_features) to three to get a nice visualization in 3D.

Let’s plot these embeddings before any training happens, at epoch 0.

# Get first embedding at epoch = 0
embed = h[0].detach().cpu().numpy()

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.patch.set_alpha(0)
plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
           s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)

plt.show()

We see every node from Zachary’s karate club with their true labels (and not the model’s predictions). For now, they’re all over the place since the GNN is not trained yet. But if we plot these embeddings at each step of the training loop, we’d be able to visualize what the GNN truly learns.

Let’s see how they evolve over time, as the GCN gets better and better at classifying nodes.

%%capture

def animate(i):
    embed = embeddings[i].detach().cpu().numpy()
    ax.clear()
    ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
           s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
    plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
              fontsize=18, pad=40)

fig = plt.figure(figsize=(12, 12))
plt.axis('off')
ax = fig.add_subplot(projection='3d')
plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)

anim = animation.FuncAnimation(fig, animate, \
              np.arange(0, 200, 10), interval=800, repeat=True)
html = HTML(anim.to_html5_video())
display(html)

We see that our GCN learned embeddings that group nodes from the same classes into nice clusters. Then, the final layer can easily separate them into different classes.

Embeddings are not unique to GNNs: they can be found everywhere in deep learning. They don’t have to be 3D either: actually, they rarely are. For instance, language models like BERT produce embeddings with 768 or even 1024 dimensions.

Additional dimensions store more information about nodes, text, images, etc. but they also create bigger models that are more difficult to train. This is why it’s better to keep low-dimensional embeddings as long as possible.

Conclusion

Graph Convolutional Networks are an incredibly versatile architecture that can be applied in many contexts. In this article,

  • We learned to use the PyTorch Geometric library to explore graph data with the Datasets and Data objects
  • We redesigned a graph convolutional layer from scratch
  • We implemented a GNN with a GCN layer
  • We visualized what training means for a GCN

Zachary’s karate club is a simplistic dataset, but it is good enough to understand the most important concepts in graph data and GNNs.

We only talked about node classification in this article, but there are other tasks GNNs can accomplish: link prediction (e.g., to recommend a friend), graph classification (e.g., to label molecules), graph generation (e.g., to create new molecules), and so on.

Beyond GCN, numerous GNN layers and architectures have been proposed by researchers. In the next article, we’re gonna talk about Graph Attention Networks (GATs), which implicitly compute GCN’s normalization factor and the importance of each connection with an attention mechanism.

If you enjoyed this article, feel free to follow me on Twitter for more GNN content. Thank you and have a great day! 📣

🌐 Graph Neural Network Course

🔎 Course overview

📝 Chapter 1: Introduction to Graph Neural Networks

📝 Chapter 2: Graph Attention Network

📝 Chapter 3: GraphSAGE

📝 Chapter 4: Graph Isomorphism Network