### Code adapted from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow
import re
import numpy as np
from rdkit import Chem
ATOM_VALENCY = {6:4, 7:3, 8:2, 9:1, 15:3, 16:2, 17:1, 35:1, 53:1}
bond_decoder_m = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}
def construct_mol(x, A, atomic_num_list):
mol = Chem.RWMol()
atoms = np.argmax(x, axis=1)
atoms_exist = atoms != len(atomic_num_list) - 1
atoms = atoms[atoms_exist]
for atom in atoms:
mol.AddAtom(Chem.Atom(int(atomic_num_list[atom])))
# A (edge_type, num_node, num_node)
adj = np.argmax(A, axis=0)
adj = np.array(adj)
adj = adj[atoms_exist, :][:, atoms_exist]
adj[adj == 3] = -1
adj += 1
for start, end in zip(*np.nonzero(adj)):
if start > end:
mol.AddBond(int(start), int(end), bond_decoder_m[adj[start, end]])
# add formal charge to atom: e.g. [O+], [N+] [S+]
# not support [O-], [N-] [S-] [NH+] etc.
flag, atomid_valence = check_valency(mol)
if flag:
continue
else:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
return mol
def check_valency(mol):
"""
Checks that no atoms in the mol have exceeded their possible
valency
:return: True if no valency issues, False otherwise
"""
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence
def correct_mol(x):
mol = x
while True:
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
assert len (atomid_valence) == 2
idx = atomid_valence[0]
queue = []
for b in mol.GetAtomWithIdx(idx).GetBonds():
queue.append(
(b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
)
queue.sort(key=lambda tup: tup[1], reverse=True)
if len(queue) > 0:
start = queue[0][2]
end = queue[0][3]
t = queue[0][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_decoder_m[t])
return mol
def valid_mol_can_with_seg(x, largest_connected_comp=True):
# mol = None
if x is None:
return None
sm = Chem.MolToSmiles(x, isomericSmiles=True)
mol = Chem.MolFromSmiles(sm)
if largest_connected_comp and '.' in sm:
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
vsm.sort(key=lambda tup: tup[1], reverse=True)
mol = Chem.MolFromSmiles(vsm[0][0])
return mol
[docs]def gen_mol_from_one_shot_tensor(adj, x, atomic_num_list, correct_validity=True, largest_connected_comp=True):
r"""Construct molecules from the node tensors and adjacency tensors generated by one-shot molecular graph generation methods.
Args:
adj (Tensor): The adjacency tensor with shape [:obj:`number of samples`, :obj:`number of possible bond types`, :obj:`maximum number of atoms`, :obj:`maximum number of atoms`].
x (Tensor): The node tensor with shape [:obj:`number of samples`, :obj:`number of possible atom types`, :obj:`maximum number of atoms`].
atomic_num_list (list): A list to specify what atom each channel of the 2nd dimension of :obj: `x` corresponds to.
correct_validity (bool, optional): Whether to use the validity correction introduced by the paper `MoFlow: an invertible flow model for generating molecular graphs <https://arxiv.org/pdf/2006.10137.pdf>`_. (default: :obj:`True`)
largest_connected_comp (bool, optional): Whether to use the largest connected component as the final molecule in the validity correction.(default: :obj:`True`)
:rtype: A list of rdkit mol object. The length of the list is :obj:`number of samples`.
Examples
--------
>>> adj = torch.rand(2, 4, 38, 38)
>>> x = torch.rand(2, 10, 38)
>>> atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
>>> gen_mols = gen_mol_from_one_shot_tensor(adj, x, atomic_num_list)
>>> gen_mols
[<rdkit.Chem.rdchem.Mol>, <rdkit.Chem.rdchem.Mol>]
"""
x = x.permute(0, 2, 1)
adj = adj.cpu().detach().numpy()
x = x.cpu().detach().numpy()
if not correct_validity:
gen_mols = [construct_mol(x_elem, adj_elem, atomic_num_list) for x_elem, adj_elem in zip(x, adj)]
else:
gen_mols = []
for x_elem, adj_elem in zip(x, adj):
mol = construct_mol(x_elem, adj_elem, atomic_num_list)
cmol = correct_mol(mol)
vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=largest_connected_comp)
gen_mols.append(vcmol)
return gen_mols