Source code for dig.threedgraph.method.comenet.comenet

from torch_cluster import radius_graph
from torch_geometric.nn import GraphConv, GraphNorm
from torch_geometric.nn import inits

from .features import angle_emb, torsion_emb

from torch_scatter import scatter, scatter_min

from torch.nn import Embedding

import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

import math
from math import sqrt

try:
    import sympy as sym
except ImportError:
    sym = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def swish(x):
    return x * torch.sigmoid(x)

class Linear(torch.nn.Module):

    def __init__(self, in_channels, out_channels, bias=True,
                 weight_initializer='glorot',
                 bias_initializer='zeros'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer

        assert in_channels > 0
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        if self.in_channels > 0:
            if self.weight_initializer == 'glorot':
                inits.glorot(self.weight)
            elif self.weight_initializer == 'glorot_orthogonal':
                inits.glorot_orthogonal(self.weight, scale=2.0)
            elif self.weight_initializer == 'uniform':
                bound = 1.0 / math.sqrt(self.weight.size(-1))
                torch.nn.init.uniform_(self.weight.data, -bound, bound)
            elif self.weight_initializer == 'kaiming_uniform':
                inits.kaiming_uniform(self.weight, fan=self.in_channels,
                                      a=math.sqrt(5))
            elif self.weight_initializer == 'zeros':
                inits.zeros(self.weight)
            elif self.weight_initializer is None:
                inits.kaiming_uniform(self.weight, fan=self.in_channels,
                                      a=math.sqrt(5))
            else:
                raise RuntimeError(
                    f"Linear layer weight initializer "
                    f"'{self.weight_initializer}' is not supported")

        if self.in_channels > 0 and self.bias is not None:
            if self.bias_initializer == 'zeros':
                inits.zeros(self.bias)
            elif self.bias_initializer is None:
                inits.uniform(self.in_channels, self.bias)
            else:
                raise RuntimeError(
                    f"Linear layer bias initializer "
                    f"'{self.bias_initializer}' is not supported")

    def forward(self, x):
        """"""
        return F.linear(x, self.weight, self.bias)


class TwoLayerLinear(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            middle_channels,
            out_channels,
            bias=False,
            act=False,
    ):
        super(TwoLayerLinear, self).__init__()
        self.lin1 = Linear(in_channels, middle_channels, bias=bias)
        self.lin2 = Linear(middle_channels, out_channels, bias=bias)
        self.act = act

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x):
        x = self.lin1(x)
        if self.act:
            x = swish(x)
        x = self.lin2(x)
        if self.act:
            x = swish(x)
        return x


class EmbeddingBlock(torch.nn.Module):
    def __init__(self, hidden_channels, act=swish):
        super(EmbeddingBlock, self).__init__()
        self.act = act
        self.emb = Embedding(95, hidden_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))

    def forward(self, x):
        x = self.act(self.emb(x))
        return x


class EdgeGraphConv(GraphConv):

    def message(self, x_j, edge_weight) -> Tensor:
        return x_j if edge_weight is None else edge_weight * x_j


class SimpleInteractionBlock(torch.nn.Module):
    def __init__(
            self,
            hidden_channels,
            middle_channels,
            num_radial,
            num_spherical,
            num_layers,
            output_channels,
            act=swish
    ):
        super(SimpleInteractionBlock, self).__init__()
        self.act = act

        self.conv1 = EdgeGraphConv(hidden_channels, hidden_channels)

        self.conv2 = EdgeGraphConv(hidden_channels, hidden_channels)

        self.lin1 = Linear(hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.lin_cat = Linear(2 * hidden_channels, hidden_channels)

        self.norm = GraphNorm(hidden_channels)

        # Transformations of Bessel and spherical basis representations.
        self.lin_feature1 = TwoLayerLinear(num_radial * num_spherical ** 2, middle_channels, hidden_channels)
        self.lin_feature2 = TwoLayerLinear(num_radial * num_spherical, middle_channels, hidden_channels)

        # Dense transformations of input messages.
        self.lin = Linear(hidden_channels, hidden_channels)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.final = Linear(hidden_channels, output_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

        self.norm.reset_parameters()

        self.lin_feature1.reset_parameters()
        self.lin_feature2.reset_parameters()

        self.lin.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

        self.lin_cat.reset_parameters()

        for lin in self.lins:
            lin.reset_parameters()

        self.final.reset_parameters()

    def forward(self, x, feature1, feature2, edge_index, batch):
        x = self.act(self.lin(x))

        feature1 = self.lin_feature1(feature1)
        h1 = self.conv1(x, edge_index, feature1)
        h1 = self.lin1(h1)
        h1 = self.act(h1)

        feature2 = self.lin_feature2(feature2)
        h2 = self.conv2(x, edge_index, feature2)
        h2 = self.lin2(h2)
        h2 = self.act(h2)

        h = self.lin_cat(torch.cat([h1, h2], 1))

        h = h + x
        for lin in self.lins:
            h = self.act(lin(h)) + h
        h = self.norm(h, batch)
        h = self.final(h)
        return h


[docs]class ComENet(nn.Module): r""" The ComENet from the `"ComENet: Towards Complete and Efficient Message Passing for 3D Molecular Graphs" <https://arxiv.org/abs/2206.08515>`_ paper. Args: cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`8.0`) num_layers (int, optional): Number of building blocks. (default: :obj:`4`) hidden_channels (int, optional): Hidden embedding size. (default: :obj:`256`) middle_channels (int, optional): Middle embedding size for the two layer linear block. (default: :obj:`256`) out_channels (int, optional): Size of each output sample. (default: :obj:`1`) num_radial (int, optional): Number of radial basis functions. (default: :obj:`3`) num_spherical (int, optional): Number of spherical harmonics. (default: :obj:`2`) num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) """ def __init__( self, cutoff=8.0, num_layers=4, hidden_channels=256, middle_channels=64, out_channels=1, num_radial=3, num_spherical=2, num_output_layers=3, ): super(ComENet, self).__init__() self.out_channels = out_channels self.cutoff = cutoff self.num_layers = num_layers if sym is None: raise ImportError("Package `sympy` could not be found.") act = swish self.act = act self.feature1 = torsion_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) self.feature2 = angle_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) self.emb = EmbeddingBlock(hidden_channels, act) self.interaction_blocks = torch.nn.ModuleList( [ SimpleInteractionBlock( hidden_channels, middle_channels, num_radial, num_spherical, num_output_layers, hidden_channels, act, ) for _ in range(num_layers) ] ) self.lins = torch.nn.ModuleList() for _ in range(num_output_layers): self.lins.append(Linear(hidden_channels, hidden_channels)) self.lin_out = Linear(hidden_channels, out_channels) self.reset_parameters() def reset_parameters(self): self.emb.reset_parameters() for interaction in self.interaction_blocks: interaction.reset_parameters() for lin in self.lins: lin.reset_parameters() self.lin_out.reset_parameters() def _forward(self, data): batch = data.batch z = data.z.long() pos = data.pos num_nodes = z.size(0) edge_index = radius_graph(pos, r=self.cutoff, batch=batch) j, i = edge_index vecs = pos[j] - pos[i] dist = vecs.norm(dim=-1) # Embedding block. x = self.emb(z) # Calculate distances. _, argmin0 = scatter_min(dist, i, dim_size=num_nodes) argmin0[argmin0 >= len(i)] = 0 n0 = j[argmin0] add = torch.zeros_like(dist).to(dist.device) add[argmin0] = self.cutoff dist1 = dist + add _, argmin1 = scatter_min(dist1, i, dim_size=num_nodes) argmin1[argmin1 >= len(i)] = 0 n1 = j[argmin1] # -------------------------------------------------------- _, argmin0_j = scatter_min(dist, j, dim_size=num_nodes) argmin0_j[argmin0_j >= len(j)] = 0 n0_j = i[argmin0_j] add_j = torch.zeros_like(dist).to(dist.device) add_j[argmin0_j] = self.cutoff dist1_j = dist + add_j # i[argmin] = range(0, num_nodes) _, argmin1_j = scatter_min(dist1_j, j, dim_size=num_nodes) argmin1_j[argmin1_j >= len(j)] = 0 n1_j = i[argmin1_j] # ---------------------------------------------------------- # n0, n1 for i n0 = n0[i] n1 = n1[i] # n0, n1 for j n0_j = n0_j[j] n1_j = n1_j[j] # tau: (iref, i, j, jref) # when compute tau, do not use n0, n0_j as ref for i and j, # because if n0 = j, or n0_j = i, the computed tau is zero # so if n0 = j, we choose iref = n1 # if n0_j = i, we choose jref = n1_j mask_iref = n0 == j iref = torch.clone(n0) iref[mask_iref] = n1[mask_iref] idx_iref = argmin0[i] idx_iref[mask_iref] = argmin1[i][mask_iref] mask_jref = n0_j == i jref = torch.clone(n0_j) jref[mask_jref] = n1_j[mask_jref] idx_jref = argmin0_j[j] idx_jref[mask_jref] = argmin1_j[j][mask_jref] pos_ji, pos_in0, pos_in1, pos_iref, pos_jref_j = ( vecs, vecs[argmin0][i], vecs[argmin1][i], vecs[idx_iref], vecs[idx_jref] ) # Calculate angles. a = ((-pos_ji) * pos_in0).sum(dim=-1) b = torch.cross(-pos_ji, pos_in0).norm(dim=-1) theta = torch.atan2(b, a) theta[theta < 0] = theta[theta < 0] + math.pi # Calculate torsions. dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt() plane1 = torch.cross(-pos_ji, pos_in0) plane2 = torch.cross(-pos_ji, pos_in1) a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2| b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji phi = torch.atan2(b, a) phi[phi < 0] = phi[phi < 0] + math.pi # Calculate right torsions. plane1 = torch.cross(pos_ji, pos_jref_j) plane2 = torch.cross(pos_ji, pos_iref) a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2| b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji tau = torch.atan2(b, a) tau[tau < 0] = tau[tau < 0] + math.pi feature1 = self.feature1(dist, theta, phi) feature2 = self.feature2(dist, tau) # Interaction blocks. for interaction_block in self.interaction_blocks: x = interaction_block(x, feature1, feature2, edge_index, batch) for lin in self.lins: x = self.act(lin(x)) x = self.lin_out(x) energy = scatter(x, batch, dim=0) return energy
[docs] def forward(self, batch_data): return self._forward(batch_data)