Source code for dig.sslgraph.method.contrastive.model.graphcl

from .contrastive import Contrastive
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, \
    UniformSample, RWSample, RandomView


[docs]class GraphCL(Contrastive): r""" Contrastive learning method proposed in the paper `Graph Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`_. You can refer to `the benchmark code <https://github.com/divelab/DIG/blob/dig/benchmarks/sslgraph/example_graphcl.ipynb>`_ for an example of usage. *Alias*: :obj:`dig.sslgraph.method.contrastive.model.`:obj:`GraphCL`. Args: dim (int): The embedding dimension. aug1 (sting, optinal): Types of augmentation for the first view from (:obj:`"dropN"`, :obj:`"permE"`, :obj:`"subgraph"`, :obj:`"maskN"`, :obj:`"random2"`, :obj:`"random3"`, :obj:`"random4"`). (default: :obj:`None`) aug2 (sting, optinal): Types of augmentation for the second view from (:obj:`"dropN"`, :obj:`"permE"`, :obj:`"subgraph"`, :obj:`"maskN"`, :obj:`"random2"`, :obj:`"random3"`, :obj:`"random4"`). (default: :obj:`None`) aug_ratio (float, optional): The ratio of augmentations. A number between [0,1). **kwargs (optinal): Additional arguments of :class:`dig.sslgraph.method.Contrastive`. """ def __init__(self, dim, aug_1=None, aug_2=None, aug_ratio=0.2, **kwargs): views_fn = [] for aug in [aug_1, aug_2]: if aug is None: views_fn.append(lambda x: x) elif aug == 'dropN': views_fn.append(UniformSample(ratio=aug_ratio)) elif aug == 'permE': views_fn.append(EdgePerturbation(ratio=aug_ratio)) elif aug == 'subgraph': views_fn.append(RWSample(ratio=aug_ratio)) elif aug == 'maskN': views_fn.append(NodeAttrMask(mask_ratio=aug_ratio)) elif aug == 'random2': canditates = [UniformSample(ratio=aug_ratio), RWSample(ratio=aug_ratio)] views_fn.append(RandomView(canditates)) elif aug == 'random4': canditates = [UniformSample(ratio=aug_ratio), RWSample(ratio=aug_ratio), EdgePerturbation(ratio=aug_ratio)] views_fn.append(RandomView(canditates)) elif aug == 'random3': canditates = [UniformSample(ratio=aug_ratio), RWSample(ratio=aug_ratio), EdgePerturbation(ratio=aug_ratio), NodeAttrMask(mask_ratio=aug_ratio)] views_fn.append(RandomView(canditates)) else: raise Exception("Aug must be from [dropN', 'permE', 'subgraph', \ 'maskN', 'random2', 'random3', 'random4'] or None.") super(GraphCL, self).__init__(objective='NCE', views_fn=views_fn, z_dim=dim, proj='MLP', node_level=False, **kwargs)
[docs] def train(self, encoders, data_loader, optimizer, epochs, per_epoch_out=False): # GraphCL removes projection heads after pre-training for enc, proj in super(GraphCL, self).train(encoders, data_loader, optimizer, epochs, per_epoch_out): yield enc