Source code for dig.xgraph.dataset.mol_dataset

import os
import re
import torch
import numpy as np
from torch_geometric.utils import dense_to_sparse
from import Data, InMemoryDataset, \

    from rdkit import Chem
except ImportError:
    Chem = None
import os.path as osp
import zipfile
import gzip

x_map = {
    list(range(0, 119)),
    'chirality': [
    list(range(0, 11)),
    list(range(-5, 7)),
    list(range(0, 9)),
    list(range(0, 5)),
    'hybridization': [
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],

e_map = {
    'bond_type': [
    'stereo': [
    'is_conjugated': [False, True],

def maybe_log(path, log=True):
    if log:
        print('Extracting', path)

def extract_gz(path, folder, log=True):
    maybe_log(path, log)
    with, 'r') as r:
        with open(osp.join(folder, '.'.join(os.path.basename(path).split('.')[:-1])), 'wb') as w:

def extract_zip(path, folder, log=True):
    r"""Extracts a zip archive to a specific folder.

        path (string): The path to the tar archive.
        folder (string): The folder.
        log (bool, optional): If :obj:`False`, will not print anything to the
            console. (default: :obj:`True`)
    maybe_log(path, log)
    with zipfile.ZipFile(path, 'r') as f:

[docs]class MoleculeDataset(InMemoryDataset): r""" The extension of MoleculeNet with `MUTAG <>`_. The `MoleculeNet benchmark collection <>`_ from the `MoleculeNet: A Benchmark for Molecular Machine Learning <>`_ paper, containing datasets from physical chemistry, biophysics and physiology. The MoleculeNet datasets come with the additional node and edge features introduced by the `Open Graph Benchmark <>`_, and the node features in MUTAG dataset are one hot features denoting the atom types. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (:obj:`"MUTAG"`, :obj:`"ESOL"`, :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`, :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`, :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`). transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = '{}' mutag_url = '{}' # Format: name: [display_name, url_name, filename, smiles_idx, y_idx] names = { 'mutag': ['MUTAG', '', None, None], 'esol': ['ESOL', 'delaney-processed.csv', 'delaney-processed.csv', -1, -2], 'freesolv': ['FreeSolv', 'SAMPL.csv', 'SAMPL.csv', 1, 2], 'lipo': ['Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity.csv', 2, 1], 'pcba': ['PCBA', 'pcba.csv.gz', 'pcba.csv', -1, slice(0, 128)], 'muv': ['MUV', 'muv.csv.gz', 'muv.csv', -1, slice(0, 17)], 'hiv': ['HIV', 'HIV.csv', 'HIV.csv', 0, -1], 'bace': ['BACE', 'bace.csv', 'bace.csv', 0, 2], 'bbbp': ['BBBP', 'BBBP.csv', 'BBBP.csv', -1, -2], 'tox21': ['Tox21', 'tox21.csv.gz', 'tox21.csv', -1, slice(0, 12)], 'toxcast': ['ToxCast', 'toxcast_data.csv.gz', 'toxcast_data.csv', 0, slice(1, 618)], 'sider': ['SIDER', 'sider.csv.gz', 'sider.csv', 0, slice(1, 28)], 'clintox': ['ClinTox', 'clintox.csv.gz', 'clintox.csv', 0, slice(1, 3)], } def __init__(self, root, name, transform=None, pre_transform=None, pre_filter=None): if Chem is None: raise ImportError('`MoleculeNet` requires `rdkit`.') = name.lower() assert in self.names.keys() super(MoleculeDataset, self).__init__(root, transform, pre_transform, pre_filter), self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self): return osp.join(self.root,, 'raw') @property def processed_dir(self): return osp.join(self.root,, 'processed') @property def raw_file_names(self): if == 'MUTAG'.lower(): return ['MUTAG_A.txt', 'MUTAG_graph_labels.txt', 'MUTAG_graph_indicator.txt', 'MUTAG_node_labels.txt', 'README.txt'] else: return self.names[][2] @property def processed_file_names(self): return ''
[docs] def download(self): if == 'MUTAG'.lower(): url = self.mutag_url.format(self.names[][1]) else: url = self.url.format(self.names[][1]) path = download_url(url, self.raw_dir) if self.names[][1][-2:] == 'gz': extract_gz(path, self.raw_dir) os.unlink(path) elif self.names[][1][-3:] == 'zip': extract_zip(path, self.raw_dir) os.unlink(path)
[docs] def process(self): if == 'MUTAG'.lower(): with open(os.path.join(self.raw_dir, 'MUTAG_node_labels.txt'), 'r') as f: nodes_all_temp = nodes_all = [int(i) for i in nodes_all_temp] adj_all = np.zeros((len(nodes_all), len(nodes_all))) with open(os.path.join(self.raw_dir, 'MUTAG_A.txt'), 'r') as f: adj_list = for item in adj_list: lr = item.split(', ') l = int(lr[0]) r = int(lr[1]) adj_all[l - 1, r - 1] = 1 with open(os.path.join(self.raw_dir, 'MUTAG_graph_indicator.txt'), 'r') as f: graph_indicator_temp = graph_indicator = [int(i) for i in graph_indicator_temp] graph_indicator = np.array(graph_indicator) with open(os.path.join(self.raw_dir, 'MUTAG_graph_labels.txt'), 'r') as f: graph_labels_temp = graph_labels = [int(i) for i in graph_labels_temp] data_list = [] for i in range(1, 189): idx = np.where(graph_indicator == i) graph_len = len(idx[0]) adj = adj_all[idx[0][0]:idx[0][0] + graph_len, idx[0][0]:idx[0][0] + graph_len] label = int(graph_labels[i - 1] == 1) feature = nodes_all[idx[0][0]:idx[0][0] + graph_len] nb_clss = 7 targets = np.array(feature).reshape(-1) one_hot_feature = np.eye(nb_clss)[targets] data_example = Data(x=torch.from_numpy(one_hot_feature).float(), edge_index=dense_to_sparse(torch.from_numpy(adj))[0], y=label) data_list.append(data_example) else: with open(self.raw_paths[0], 'r') as f: dataset ='\n')[1:-1] dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. data_list = [] for line in dataset: line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. line = line.split(',') smiles = line[self.names[][3]] ys = line[self.names[][4]] ys = ys if isinstance(ys, list) else [ys] ys = [float(y) if len(y) > 0 else float('NaN') for y in ys] y = torch.tensor(ys, dtype=torch.float).view(1, -1) mol = Chem.MolFromSmiles(smiles) if mol is None: continue xs = [] for atom in mol.GetAtoms(): x = [] x.append(x_map['atomic_num'].index(atom.GetAtomicNum())) x.append(x_map['chirality'].index(str(atom.GetChiralTag()))) x.append(x_map['degree'].index(atom.GetTotalDegree())) x.append(x_map['formal_charge'].index(atom.GetFormalCharge())) x.append(x_map['num_hs'].index(atom.GetTotalNumHs())) x.append(x_map['num_radical_electrons'].index( atom.GetNumRadicalElectrons())) x.append(x_map['hybridization'].index( str(atom.GetHybridization()))) x.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) x.append(x_map['is_in_ring'].index(atom.IsInRing())) xs.append(x) x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() e = [] e.append(e_map['bond_type'].index(str(bond.GetBondType()))) e.append(e_map['stereo'].index(str(bond.GetStereo()))) e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) edge_indices += [[i, j], [j, i]] edge_attrs += [e, e] edge_index = torch.tensor(edge_indices) edge_index = edge_index.t().to(torch.long).view(2, -1) edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) # Sort indices. if edge_index.numel() > 0: perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, smiles=smiles) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data), self.processed_paths[0])
def __repr__(self): return '{}({})'.format(self.names[][0], len(self))
if __name__ == '__main__': dataset = MoleculeDataset(root='.', name='MUTAG')