Source code for dig.sslgraph.method.contrastive.model.grace

from .contrastive import Contrastive
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, Sequential

[docs]class GRACE(Contrastive): r""" Contrastive learning method proposed in the paper `Deep Graph Contrastive Representation Learning <>`_. You can refer to `the benchmark code <>`_ for an example of usage. *Alias*: :obj:`dig.sslgraph.method.contrastive.model.`:obj:`GRACE`. Args: dim (int): The embedding dimension. dropE_rate_1, dropE_rate_2 (float): The ratio of the edge dropping augmentation for view 1. A number between [0,1). maskN_rate_1, maskN_rate_2 (float): The ratio of the node masking augmentation for view 2. A number between [0,1). **kwargs (optinal): Additional arguments of :class:`dig.sslgraph.method.Contrastive`. """ def __init__(self, dim, dropE_rate_1, dropE_rate_2, maskN_rate_1, maskN_rate_2, **kwargs): view_fn_1 = Sequential([EdgePerturbation(drop=True, ratio=dropE_rate_1), NodeAttrMask(mask_ratio=maskN_rate_1)]) view_fn_2 = Sequential([EdgePerturbation(drop=True, ratio=dropE_rate_2), NodeAttrMask(mask_ratio=maskN_rate_2)]) views_fn = [view_fn_1, view_fn_2] super(GRACE, self).__init__(objective='NCE', views_fn=views_fn, graph_level=False, node_level=True, z_n_dim=dim, proj_n='MLP', **kwargs)
[docs] def train(self, encoders, data_loader, optimizer, epochs, per_epoch_out=False): # GRACE removes projection heads after pre-training for enc, proj in super().train(encoders, data_loader, optimizer, epochs, per_epoch_out): yield enc