Source code for dig.sslgraph.method.contrastive.objectives.infonce

import itertools
import torch
import torch.nn.functional as F


[docs]def NCE_loss(zs=None, zs_n=None, batch=None, sigma=None, **kwargs): '''The InfoNCE (NT-XENT) loss in contrastive learning. Args: zs (list, optipnal): List of tensors of shape [batch_size, z_dim]. zs_n (list, optional): List of tensors of shape [nodes, z_dim]. batch (Tensor, optional): Required when both :obj:`zs` and :obj:`zs_n` are given. sigma (ndarray, optional): A 2D-array of shape [:obj:`n_views`, :obj:`n_views`] with boolean values, indicating contrast between which two views are computed. Only required when number of views is greater than 2. If :obj:`sigma[i][j]` = :obj:`True`, infoNCE between :math:`view_i` and :math:`view_j` will be computed. tau (int, optional): The temperature used in NT-XENT. :rtype: :class:`Tensor` ''' assert zs is not None or zs_n is not None if 'tau' in kwargs: tau = kwargs['tau'] else: tau = 0.5 if 'norm' in kwargs: norm = kwargs['norm'] else: norm = True mean = kwargs['mean'] if 'mean' in kwargs else True if zs_n is not None: if zs is None: # InfoNCE in GRACE assert len(zs_n)==2 return (infoNCE_local_intra_node(zs_n[0], zs_n[1], tau, norm, batch)+ infoNCE_local_intra_node(zs_n[1], zs_n[0], tau, norm, batch))*0.5 else: assert len(zs_n)==len(zs) assert batch is not None if len(zs)==1: return infoNCE_local_global(zs[0], zs_n[0], batch, tau, norm) elif len(zs)==2: return (infoNCE_local_global(zs[0], zs_n[1], batch, tau, norm)+ infoNCE_local_global(zs[1], zs_n[0], batch, tau, norm)) else: assert len(zs)==len(sigma) loss = 0 for (i, j) in itertools.combinations(range(len(zs)), 2): if sigma[i][j]: loss += (infoNCE_local_global(zs[i], zs_n[j], batch, tau, norm)+ infoNCE_local_global(zs[j], zs_n[i], batch, tau, norm)) return loss if len(zs)==2: return NT_Xent(zs[0], zs[1], tau, norm) elif len(zs)>2: assert len(zs)==len(sigma) loss = 0 for (i, j) in itertools.combinations(range(len(zs)), 2): if sigma[i][j]: loss += NT_Xent(zs[i], zs[j], tau, norm) return loss
def infoNCE_local_intra_node(z1_n, z2_n, tau=0.5, norm=True, batch=None): ''' Args: z1_n: Tensor of shape [n_nodes, z_dim]. z2_n: Tensor of shape [n_nodes, z_dim]. tau: Float. Usually in (0,1]. norm: Boolean. Whether to apply normlization. batch: Tensor of shape [batch_size] ''' def sim(z1:torch.Tensor, z2:torch.Tensor): if norm: z1 = F.normalize(z1) z2 = F.normalize(z2) return torch.mm(z1, z2.t()) exp = lambda x: torch.exp(x / tau) if batch is not None: batch_size = batch.size(0) num_nodes = z1_n.size(0) indices = torch.arange(0, num_nodes).to(z1_n.device) losses = [] for i in range(0, num_nodes, batch_size): mask = indices[i:i+batch_size] refl_sim = exp(sim(z1_n[mask], z1_n)) between_sim = exp(sim(z1_n[mask], z2_n)) losses.append(-torch.log(between_sim[:, i:i+batch_size].diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim[:, i:i+batch_size].diag()))) losses = torch.cat(losses) return losses.mean() refl_sim = exp(sim(z1_n, z1_n)) between_sim = exp(sim(z1_n, z2_n)) pos_sim = between_sim.diag() intra_sim = refl_sim.sum(1) - refl_sim.diag() inter_pos_sim = between_sim.sum(1) loss = pos_sim / (intra_sim + inter_pos_sim) loss = -torch.log(loss).mean() return loss def infoNCE_local_global(z_n, z_g, batch, tau=0.5, norm=True): ''' Args: z_n: Tensor of shape [n_nodes, z_dim]. z_g: Tensor of shape [n_graphs, z_dim]. tau: Float. Usually in (0,1]. norm: Boolean. Whether to apply normlization. ''' # Not yet used in existing methods, to be implemented. loss = 0 return loss def NT_Xent(z1, z2, tau=0.5, norm=True): ''' Args: z1, z2: Tensor of shape [batch_size, z_dim] tau: Float. Usually in (0,1]. norm: Boolean. Whether to apply normlization. ''' batch_size, _ = z1.size() sim_matrix = torch.einsum('ik,jk->ij', z1, z2) if norm: z1_abs = z1.norm(dim=1) z2_abs = z2.norm(dim=1) sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs) sim_matrix = torch.exp(sim_matrix / tau) pos_sim = sim_matrix[range(batch_size), range(batch_size)] loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) loss = - torch.log(loss).mean() return loss