import torch.nn as nn
from .contrastive import Contrastive

class ProjHead(nn.Module):
    def __init__(self, input_dim, out_dim):
        self.block = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            nn.Linear(out_dim, out_dim),
            nn.Linear(out_dim, out_dim),
        self.linear_shortcut = nn.Linear(input_dim, out_dim)

    def forward(self, x):
        return self.block(x) + self.linear_shortcut(x)
class InfoG_enc(nn.Module):
    def __init__(self, encoder, z_g_dim, z_n_dim):
        super(InfoG_enc, self).__init__()
        self.fc = nn.Linear(z_g_dim, z_n_dim)
        self.encoder = encoder
    def forward(self, data):
        zg, zn = self.encoder(data)
        zg = self.fc(zg)
        return zg

[docs]class InfoGraph(Contrastive): r""" Contrastive learning method proposed in the paper `InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization <>`_. You can refer to `the benchmark code <>`_ for an example of usage. *Alias*: :obj:`dig.sslgraph.method.contrastive.model.`:obj:`InfoGraph`. Args: g_dim (int): The embedding dimension for graph-level (global) representations. n_dim (int): The embedding dimension for node-level (local) representations. Typically, when jumping knowledge is included in the encoder, we have :obj:`g_dim` = :obj:`n_layers` * :obj:`n_dim`. **kwargs (optinal): Additional arguments of :class:`dig.sslgraph.method.Contrastive`. """ def __init__(self, g_dim, n_dim, **kwargs): views_fn = [lambda x: x] proj = ProjHead(g_dim, n_dim) proj_n = ProjHead(n_dim, n_dim) super(InfoGraph, self).__init__(objective='JSE', views_fn=views_fn, node_level=True, z_dim=g_dim, z_n_dim=n_dim, proj=proj, proj_n=proj_n, **kwargs)
[docs] def train(self, encoders, data_loader, optimizer, epochs, per_epoch_out=False): for enc, (proj, proj_n) in super(InfoGraph, self).train(encoders, data_loader, optimizer, epochs, per_epoch_out): yield InfoG_enc(enc, self.z_dim, self.z_n_dim)