Tutorial for Self-Supervised GNNs

In this tutorial, we will show how to use DIG library 1 to build self-supervised learning (SSL) frameworks to train Graph Neural Networks (GNNs). Specifically, we provide a unified and general description of the constrastive SSL framework following the survey 2, based on which we introduce the Contrastive base class. We further provide two examples to demonstrate how to build up customized contrastive framework and the GraphCL 3 framework using the base class.

Contrastive Learning Frameworks

Contrastive learning has achieved great success in natural language processing and computer vision. Inspired by the success, contrastive learning methods are adapted to graph data, such as GraphCL, MVGRL 5 and InfoGraph 6 .

A contrastive learning framework consists of multiple (k>=2) views of the input variable. It seeks to maximize the mutual information between the representations of different views. In particular, it learns to discriminate jointly sampled view pairs from independently sampled view pairs. Concretely, two views generated from the same instance are considered as a positive pair and two views generated from different instances are considered as a negative pair.

The computation pipeline of a contrastive learning framework can be formulated as

\[\mathbf{w_i} = \mathcal{T}_i (\mathbf{A}, \mathbf{X})\]
\[\mathbf{h_i} = f_i (\mathbf{w_i}), i = 1, \cdots, k\]
\[\max_{\{f_i\}_{i=1}^k} \frac{1}{\sum_{i \ne j}\sigma_{ij} } \left[ \sum_{i\ne j} \sigma_{ij} \mathcal{I}(\mathbf{h_i}, \mathbf{h_j}) \right]\]

where \((\mathbf{A}, \mathbf{X})\) is a given graph as a random variable distributed from \(\mathcal{P}\), \(\mathcal{T}_1,\cdots,\mathcal{T}_k\) are multiple transformations (augmentations) to obtain different views \(\mathbf{w_1}, \cdots, \mathbf{w_k}\), \(f_i, \cdots, f_k\) are encoding networks to generate output representations \(\mathbf{h_1}, \cdots, \mathbf{h_k}\), \(\sigma_{ij} \in \{0,1\}\), \(\mathcal{I}(\mathbf{h_i}, \mathbf{h_j})\) is mutual information between \(\mathbf{h_i}\) and \(\mathbf{h_j}\), \(\sigma_{ij}=1\) if the mutual information is computed between \(\mathbf{h_i}\) and \(\mathbf{h_j}\), and \(\sigma_{ij}=0\) otherwise.

The “Contrastive” Base Class

DIG-sslgraph provides the Contrastive base class to implement existing or create customized contrastive learning frameworks.

Contrastive(objective, views_fn, graph_level=True, node_level=False, z_dim=None, z_n_dim=None, proj=None, proj_n=None, neg_by_crpt=False, tau=0.5, device=None, choice_model=’last’,model_path=’models’)

Following the computation pipeline, to construct a certain contrastive framework, one need to first specify key components in the Contrastive base class: 1) objective, the objective function for MI maximization, 2) views_fn, a list of transformations to generate multiple views of a graph, and 3) graph_level and node_level, whether to learn graph-level or node-level representations.

Objectives (MI Estimators)

Two lower-bound mutual information estimators are implemented in DIG:

  • Jensen-Shannon Estimator (JSE):

\[\hat{\mathcal{I}}^{(JS)}(\mathbf{h_i}, \mathbf{h_j}) = \mathbb{E}_{(\mathbf{A}, \mathbf{X}) \sim \mathcal{P}} \left[ log(\mathcal{D}(\mathbf{h_i}, \mathbf{h_j})) \right] + \mathbb{E}_{[(\mathbf{A}, \mathbf{X}), (\mathbf{A'}, \mathbf{X'})] \sim \mathcal{P} \times \mathcal{P}} \left[ log(1-\mathcal{D}(\mathbf{h_i}, \mathbf{h_j'})) \right]\]

where \(\mathbf{h_i}, \mathbf{h_j}\) in the first term are computed from \((\mathbf{A}, \mathbf{X})\) distributed from \(\mathcal{P}\), \(\mathbf{h_i}\) and \(\mathbf{h_j}'\) in the second term are computed from \((\mathbf{A}, \mathbf{X})\) and \((\mathbf{A'}, \mathbf{X'})\) identically and independently distributed from the distribution \(\mathcal{P}\).

  • InfoNCE (NCE):

\[\begin{split}\hat{\mathcal{I}}^{(NCE)}(\mathbf{h_i}, \mathbf{h_j}) &= \mathbb{E}_{(\mathbf{A}, \mathbf{X}) \sim \mathcal{P}} \left[ \mathcal{D}(\mathbf{h_i}, \mathbf{h_j}) - \mathbb{E}_{\mathbf{K}\sim \mathcal{P}^N} \left[ log \sum_{(\mathbf{A'}, \mathbf{X'}) \in \mathbf{K}} e^{\mathcal{D}(\mathbf{h_i}, \mathbf{h_j}') / N} \left| \right (\mathbf{A}, \mathbf{X}) \right] \right] \\ &= \mathbb{E}_{[(\mathbf{A}, \mathbf{X}), \mathbf{K}] \sim \mathcal{P} \times \mathcal{P}^N} \left[ log \frac{e^{(\mathbf{h_i}, \mathbf{h_j})}}{\sum_{(\mathbf{A'}, \mathbf{X'}) \in \mathbf{K}} e^{\mathcal{D}(\mathbf{h_i}, \mathbf{h_j}')}}\right] + logN\end{split}\]

where \(\mathbf{K}\) consisted of \(N\) random variable identically and independently distributed from \(\mathcal{P}\), \(\mathbf{h_i}, \mathbf{h_j}\) are the representations of the i-th and j-th views of \((\mathbf{A}, \mathbf{X})\), and \(\mathbf{h_i}'\) is the representation of the j-th view of \((\mathbf{A'}, \mathbf{X'})\).

  • In addition to the type of MI estimator, the users are able to specify proj and proj_n whether projection head(s) with trainable parameters are included and what projection head(s) to be included when computing the MI estimates. A projection head will turn the MI estimator into a parameterized estimator and can bring performance gain to certain contrastive methods.

View Generation

Variety of view generation functions \(\mathcal{T}\) belonging to three types are implemented in DIG. To perform multi-view contrastive learning the number of view generators (len(views_fn)) should be no less than 2.

\[\mathcal{T}_{feat}(\mathbf{A}, \mathbf{X}) = (\mathbf{A}, \mathcal{T}_X(\mathbf{X}))\]

where \(\mathcal{T}_X: \mathbb{R}^{|V|\times d} \to \mathbb{R}^{|V|\times d}\) performs the transformation on the feature matrix \(\mathbf{X}\).

\[\mathcal{T}_{struct}(\mathbf{A}, \mathbf{X}) = (\mathcal{T}_A(\mathbf{A}), \mathbf{X})\]

where \(\mathcal{T}_A: \mathbb{R}^{|V|\times |V|} \to \mathbb{R}^{|V|\times |V|}\) performs the transformation on the adjacency matrix \(\mathbf{A}\).

\[\mathcal{T}_{sample}(\mathbf{A}, \mathbf{X}) = (\mathbf{A}[S;S], \mathbf{X}[S])\]

where \(S \subseteq V\) denotes a subset of nodes and \([\cdot]\) selects certain rows and columns from a matrix based on indices of nodes in \(S\).

Level of Representations

DIG-sslgraph provides three different representation levels to perform contrastive learning. By default, the base class performs graph-level contrast. To perform node-level contrast, one can set graph_level`=:bool:`False and node_level`=:bool:`True. If both graph_level and node_level are True, the contrastive method performs local-global constrast. In this case, the number of view generators (len(views_fn)) can be 1.

Creating Customized Contrastive Methods

The simplest way to create a customized contrastive method is to define a subclass of Contrastive by specify corresponding components and override the method train(). Below is an example to employ two node attribute masking view functions, the “JSE” objective with “MLP” projection head for graph-level constrastive learning.

from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask
from dig.sslgraph.method import Contrastive

class SSLModel(Contrastive):
    def __init__(self, z_dim, mask_ratio, **kwargs):

        objective = "JSE"
        mask_i = NodeAttrMask(mask_ratio=mask_ratio)
        mask_j = NodeAttrMask(mask_ratio=mask_ratio)
        views_fn = [mask_i, mask_j]

        super(SSLModel, self).__init__(objective=objective,

    def train(self, encoder, data_loader, optimizer, epochs, per_epoch_out=False):
        for enc, proj in super(SSLModel, self).train(encoder, data_loader,
                                                    optimizer, epochs, per_epoch_out):
            yield enc

ssl_model = SSLModel(z_dim=embed_dim, mask_ratio=0.1)

Below is another example using the Contrastive base class to implement GraphCL, who employs random augmentations to generate views and optimize the “NCE” objective.

import sys, torch
import torch.nn as nn
from dig.sslgraph.method import Contrastive
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, \
    UniformSample, RWSample, RandomView

class GraphCL(Contrastive):

    def __init__(self, dim, aug_1=None, aug_2=None, aug_ratio=0.2, **kwargs):

        views_fn = []

        for aug in [aug_1, aug_2]:
            if aug is None:
                views_fn.append(lambda x: x)
            elif aug == 'dropN':
            elif aug == 'permE':
            elif aug == 'subgraph':
            elif aug == 'maskN':
            elif aug == 'random2':
                canditates = [UniformSample(ratio=aug_ratio),
            elif aug == 'random4':
                canditates = [UniformSample(ratio=aug_ratio),
            elif aug == 'random3':
                canditates = [UniformSample(ratio=aug_ratio),
                raise Exception("Aug must be from [dropN', 'permE', 'subgraph', \
                                'maskN', 'random2', 'random3', 'random4'] or None.")

        super(GraphCL, self).__init__(objective='NCE',

    def train(self, encoders, data_loader, optimizer, epochs, per_epoch_out=False):
        # GraphCL removes projection heads after pre-training
        for enc, proj in super(GraphCL, self).train(encoders, data_loader,
                                                    optimizer, epochs, per_epoch_out):
            yield enc

Note that the train returns a generator the yields trained encoder and projection heads at each iteration. That is because some contrastive methods also requires the projection heads in downstream tasks (such as MVGRL).

Evaluation of encapsulated models

You can always write your own code to do flexible evlauation of the above defined contrastive methods. However, we provide pre-implemented evluation tools for more convenient evaluation. The tool works with most datasets from pytorch-geometric. Below is an example of perform semi-supervised evaluation for GraphCL. More examples can be found in runnable jupyter notebooks in the benchmark.

For the first step, we load the dataset NCI, which is a typical dataset for graph classification. One can also use different datasets for pretraining and finetuning.

from dig.sslgraph.dataset import get_dataset
dataset, dataset_pretrain = get_dataset('NCI1', task='semisupervised')
feat_dim = dataset[0].x.shape[1]
embed_dim = 128

Then we employ ResGCN 4 as the graph encoder and run the evaluation.

from dig.sslgraph.utils import Encoder
from dig.sslgraph.method import GraphCL
from dig.sslgraph.evaluation import GraphSemisupervised

encoder = Encoder(feat_dim, embed_dim, n_layers=3, gnn='resgcn')
graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='subgraph')
evaluator = GraphSemisupervised(dataset, dataset_pretrain, label_rate=0.01)
evaluator.evaluate(learning_model=graphcl, encoder=encoder)

