Source code for dig.ggraph.method.JTVAE.jtvae

import math
import sys
from optparse import OptionParser

import numpy as np
import rdkit
from rdkit import RDLogger, Chem
from rdkit.Chem import Descriptors
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from dig.ggraph.method import Generator

from . import fast_jtnn


[docs]class JTVAE(Generator): r""" The method class for the JTVAE algorithm proposed in the paper `Junction Tree Variational Autoencoder for Molecular Graph Generation <https://arxiv.org/abs/1802.04364>`_. This class provides interfaces for running random generation with the JTVAE algorithm. Please refer to the `example codes <https://github.com/divelab/DIG/tree/dig-stable/examples/ggraph/JTVAE>`_ for usage examples. Args: list_smiles (list): List of smiles in training data. training (boolean): If we are training (as opposed to testing). build_vocab (boolean): If we need to build the vocabulary (first time training with this dataset). device (torch.device, optional): The device where the model is deployed. """ def __init__(self, list_smiles, build_vocab=True, device=None): super().__init__() if build_vocab: self.vocab = self.build_vocabulary(list_smiles) self.model = None def get_model(self, task, config_dict): if task == 'rand_gen': #hidden_size, latent_size, depthT, depthG self.vae = fast_jtnn.JTNNVAE( fast_jtnn.Vocab(self.vocab), **config_dict).cuda() elif task == 'cons_optim': self.prop_vae = jtnn.JTPropVAE( jtnn.Vocab(self.vocab), **config_dict).cuda() else: raise ValueError('Task {} is not supported in JTVAE!'.format(task))
[docs] def build_vocabulary(self, list_smiles): r""" Building the vocabulary for training. Args: list_smiles (list): the list of smiles strings in the dataset. :rtype: cset (list): A list of smiles that contains the vocabulary for the training data. """ cset = set() for smiles in list_smiles: mol = fast_jtnn.MolTree(smiles) for c in mol.nodes: cset.add(c.smiles) # cset_newline = list(map(lambda x: x + "\n", cset)) return list(cset)
def _tensorize(self, smiles, assm=True): mol_tree = fast_jtnn.MolTree(smiles) mol_tree.recover() if assm: mol_tree.assemble() for node in mol_tree.nodes: if node.label not in node.cands: node.cands.append(node.label) del mol_tree.mol for node in mol_tree.nodes: del node.mol return mol_tree
[docs] def preprocess(self, list_smiles): r""" Preprocess the molecules. Args: list_smiles (list): The list of smiles strings in the dataset. :rtype: preprocessed (list): A list of preprocessed MolTree objects. """ # TODO In the future, this can be parallelized preprocessed = list( map(self._tensorize, tqdm(list_smiles, leave=True))) return preprocessed
[docs] def train_rand_gen(self, loader, load_epoch, lr, anneal_rate, clip_norm, num_epochs, beta, max_beta, step_beta, anneal_iter, kl_anneal_iter, print_iter, save_iter): r""" Train the Junction Tree Variational Autoencoder for the random generation task. Args: loader (MolTreeFolder): The MolTreeFolder loader. load_epoch (int): The epoch to load from state dictionary. lr (float): The learning rate for training. anneal_rate (float): The learning rate annealing. clip_norm (float): Clips gradient norm of an iterable of parameters. num_epochs (int): The number of training epochs. beta (float): The KL regularization weight. max_beta (float): The maximum KL regularization weight. step_beta (float): The KL regularization weight step size. anneal_iter (int): How often to step in annealing the learning rate. kl_anneal_iter (int): How often to step in annealing the KL regularization weight. print_iter (int): How often to print the iteration statistics. save_iter (int): How often to save the iteration statistics. """ vocab = fast_jtnn.Vocab(self.vocab) for param in self.vae.parameters(): if param.dim() == 1: nn.init.constant_(param, 0) else: nn.init.xavier_normal_(param) if 0 > 0: self.vae.load_state_dict(torch.load( 'checkpoints' + "/model.iter-" + str(load_epoch))) print("Model #Params: %dK" % (sum([x.nelement() for x in self.vae.parameters()]) / 1000,)) optimizer = optim.Adam(self.vae.parameters(), lr=lr) scheduler = lr_scheduler.ExponentialLR(optimizer, anneal_rate) scheduler.step() def param_norm(m): return math.sqrt( sum([p.norm().item() ** 2 for p in m.parameters()])) def grad_norm(m): return math.sqrt( sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) total_step = load_epoch meters = np.zeros(4) for epoch in range(num_epochs): for batch in loader: total_step += 1 try: self.vae.zero_grad() loss, kl_div, wacc, tacc, sacc = self.vae(batch, beta) loss.backward() nn.utils.clip_grad_norm_(self.vae.parameters(), clip_norm) optimizer.step() except Exception as e: print(e) continue meters = meters + \ np.array([kl_div, wacc * 100, tacc * 100, sacc * 100]) if total_step % print_iter == 0: meters /= 50 print("[%d] Beta: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % ( total_step, beta, meters[0], meters[1], meters[2], meters[3], param_norm(self.vae), grad_norm(self.vae))) sys.stdout.flush() meters *= 0 if total_step % save_iter == 0: torch.save(self.vae.state_dict(), "saved" + "/model.iter-" + str(total_step)) if total_step % anneal_iter == 0: scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) if total_step % kl_anneal_iter == 0 and total_step >= anneal_iter: beta = min(max_beta, beta + step_beta)
[docs] def run_rand_gen(self, num_samples): r""" Sample new molecules from the trained model. Args: num_samples (int): Number of samples to generate from the trained model. :rtype: samples (list): samples is a list of generated molecules. """ torch.manual_seed(0) samples = [self.vae.sample_prior() for _ in range(num_samples)] return samples
[docs] def train_cons_optim(self, loader, batch_size, num_epochs, hidden_size, latent_size, depth, beta, lr): r""" Train the Junction Tree Variational Autoencoder for the constrained optimization task. Args: loader (MolTreeFolder): The MolTreeFolder loader. batch_size (int): The batch size. num_epochs (int): The number of epochs. hidden_size (int): The hidden size. latent_size (int): The latent size. depth (int): The depth of the network. lr (float): The learning rate for training. beta (float): The KL regularization weight. """ for param in self.prop_vae.parameters(): if param.dim() == 1: nn.init.constant(param, 0) else: nn.init.xavier_normal(param) print("Model #Params: %dK" % (sum([x.nelement() for x in self.prop_vae.parameters()]) / 1000,)) optimizer = optim.Adam(self.prop_vae.parameters(), lr=lr) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) scheduler.step() PRINT_ITER = 20 for epoch in range(num_epochs): word_acc, topo_acc, assm_acc, steo_acc, prop_acc = 0, 0, 0, 0, 0 for it, batch in enumerate(tqdm(loader)): for mol_tree, _ in batch: for node in mol_tree.nodes: if node.label not in node.cands: node.cands.append(node.label) node.cand_mols.append(node.label_mol) self.prop_vae.zero_grad() loss, kl_div, wacc, tacc, sacc, dacc, pacc = self.prop_vae( batch, beta) loss.backward() optimizer.step() word_acc += wacc topo_acc += tacc assm_acc += sacc steo_acc += dacc prop_acc += pacc if (it + 1) % PRINT_ITER == 0: word_acc = word_acc / PRINT_ITER * 100 topo_acc = topo_acc / PRINT_ITER * 100 assm_acc = assm_acc / PRINT_ITER * 100 steo_acc = steo_acc / PRINT_ITER * 100 prop_acc /= PRINT_ITER print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Prop: %.4f" % ( kl_div, word_acc, topo_acc, assm_acc, steo_acc, prop_acc)) word_acc, topo_acc, assm_acc, steo_acc, prop_acc = 0, 0, 0, 0, 0 sys.stdout.flush() if (it + 1) % 1500 == 0: # Fast annealing scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) # torch.save(self.prop_vae.state_dict(), opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1)) scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0])
# torch.save(self.prop_vae.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
[docs] def run_cons_optim(self, list_smiles, sim_cutoff=0.0): r""" Optimize a set of molecules. Args: list_smiles (list): List of smiles in training data. sim_cutoff (float): Simulation cutoff. """ # data = [] # with open(opts.test_path) as f: # for line in f: # s = line.strip("\r\n ").split()[0] # data.append(s) res = [] for smiles in tqdm(list_smiles): mol = Chem.MolFromSmiles(smiles) score = Descriptors.MolLogP( mol) - jtnn.sascorer.calculateScore(mol) new_smiles, sim = self.prop_vae.optimize( smiles, sim_cutoff=sim_cutoff, lr=2, num_iter=80) new_mol = Chem.MolFromSmiles(new_smiles) new_score = Descriptors.MolLogP( new_mol) - jtnn.sascorer.calculateScore(new_mol) res.append((new_score - score, sim, score, new_score, smiles, new_smiles)) print(new_score - score, sim, score, new_score, smiles, new_smiles) print(sum([x[0] for x in res]), sum([x[1] for x in res]))