Source code for dig.threedgraph.method.pronet.pronet

"""
This is an implementation of ProNet model

"""

from torch_geometric.nn import inits, MessagePassing
from torch_geometric.nn import radius_graph

from .features import d_angle_emb, d_theta_phi_emb

from torch_scatter import scatter
from torch_sparse import matmul

import torch
from torch import nn
from torch.nn import Embedding
import torch.nn.functional as F

import numpy as np


num_aa_type = 26
num_side_chain_embs = 8
num_bb_embs = 6

def swish(x):
    return x * torch.sigmoid(x)


class Linear(torch.nn.Module):
    """
        A linear method encapsulation similar to PyG's

        Parameters
        ----------
        in_channels (int)
        out_channels (int)
        bias (int)
        weight_initializer (string): (glorot or zeros)
    """

    def __init__(self, in_channels, out_channels, bias=True, weight_initializer='glorot'):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight_initializer = weight_initializer

        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        if self.weight_initializer == 'glorot':
            inits.glorot(self.weight)
        elif self.weight_initializer == 'zeros':
            inits.zeros(self.weight)
        if self.bias is not None:
            inits.zeros(self.bias)

    def forward(self, x):
        """"""
        return F.linear(x, self.weight, self.bias)


class TwoLinear(torch.nn.Module):
    """
        A layer with two linear modules

        Parameters
        ----------
        in_channels (int)
        middle_channels (int)
        out_channels (int)
        bias (bool)
        act (bool)
    """

    def __init__(
            self,
            in_channels,
            middle_channels,
            out_channels,
            bias=False,
            act=False
    ):
        super(TwoLinear, self).__init__()
        self.lin1 = Linear(in_channels, middle_channels, bias=bias)
        self.lin2 = Linear(middle_channels, out_channels, bias=bias)
        self.act = act

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x):
        x = self.lin1(x)
        if self.act:
            x = swish(x)
        x = self.lin2(x)
        if self.act:
            x = swish(x)
        return x


class EdgeGraphConv(MessagePassing):
    """
        Graph convolution similar to PyG's GraphConv(https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GraphConv)

        The difference is that this module performs Hadamard product between node feature and edge feature

        Parameters
        ----------
        in_channels (int)
        out_channels (int)
    """
    def __init__(self, in_channels, out_channels):
        super(EdgeGraphConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin_l = Linear(in_channels, out_channels)
        self.lin_r = Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, edge_weight, size=None):
        x = (x, x)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
        out = self.lin_l(out)
        return out + self.lin_r(x[1])

    def message(self, x_j, edge_weight):
        return edge_weight * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x[0], reduce=self.aggr)


class InteractionBlock(torch.nn.Module):
    def __init__(
            self,
            hidden_channels,
            output_channels,
            num_radial,
            num_spherical,
            num_layers,
            mid_emb,
            act=swish,
            num_pos_emb=16,
            dropout=0,
            level='allatom'
    ):
        super(InteractionBlock, self).__init__()
        self.act = act
        self.dropout = nn.Dropout(dropout)
        
        self.conv0 = EdgeGraphConv(hidden_channels, hidden_channels)
        self.conv1 = EdgeGraphConv(hidden_channels, hidden_channels)
        self.conv2 = EdgeGraphConv(hidden_channels, hidden_channels)

        self.lin_feature0 = TwoLinear(num_radial * num_spherical ** 2, mid_emb, hidden_channels)
        if level == 'aminoacid':
            self.lin_feature1 = TwoLinear(num_radial * num_spherical, mid_emb, hidden_channels)
        elif level == 'backbone' or level == 'allatom':
            self.lin_feature1 = TwoLinear(3 * num_radial * num_spherical, mid_emb, hidden_channels)
        self.lin_feature2 = TwoLinear(num_pos_emb, mid_emb, hidden_channels)

        self.lin_1 = Linear(hidden_channels, hidden_channels)
        self.lin_2 = Linear(hidden_channels, hidden_channels)

        self.lin0 = Linear(hidden_channels, hidden_channels)
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.lins_cat = torch.nn.ModuleList()
        self.lins_cat.append(Linear(3*hidden_channels, hidden_channels))
        for _ in range(num_layers-1):
            self.lins_cat.append(Linear(hidden_channels, hidden_channels))

        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers-1):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.final = Linear(hidden_channels, output_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv0.reset_parameters()
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

        self.lin_feature0.reset_parameters()
        self.lin_feature1.reset_parameters()
        self.lin_feature2.reset_parameters()

        self.lin_1.reset_parameters()
        self.lin_2.reset_parameters()

        self.lin0.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

        for lin in self.lins:
            lin.reset_parameters()
        for lin in self.lins_cat:
            lin.reset_parameters()

        self.final.reset_parameters()


    def forward(self, x, feature0, feature1, pos_emb, edge_index, batch):
        x_lin_1 = self.act(self.lin_1(x))
        x_lin_2 = self.act(self.lin_2(x))
        
        feature0 = self.lin_feature0(feature0)
        h0 = self.conv0(x_lin_1, edge_index, feature0)
        h0 = self.lin0(h0)
        h0 = self.act(h0)
        h0 = self.dropout(h0)

        feature1 = self.lin_feature1(feature1)
        h1 = self.conv1(x_lin_1, edge_index, feature1)
        h1 = self.lin1(h1)
        h1 = self.act(h1)
        h1 = self.dropout(h1)

        feature2 = self.lin_feature2(pos_emb)
        h2 = self.conv2(x_lin_1, edge_index, feature2)
        h2 = self.lin2(h2)
        h2 = self.act(h2)
        h2 = self.dropout(h2)

        h = torch.cat((h0, h1, h2),1)
        for lin in self.lins_cat:
            h = self.act(lin(h)) 

        h = h + x_lin_2

        for lin in self.lins:
            h = self.act(lin(h)) 
        h = self.final(h)
        return h


[docs]class ProNet(nn.Module): r""" The ProNet from the "Learning Protein Representations via Complete 3D Graph Networks" paper. Args: level: (str, optional): The level of protein representations. It could be :obj:`aminoacid`, obj:`backbone`, and :obj:`allatom`. (default: :obj:`aminoacid`) num_blocks (int, optional): Number of building blocks. (default: :obj:`4`) hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) out_channels (int, optional): Size of each output sample. (default: :obj:`1`) mid_emb (int, optional): Embedding size used for geometric features. (default: :obj:`64`) num_radial (int, optional): Number of radial basis functions. (default: :obj:`6`) num_spherical (int, optional): Number of spherical harmonics. (default: :obj:`2`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): Max number of neighbors during graph construction. (default: :obj:`32`) int_emb_layers (int, optional): Number of embedding layers in the interaction block. (default: :obj:`3`) out_layers (int, optional): Number of layers for features after interaction blocks. (default: :obj:`2`) num_pos_emb (int, optional): Number of positional embeddings. (default: :obj:`16`) dropout (float, optional): Dropout. (default: :obj:`0`) data_augment_eachlayer (bool, optional): Data augmentation tricks. If set to :obj:`True`, will add noise to the node features before each interaction block. (default: :obj:`False`) euler_noise (bool, optional): Data augmentation tricks. If set to :obj:`True`, will add noise to Euler angles. (default: :obj:`False`) """ def __init__( self, level='aminoacid', num_blocks=4, hidden_channels=128, out_channels=1, mid_emb=64, num_radial=6, num_spherical=2, cutoff=10.0, max_num_neighbors=32, int_emb_layers=3, out_layers=2, num_pos_emb=16, dropout=0, data_augment_eachlayer=False, euler_noise = False, ): super(ProNet, self).__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.num_pos_emb = num_pos_emb self.data_augment_eachlayer = data_augment_eachlayer self.euler_noise = euler_noise self.level = level self.act = swish self.feature0 = d_theta_phi_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) self.feature1 = d_angle_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) if level == 'aminoacid': self.embedding = Embedding(num_aa_type, hidden_channels) elif level == 'backbone': self.embedding = torch.nn.Linear(num_aa_type + num_bb_embs, hidden_channels) elif level == 'allatom': self.embedding = torch.nn.Linear(num_aa_type + num_bb_embs + num_side_chain_embs, hidden_channels) else: print('No supported model!') self.interaction_blocks = torch.nn.ModuleList( [ InteractionBlock( hidden_channels=hidden_channels, output_channels=hidden_channels, num_radial=num_radial, num_spherical=num_spherical, num_layers=int_emb_layers, mid_emb=mid_emb, act=self.act, num_pos_emb=num_pos_emb, dropout=dropout, level=level ) for _ in range(num_blocks) ] ) self.lins_out = torch.nn.ModuleList() for _ in range(out_layers-1): self.lins_out.append(Linear(hidden_channels, hidden_channels)) self.lin_out = Linear(hidden_channels, out_channels) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): self.embedding.reset_parameters() for interaction in self.interaction_blocks: interaction.reset_parameters() for lin in self.lins_out: lin.reset_parameters() self.lin_out.reset_parameters() def pos_emb(self, edge_index, num_pos_emb=16): # From https://github.com/jingraham/neurips19-graph-protein-design d = edge_index[0] - edge_index[1] frequency = torch.exp( torch.arange(0, num_pos_emb, 2, dtype=torch.float32, device=edge_index.device) * -(np.log(10000.0) / num_pos_emb) ) angles = d.unsqueeze(-1) * frequency E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) return E
[docs] def forward(self, batch_data): z, pos, batch = torch.squeeze(batch_data.x.long()), batch_data.coords_ca, batch_data.batch pos_n = batch_data.coords_n pos_c = batch_data.coords_c bb_embs = batch_data.bb_embs side_chain_embs = batch_data.side_chain_embs device = z.device if self.level == 'aminoacid': x = self.embedding(z) elif self.level == 'backbone': x = torch.cat([torch.squeeze(F.one_hot(z, num_classes=num_aa_type).float()), bb_embs], dim = 1) x = self.embedding(x) elif self.level == 'allatom': x = torch.cat([torch.squeeze(F.one_hot(z, num_classes=num_aa_type).float()), bb_embs, side_chain_embs], dim = 1) x = self.embedding(x) else: print('No supported model!') edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) pos_emb = self.pos_emb(edge_index, self.num_pos_emb) j, i = edge_index # Calculate distances. dist = (pos[i] - pos[j]).norm(dim=1) num_nodes = len(z) # Calculate angles theta and phi. refi0 = (i-1)%num_nodes refi1 = (i+1)%num_nodes a = ((pos[j] - pos[i]) * (pos[refi0] - pos[i])).sum(dim=-1) b = torch.cross(pos[j] - pos[i], pos[refi0] - pos[i]).norm(dim=-1) theta = torch.atan2(b, a) plane1 = torch.cross(pos[refi0] - pos[i], pos[refi1] - pos[i]) plane2 = torch.cross(pos[refi0] - pos[i], pos[j] - pos[i]) a = (plane1 * plane2).sum(dim=-1) b = (torch.cross(plane1, plane2) * (pos[refi0] - pos[i])).sum(dim=-1) / ((pos[refi0] - pos[i]).norm(dim=-1)) phi = torch.atan2(b, a) feature0 = self.feature0(dist, theta, phi) if self.level == 'backbone' or self.level == 'allatom': # Calculate Euler angles. Or1_x = pos_n[i] - pos[i] Or1_z = torch.cross(Or1_x, torch.cross(Or1_x, pos_c[i] - pos[i])) Or1_z_length = Or1_z.norm(dim=1) + 1e-7 Or2_x = pos_n[j] - pos[j] Or2_z = torch.cross(Or2_x, torch.cross(Or2_x, pos_c[j] - pos[j])) Or2_z_length = Or2_z.norm(dim=1) + 1e-7 Or1_Or2_N = torch.cross(Or1_z, Or2_z) angle1 = torch.atan2((torch.cross(Or1_x, Or1_Or2_N) * Or1_z).sum(dim=-1)/Or1_z_length, (Or1_x * Or1_Or2_N).sum(dim=-1)) angle2 = torch.atan2(torch.cross(Or1_z, Or2_z).norm(dim=-1), (Or1_z * Or2_z).sum(dim=-1)) angle3 = torch.atan2((torch.cross(Or1_Or2_N, Or2_x) * Or2_z).sum(dim=-1)/Or2_z_length, (Or1_Or2_N * Or2_x).sum(dim=-1)) if self.euler_noise: euler_noise = torch.clip(torch.empty(3,len(angle1)).to(device).normal_(mean=0.0, std=0.025), min=-0.1, max=0.1) angle1 += euler_noise[0] angle2 += euler_noise[1] angle3 += euler_noise[2] feature1 = torch.cat((self.feature1(dist, angle1), self.feature1(dist, angle2), self.feature1(dist, angle3)),1) elif self.level == 'aminoacid': refi = (i-1)%num_nodes refj0 = (j-1)%num_nodes refj = (j-1)%num_nodes refj1 = (j+1)%num_nodes mask = refi0 == j refi[mask] = refi1[mask] mask = refj0 == i refj[mask] = refj1[mask] plane1 = torch.cross(pos[j] - pos[i], pos[refi] - pos[i]) plane2 = torch.cross(pos[j] - pos[i], pos[refj] - pos[j]) a = (plane1 * plane2).sum(dim=-1) b = (torch.cross(plane1, plane2) * (pos[j] - pos[i])).sum(dim=-1) / dist tau = torch.atan2(b, a) feature1 = self.feature1(dist, tau) # Interaction blocks. for interaction_block in self.interaction_blocks: if self.data_augment_eachlayer: # add gaussian noise to features gaussian_noise = torch.clip(torch.empty(x.shape).to(device).normal_(mean=0.0, std=0.025), min=-0.1, max=0.1) x += gaussian_noise x = interaction_block(x, feature0, feature1, pos_emb, edge_index, batch) y = scatter(x, batch, dim=0) for lin in self.lins_out: y = self.relu(lin(y)) y = self.dropout(y) y = self.lin_out(y) return y
@property def num_params(self): return sum(p.numel() for p in self.parameters())