Graph Convolutional Networks: Introduction to GNNs

Graph Neural Network Course: Chapter 1

graph neural networks
Author

Maxime Lbonne

Published

February 20, 2022

Find many more architectures and applications using graph neural networks in my book, Hands-On Graph Neural Networks ๐Ÿ‘‡ Hands-On Graph Neural Networks Using Python

Update

July 2023: the article and the code have been updated, it now works with Pytorch Geometric >2.0.

Graph Neural Networks (GNNs) represent one of the most captivating and rapidly evolving architectures within the deep learning landscape. As deep learning models designed to process data structured as graphs, GNNs bring remarkable versatility and powerful learning capabilities.

Among the various types of GNNs, the Graph Convolutional Networks (GCNs) have emerged as the most prevalent and broadly applied model. GCNs are innovative due to their ability to leverage both the features of a node and its locality to make predictions, providing an effective way to handle graph-structured data.

In this article, we will delve into the mechanics of the GCN layer and explain its inner workings. Furthermore, we will explore its practical application for node classification tasks, using PyTorch Geometric as our tool of choice.

PyTorch Geometric is a specialized extension of PyTorch that has been created specifically for the development and implementation of GNNs. It is an advanced, yet user-friendly library that provides a comprehensive suite of tools to facilitate graph-based machine learning. To commence our journey, the PyTorch Geometric installation will be required. If you are using Google Colab, PyTorch should already be in place, so all we need to do is execute a few additional commands.

All the code is available on Google Colab and GitHub.

!pip install torch_geometric

import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (23.1.2)
Requirement already satisfied: install in /usr/local/lib/python3.10/dist-packages (1.3.5)
Requirement already satisfied: torch_geometric in /usr/local/lib/python3.10/dist-packages (2.3.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (4.65.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.22.4)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.10.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.2)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2.27.1)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.0)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.2.2)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (5.9.5)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch_geometric) (2.1.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.4)
Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (1.3.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (3.2.0)

Now that PyTorch Geometric is installed, letโ€™s explore the dataset we will use in this tutorial.

๐ŸŒ I. Graph data

Graphs are an essential structure for representing relationships between objects. You can encounter graph data in a multitude of real-world scenarios, such as social and computer networks, chemical structures of molecules, natural language processing, and image recognition, to name a few.

In this article, we will study the infamous and much-used Zacharyโ€™s karate club dataset.

The Zacharyโ€™s karate club dataset embodies the relationships formed within a karate club as observed by Wayne W. Zachary during the 1970s. It is a kind of social network, where each node represents a club member, and edges between nodes represent interactions that occurred outside the club environment.

In this particular scenario, the members of the club are split into four distinct groups. Our task is to assign the correct group to each member (node classification), based on the pattern of their interactions.

Letโ€™s import the dataset with PyGโ€™s built-in function and try to understand the Datasets object it uses.

from torch_geometric.datasets import KarateClub

# Import dataset from PyTorch Geometric
dataset = KarateClub()

# Print information
print(dataset)
print('------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
KarateClub()
------------
Number of graphs: 1
Number of features: 34
Number of classes: 4

This dataset only has 1 graph, where each node has a feature vector of 34 dimensions and is part of one out of four classes (our four groups). Actually, the Datasets object can be seen as a collection of Data (graph) objects.

We can further inspect our unique graph to know more about it.

# Print first element
print(f'Graph: {dataset[0]}')
Graph: Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

The Data object is particularly interesting. Printing it offers a good summary of the graph weโ€™re studying:

  • x=[34, 34] is the node feature matrix with shape (number of nodes, number of features). In our case, it means that we have 34 nodes (our 34 members), each node being associated to a 34-dim feature vector.
  • edge_index=[2, 156] represents the graph connectivity (how the nodes are connected) with shape (2, number of directed edges).
  • y=[34] is the node ground-truth labels. In this problem, every node is assigned to one class (group), so we have one value for each node.
  • train_mask=[34] is an optional attribute that tells which nodes should be used for training with a list of True or False statements.

Letโ€™s print each of these tensors to understand what they store. Letโ€™s start with the node features.

data = dataset[0]

print(f'x = {data.x.shape}')
print(data.x)
x = torch.Size([34, 34])
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

Here, the node feature matrix x is an identity matrix: it doesnโ€™t contain any relevant information about the nodes. It could contain information like age, skill level, etc. but this is not the case in this dataset. It means weโ€™ll have to classify our nodes just by looking at their connections.

Now, letโ€™s print the edge index.

print(f'edge_index = {data.edge_index.shape}')
print(data.edge_index)
edge_index = torch.Size([2, 156])
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

In graph theory and network analysis, connectivity between nodes is stored using a variety of data structures. The edge_index is one such data structure, where the graphโ€™s connections are stored in two lists (156 directed edges, which equate to 78 bidirectional edges). The reason for these two lists is that one list stores the source nodes, while the second one identifies the destination nodes.

This method is known as a coordinate list (COO) format, which is essentially a means to efficiently store a sparse matrix. Sparse matrices are data structures that efficiently store matrices with a majority of zero elements. In the COO format, only non-zero elements are stored, saving memory and computational resources.

Contrarily, a more intuitive and straightforward way to represent graph connectivity is through an adjacency matrix \(A\). This is a square matrix where each element \(A_{ij}\) specifies the presence or absence of an edge from node \(i\) to node \(j\) in the graph. In other words, a non-zero element \(A_{ij}\) implies a connection from node \(i\) to node \(j\), and a zero indicates no direct connection.

An adjacency matrix, however, is not as space-efficient as the COO format for sparse matrices or graphs with fewer edges. However, for clarity and easy interpretation, the adjacency matrix remains a popular choice for representing graph connectivity.

The adjacency matrix can be inferred from the edge_index with a utility function to_dense_adj().

from torch_geometric.utils import to_dense_adj

A = to_dense_adj(data.edge_index)[0].numpy().astype(int)
print(f'A = {A.shape}')
print(A)
A = (34, 34)
[[0 1 1 ... 1 0 0]
 [1 0 1 ... 0 0 0]
 [1 1 0 ... 0 1 0]
 ...
 [1 0 0 ... 0 1 1]
 [0 0 1 ... 1 0 1]
 [0 0 0 ... 1 1 0]]

With graph data, it is relatively uncommon for nodes to be densely interconnected. As you can see, our adjacency matrix \(A\) is sparse (filled with zeros).

In many real-world graphs, most nodes are connected to only a few other nodes, resulting in a large number of zeros in the adjacency matrix. Storing so many zeros is not efficient at all, which is why the COO format is adopted by PyG.

On the contrary, ground-truth labels are easy to understand.

print(f'y = {data.y.shape}')
print(data.y)
y = torch.Size([34])
tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
        2, 2, 0, 0, 2, 0, 0, 2, 0, 0])

Our node ground-truth labels stored in y simply encode the group number (0, 1, 2, 3) for each node, which is why we have 34 values.

Finally, letโ€™s print the train mask.

print(f'train_mask = {data.train_mask.shape}')
print(data.train_mask)
train_mask = torch.Size([34])
tensor([ True, False, False, False,  True, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False])

The train mask shows which nodes are supposed to be used for training with True statements. These nodes represent the training set, while the others can be considered as the test set. This division helps in model evaluation by providing unseen data for testing.

But weโ€™re not done yet! The Data object has a lot more to offer. It provides various utility functions that enable the investigation of several properties of the graph. For instance:

  • is_directed() tells you if the graph is directed. A directed graph signifies that the adjacency matrix is not symmetric, i.e., the direction of edges matters in the connections between nodes.
  • isolated_nodes() checks if some nodes are not connected to the rest of the graph. These nodes are likely to pose challenges in tasks like classification due to their lack of connections.
  • has_self_loops() indicates if at least one node is connected to itself. This is distinct from the concept of []loops: a loop implies a path that starts and ends at the same node, traversing other nodes in between.

In the context of the Zacharyโ€™s karate club dataset, all these properties return False. This implies that the graph is not directed, does not have any isolated nodes, and none of its nodes are connected to themselves.

print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False

Finally, we can convert a graph from PyTorch Geometric to the popular graph library NetworkX using to_networkx. This is particularly useful to visualize a small graph with networkx and matplotlib.

Letโ€™s plot our dataset with a different color for each group.

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(12,12))
plt.axis('off')
nx.draw_networkx(G,
                pos=nx.spring_layout(G, seed=0),
                with_labels=True,
                node_size=800,
                node_color=data.y,
                cmap="hsv",
                vmin=-2,
                vmax=3,
                width=0.8,
                edge_color="grey",
                font_size=14
                )
plt.show()

This plot of Zacharyโ€™s karate club displays our 34 nodes, 78 (bidirectional) edges, and 4 labels with 4 different colors. Now that weโ€™ve seen the essentials of loading and handling a dataset with PyTorch Geometric, we can introduce the Graph Convolutional Network architecture.

โœ‰๏ธ II. Graph Convolutional Network

This section aims to introduce and build the graph convolutional layer from the ground up.

In traditional neural networks, linear layers apply a linear transformation to the incoming data. This transformation converts input features \(x\) into hidden vectors \(h\) through the use of a weight matrix \(\mathbf{W}\). Ignoring biases for the time being, this can be expressed as:

\[h = \mathbf{W} x\]

With graph data, an additional layer of complexity is added through the connections between nodes. These connections matter because, typically, in networks, itโ€™s assumed that similar nodes are more likely to be linked to each other than dissimilar ones, a phenomenon known as network homophily.

We can enrich our node representation by merging its features with those of its neighbors. This operation is called convolution, or neighborhood aggregation. Letโ€™s represent the neighborhood of node \(i\) including itself as \(\tilde{\mathcal{N}}_i\).

\[h_i = \sum_{j \in \tilde{\mathcal{N}}_i} \mathbf{W} x_j\]

Unlike filters in Convolutional Neural Networks (CNNs), our weight matrix \(\mathbf{W}\) is unique and shared among every node. But there is another issue: nodes do not have a fixed number of neighbors like pixels do.

How do we address cases where one node has only 1 neighbor, and another has 500? If we simply sum the feature vectors, the resulting embedding \(h\) would be much larger for the node with 500 neighbors. To ensure a similar range of values for all nodes and comparability between them, we can normalize the result based on the degree of nodes, where degree refers to the number of connections a node has.

\[h_i = \dfrac{1}{\deg(i)} \sum_{j \in \tilde{\mathcal{N}}_i} \mathbf{W} x_j\]

Weโ€™re almost there! Introduced by Kipf et al.ย (2016), the graph convolutional layer has one final improvement.

The authors observed that features from nodes with numerous neighbors propagate much more easily than those from more isolated nodes. To offset this effect, they suggested assigning bigger weights to features from nodes with fewer neighbors, thus balancing the influence across all nodes. This operation is written as:

\[h_i = \sum_{j \in \tilde{\mathcal{N}}_i} \dfrac{1}{\sqrt{\deg(i)}\sqrt{\deg(j)}} \mathbf{W} x_j\]

Note that when \(i\) and \(j\) have the same number of neighbors, it is equivalent to our own layer. Now, letโ€™s see how to implement it in Python with PyTorch Geometric.

๐Ÿง  III. Implementing a GCN

PyTorch Geometric provides the GCNConv function, which directly implements the graph convolutional layer.

In this example, weโ€™ll create a basic Graph Convolutional Network with a single GCN layer, a ReLU activation function, and a linear output layer. This output layer will yield four values corresponding to our four categories, with the highest value determining the class of each node.

In the following code block, we define the GCN layer with a 3-dimensional hidden layer.

from torch.nn import Linear
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn = GCNConv(dataset.num_features, 3)
        self.out = Linear(3, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.gcn(x, edge_index).relu()
        z = self.out(h)
        return h, z

model = GCN()
print(model)
GCN(
  (gcn): GCNConv(34, 3)
  (out): Linear(in_features=3, out_features=4, bias=True)
)

If we added a second GCN layer, our model would not only aggregate feature vectors from the neighbors of each node, but also from the neighbors of these neighbors.

We can stack several graph layers to aggregate more and more distant values, but thereโ€™s a catch: if we add too many layers, the aggregation becomes so intense that all the embeddings end up looking the same. This phenomenon is called over-smoothing and can be a real problem when you have too many layers.

Now that weโ€™ve defined our GNN, letโ€™s write a simple training loop with PyTorch. I chose a regular cross-entropy loss since itโ€™s a multi-class classification task, with Adam as optimizer. In this article, we wonโ€™t implement a train/test split to keep things simple and focus on how GNNs learn instead.

The training loop is standard: we try to predict the correct labels, and we compare the GCNโ€™s results to the values stored in data.y. The error is calculated by the cross-entropy loss and backpropagated with Adam to fine-tune our GNNโ€™s weights and biases. Finally, we print metrics every 10 epochs.

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

# Calculate accuracy
def accuracy(pred_y, y):
    return (pred_y == y).sum() / len(y)

# Data for animations
embeddings = []
losses = []
accuracies = []
outputs = []

# Training loop
for epoch in range(201):
    # Clear gradients
    optimizer.zero_grad()

    # Forward pass
    h, z = model(data.x, data.edge_index)

    # Calculate loss function
    loss = criterion(z, data.y)

    # Calculate accuracy
    acc = accuracy(z.argmax(dim=1), data.y)

    # Compute gradients
    loss.backward()

    # Tune parameters
    optimizer.step()

    # Store data for animations
    embeddings.append(h)
    losses.append(loss)
    accuracies.append(acc)
    outputs.append(z.argmax(dim=1))

    # Print metrics every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')
Epoch   0 | Loss: 1.40 | Acc: 41.18%
Epoch  10 | Loss: 1.21 | Acc: 47.06%
Epoch  20 | Loss: 1.02 | Acc: 67.65%
Epoch  30 | Loss: 0.80 | Acc: 73.53%
Epoch  40 | Loss: 0.59 | Acc: 73.53%
Epoch  50 | Loss: 0.39 | Acc: 94.12%
Epoch  60 | Loss: 0.23 | Acc: 97.06%
Epoch  70 | Loss: 0.13 | Acc: 100.00%
Epoch  80 | Loss: 0.07 | Acc: 100.00%
Epoch  90 | Loss: 0.05 | Acc: 100.00%
Epoch 100 | Loss: 0.03 | Acc: 100.00%
Epoch 110 | Loss: 0.02 | Acc: 100.00%
Epoch 120 | Loss: 0.02 | Acc: 100.00%
Epoch 130 | Loss: 0.02 | Acc: 100.00%
Epoch 140 | Loss: 0.01 | Acc: 100.00%
Epoch 150 | Loss: 0.01 | Acc: 100.00%
Epoch 160 | Loss: 0.01 | Acc: 100.00%
Epoch 170 | Loss: 0.01 | Acc: 100.00%
Epoch 180 | Loss: 0.01 | Acc: 100.00%
Epoch 190 | Loss: 0.01 | Acc: 100.00%
Epoch 200 | Loss: 0.01 | Acc: 100.00%

Great! Without much surprise, we reach 100% accuracy on the training set (full dataset). It means that our model learned to correctly assign every member of the karate club to its correct group.

We can produce a neat visualization by animating the graph and see the evolution of the GNNโ€™s predictions during the training process.

%%capture
from IPython.display import HTML
from matplotlib import animation
plt.rcParams["animation.bitrate"] = 3000

def animate(i):
    G = to_networkx(data, to_undirected=True)
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=True,
                    node_size=800,
                    node_color=outputs[i],
                    cmap="hsv",
                    vmin=-2,
                    vmax=3,
                    width=0.8,
                    edge_color="grey",
                    font_size=14
                    )
    plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
              fontsize=18, pad=20)

fig = plt.figure(figsize=(12, 12))
plt.axis('off')

anim = animation.FuncAnimation(fig, animate, \
            np.arange(0, 200, 10), interval=500, repeat=True)
html = HTML(anim.to_html5_video())
display(html)

The first predictions are random, but the GCN perfectly labels every node after a while. Indeed, the final graph is the same as the one we plotted at the end of the first section. But what does the GCN really learn?

By aggregating features from neighboring nodes, the GNN learns a vector representation (or embedding) of every node in the network. In our model, the final layer just learns how to use these representations to produce the best classifications. However, embeddings are the real products of GNNs.

Letโ€™s print the embeddings learned by our model.

# Print embeddings
print(f'Final embeddings = {h.shape}')
print(h)
Final embeddings = torch.Size([34, 3])
tensor([[1.9099e+00, 2.3584e+00, 7.4027e-01],
        [2.6203e+00, 2.7997e+00, 0.0000e+00],
        [2.2567e+00, 2.2962e+00, 6.4663e-01],
        [2.0802e+00, 2.8785e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.9694e+00],
        [0.0000e+00, 0.0000e+00, 3.3817e+00],
        [0.0000e+00, 1.5008e-04, 3.4246e+00],
        [1.7593e+00, 2.4292e+00, 2.4551e-01],
        [1.9757e+00, 6.1032e-01, 1.8986e+00],
        [1.7770e+00, 1.9950e+00, 6.7018e-01],
        [0.0000e+00, 1.1683e-04, 2.9738e+00],
        [1.8988e+00, 2.0512e+00, 2.6225e-01],
        [1.7081e+00, 2.3618e+00, 1.9609e-01],
        [1.8303e+00, 2.1591e+00, 3.5906e-01],
        [2.0755e+00, 2.7468e-01, 1.9804e+00],
        [1.9676e+00, 3.7185e-01, 2.0011e+00],
        [0.0000e+00, 0.0000e+00, 3.4787e+00],
        [1.6945e+00, 2.0350e+00, 1.9789e-01],
        [1.9808e+00, 3.2633e-01, 2.1349e+00],
        [1.7846e+00, 1.9585e+00, 4.8021e-01],
        [2.0420e+00, 2.7512e-01, 1.9810e+00],
        [1.7665e+00, 2.1357e+00, 4.0325e-01],
        [1.9870e+00, 3.3886e-01, 2.0421e+00],
        [2.0614e+00, 5.1042e-01, 2.4872e+00],
        [1.8381e-01, 2.1094e+00, 2.2035e+00],
        [1.8858e-01, 2.0701e+00, 2.1601e+00],
        [2.2553e+00, 4.1764e-01, 2.0231e+00],
        [1.6532e+00, 8.6745e-01, 2.2131e+00],
        [2.4265e-01, 2.1862e+00, 1.6104e+00],
        [2.5709e+00, 4.6342e-02, 2.3627e+00],
        [2.1778e+00, 4.4730e-01, 2.0077e+00],
        [3.8906e-02, 2.3443e+00, 1.9195e+00],
        [3.0748e+00, 0.0000e+00, 3.0789e+00],
        [3.4316e+00, 1.9716e-01, 2.5231e+00]], grad_fn=<ReluBackward0>)

As you can see, embeddings do not need to have the same dimensions as feature vectors. Here, I chose to reduce the number of dimensions from 34 (dataset.num_features) to three to get a nice visualization in 3D.

Letโ€™s plot these embeddings before any training happens, at epoch 0.

# Get first embedding at epoch = 0
embed = h.detach().cpu().numpy()

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.patch.set_alpha(0)
plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
           s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)

plt.show()

We see every node from Zacharyโ€™s karate club with their true labels (and not the modelโ€™s predictions). For now, theyโ€™re all over the place since the GNN is not trained yet. But if we plot these embeddings at each step of the training loop, weโ€™d be able to visualize what the GNN truly learns.

Letโ€™s see how they evolve over time, as the GCN gets better and better at classifying nodes.

%%capture

def animate(i):
    embed = embeddings[i].detach().cpu().numpy()
    ax.clear()
    ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
           s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
    plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
              fontsize=18, pad=40)

fig = plt.figure(figsize=(12, 12))
plt.axis('off')
ax = fig.add_subplot(projection='3d')
plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)

anim = animation.FuncAnimation(fig, animate, \
              np.arange(0, 200, 10), interval=800, repeat=True)
html = HTML(anim.to_html5_video())
display(html)

Our Graph Convolutional Network (GCN) has effectively learned embeddings that group similar nodes into distinct clusters. This enables the final linear layer to distinguish them into separate classes with ease.

Embeddings are not unique to GNNs: they can be found everywhere in deep learning. They donโ€™t have to be 3D either: actually, they rarely are. For instance, language models like BERT produce embeddings with 768 or even 1024 dimensions.

Additional dimensions store more information about nodes, text, images, etc. but they also create bigger models that are more difficult to train. This is why keeping low-dimensional embeddings as long as possible is advantageous.

Conclusion

Graph Convolutional Networks are an incredibly versatile architecture that can be applied in many contexts. In this article, we familiarized ourselves with the PyTorch Geometric library and objects like Datasets and Data. Then, we successfully reconstructed a graph convolutional layer from the ground up. Next, we put theory into practice by implementing a GCN, which gave us an understanding of practical aspects and how individual components interact. Finally, we visualized the training process and obtained a clear perspective of what it involves for such a network.

Zacharyโ€™s karate club is a simplistic dataset, but it is good enough to understand the most important concepts in graph data and GNNs. Although we only talked about node classification in this article, there are other tasks GNNs can accomplish: link prediction (e.g., to recommend a friend), graph classification (e.g., to label molecules), graph generation (e.g., to create new molecules), and so on.

Beyond GCN, numerous GNN layers and architectures have been proposed by researchers. In the next article, weโ€™ll introduce the Graph Attention Network (GAT) architecture, which dynamically computes the GCNโ€™s normalization factor and the importance of each connection with an attention mechanism.

If you enjoyed this article, feel free to follow me on Twitter for more GNN content. Thank you!

๐ŸŒ Graph Neural Network Course

๐Ÿ“• Hands-On Graph Neural Networks

๐Ÿ”Ž Course overview

๐Ÿ“ Chapter 1: Introduction to Graph Neural Networks

๐Ÿ“ Chapter 2: Graph Attention Network

๐Ÿ“ Chapter 3: GraphSAGE

๐Ÿ“ Chapter 4: Graph Isomorphism Network