Source code for dig.ggraph.utils.gen_mol_from_one_shot_tensor

### Code adapted from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow
import re
import numpy as np
from rdkit import Chem



ATOM_VALENCY = {6:4, 7:3, 8:2, 9:1, 15:3, 16:2, 17:1, 35:1, 53:1}
bond_decoder_m = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}


def construct_mol(x, A, atomic_num_list):
    mol = Chem.RWMol()
    atoms = np.argmax(x, axis=1)
    atoms_exist = atoms != len(atomic_num_list) - 1
    atoms = atoms[atoms_exist]

    for atom in atoms:
        mol.AddAtom(Chem.Atom(int(atomic_num_list[atom])))

    # A (edge_type, num_node, num_node)
    adj = np.argmax(A, axis=0)
    adj = np.array(adj)
    adj = adj[atoms_exist, :][:, atoms_exist]
    adj[adj == 3] = -1
    adj += 1
    for start, end in zip(*np.nonzero(adj)):
        if start > end:
            mol.AddBond(int(start), int(end), bond_decoder_m[adj[start, end]])
            # add formal charge to atom: e.g. [O+], [N+] [S+]
            # not support [O-], [N-] [S-]  [NH+] etc.
            flag, atomid_valence = check_valency(mol)
            if flag:
                continue
            else:
                assert len(atomid_valence) == 2
                idx = atomid_valence[0]
                v = atomid_valence[1]
                an = mol.GetAtomWithIdx(idx).GetAtomicNum()
                if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
                    mol.GetAtomWithIdx(idx).SetFormalCharge(1)
    return mol


def check_valency(mol):
    """
    Checks that no atoms in the mol have exceeded their possible
    valency
    :return: True if no valency issues, False otherwise
    """
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
        return True, None
    except ValueError as e:
        e = str(e)
        p = e.find('#')
        e_sub = e[p:]
        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
        return False, atomid_valence

    
def correct_mol(x):
    mol = x
    while True:
        flag, atomid_valence = check_valency(mol)
        if flag:
            break
        else:
            assert len (atomid_valence) == 2
            idx = atomid_valence[0]
            queue = []
            for b in mol.GetAtomWithIdx(idx).GetBonds():
                queue.append(
                    (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
                )
            queue.sort(key=lambda tup: tup[1], reverse=True)
            if len(queue) > 0:
                start = queue[0][2]
                end = queue[0][3]
                t = queue[0][1] - 1
                mol.RemoveBond(start, end)
                if t >= 1:
                    mol.AddBond(start, end, bond_decoder_m[t])
    return mol


def valid_mol_can_with_seg(x, largest_connected_comp=True):
    # mol = None
    if x is None:
        return None
    sm = Chem.MolToSmiles(x, isomericSmiles=True)
    mol = Chem.MolFromSmiles(sm)
    if largest_connected_comp and '.' in sm:
        vsm = [(s, len(s)) for s in sm.split('.')]  # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
        vsm.sort(key=lambda tup: tup[1], reverse=True)
        mol = Chem.MolFromSmiles(vsm[0][0])
    return mol
    
    
[docs]def gen_mol_from_one_shot_tensor(adj, x, atomic_num_list, correct_validity=True, largest_connected_comp=True): r"""Construct molecules from the node tensors and adjacency tensors generated by one-shot molecular graph generation methods. Args: adj (Tensor): The adjacency tensor with shape [:obj:`number of samples`, :obj:`number of possible bond types`, :obj:`maximum number of atoms`, :obj:`maximum number of atoms`]. x (Tensor): The node tensor with shape [:obj:`number of samples`, :obj:`number of possible atom types`, :obj:`maximum number of atoms`]. atomic_num_list (list): A list to specify what atom each channel of the 2nd dimension of :obj: `x` corresponds to. correct_validity (bool, optional): Whether to use the validity correction introduced by the paper `MoFlow: an invertible flow model for generating molecular graphs <https://arxiv.org/pdf/2006.10137.pdf>`_. (default: :obj:`True`) largest_connected_comp (bool, optional): Whether to use the largest connected component as the final molecule in the validity correction.(default: :obj:`True`) :rtype: A list of rdkit mol object. The length of the list is :obj:`number of samples`. Examples -------- >>> adj = torch.rand(2, 4, 38, 38) >>> x = torch.rand(2, 10, 38) >>> atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0] >>> gen_mols = gen_mol_from_one_shot_tensor(adj, x, atomic_num_list) >>> gen_mols [<rdkit.Chem.rdchem.Mol>, <rdkit.Chem.rdchem.Mol>] """ x = x.permute(0, 2, 1) adj = adj.cpu().detach().numpy() x = x.cpu().detach().numpy() if not correct_validity: gen_mols = [construct_mol(x_elem, adj_elem, atomic_num_list) for x_elem, adj_elem in zip(x, adj)] else: gen_mols = [] for x_elem, adj_elem in zip(x, adj): mol = construct_mol(x_elem, adj_elem, atomic_num_list) cmol = correct_mol(mol) vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=largest_connected_comp) gen_mols.append(vcmol) return gen_mols