# Source code for dig.threedgraph.utils.geometric_computing

```# Based on the code from: https://github.com/klicperajo/dimenet,
# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py

import torch
from torch_scatter import scatter
from torch_sparse import SparseTensor
from math import pi as PI

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

[docs]def xyz_to_dat(pos, edge_index, num_nodes, use_torsion = False):
"""
Compute the diatance, angle, and torsion from geometric information.

Args:
pos: Geometric information for every node in the graph.
edge_index: Edge index of the graph.
number_nodes: Number of nodes in the graph.
use_torsion: If set to :obj:`True`, will return distance, angle and torsion, otherwise only return distance and angle (also retrun some useful index). (default: :obj:`False`)
"""
j, i = edge_index  # j->i

# Calculate distances. # number of edges
dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

value = torch.arange(j.size(0), device=j.device)
adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes))

# Node indices (k->j->i) for triplets.
idx_i = i.repeat_interleave(num_triplets)
idx_j = j.repeat_interleave(num_triplets)

# Edge indices (k-j, j->i) for triplets.

# Calculate angles. 0 to pi
pos_ji = pos[idx_i] - pos[idx_j]
pos_jk = pos[idx_k] - pos[idx_j]
a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk|
b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk|
angle = torch.atan2(b, a)

if use_torsion:
# Prepare torsion idxes.
idx_batch = torch.arange(len(idx_i),device=device)
repeat = num_triplets
idx_i_t = idx_i.repeat_interleave(num_triplets_t)
idx_j_t = idx_j.repeat_interleave(num_triplets_t)
idx_k_t = idx_k.repeat_interleave(num_triplets_t)
idx_batch_t = idx_batch.repeat_interleave(num_triplets_t)

# Calculate torsions.
pos_j0 = pos[idx_k_t] - pos[idx_j_t]
pos_ji = pos[idx_i_t] - pos[idx_j_t]
pos_jk = pos[idx_k_n] - pos[idx_j_t]
dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt()
plane1 = torch.cross(pos_ji, pos_j0)
plane2 = torch.cross(pos_ji, pos_jk)
a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2|
b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji
torsion1 = torch.atan2(b, a) # -pi to pi
torsion1[torsion1<=0]+=2*PI # 0 to 2pi
torsion = scatter(torsion1,idx_batch_t,reduce='min')

return dist, angle, torsion, i, j, idx_kj, idx_ji

else:
return dist, angle, i, j, idx_kj, idx_ji

```