Source code for dig.threedgraph.utils.geometric_computing

# Based on the code from: https://github.com/klicperajo/dimenet,
# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py

import torch
from torch_scatter import scatter
from torch_sparse import SparseTensor
from math import pi as PI

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


[docs]def xyz_to_dat(pos, edge_index, num_nodes, use_torsion = False): """ Compute the diatance, angle, and torsion from geometric information. Args: pos: Geometric information for every node in the graph. edge_index: Edge index of the graph. number_nodes: Number of nodes in the graph. use_torsion: If set to :obj:`True`, will return distance, angle and torsion, otherwise only return distance and angle (also retrun some useful index). (default: :obj:`False`) """ j, i = edge_index # j->i # Calculate distances. # number of edges dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() value = torch.arange(j.size(0), device=j.device) adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes)) adj_t_row = adj_t[j] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) # Node indices (k->j->i) for triplets. idx_i = i.repeat_interleave(num_triplets) idx_j = j.repeat_interleave(num_triplets) idx_k = adj_t_row.storage.col() mask = idx_i != idx_k idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] # Edge indices (k-j, j->i) for triplets. idx_kj = adj_t_row.storage.value()[mask] idx_ji = adj_t_row.storage.row()[mask] # Calculate angles. 0 to pi pos_ji = pos[idx_i] - pos[idx_j] pos_jk = pos[idx_k] - pos[idx_j] a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk| angle = torch.atan2(b, a) if use_torsion: # Prepare torsion idxes. idx_batch = torch.arange(len(idx_i),device=device) idx_k_n = adj_t[idx_j].storage.col() repeat = num_triplets num_triplets_t = num_triplets.repeat_interleave(repeat)[mask] idx_i_t = idx_i.repeat_interleave(num_triplets_t) idx_j_t = idx_j.repeat_interleave(num_triplets_t) idx_k_t = idx_k.repeat_interleave(num_triplets_t) idx_batch_t = idx_batch.repeat_interleave(num_triplets_t) mask = idx_i_t != idx_k_n idx_i_t, idx_j_t, idx_k_t, idx_k_n, idx_batch_t = idx_i_t[mask], idx_j_t[mask], idx_k_t[mask], idx_k_n[mask], idx_batch_t[mask] # Calculate torsions. pos_j0 = pos[idx_k_t] - pos[idx_j_t] pos_ji = pos[idx_i_t] - pos[idx_j_t] pos_jk = pos[idx_k_n] - pos[idx_j_t] dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt() plane1 = torch.cross(pos_ji, pos_j0) plane2 = torch.cross(pos_ji, pos_jk) a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2| b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji torsion1 = torch.atan2(b, a) # -pi to pi torsion1[torsion1<=0]+=2*PI # 0 to 2pi torsion = scatter(torsion1,idx_batch_t,reduce='min') return dist, angle, torsion, i, j, idx_kj, idx_ji else: return dist, angle, i, j, idx_kj, idx_ji