import time
import os
import copy

import torch
from torch.optim import Adam
from tqdm import tqdm
from rdkit import Chem

from dig.ggraph.method import Generator
from dig.ggraph.utils import gen_mol_from_one_shot_tensor
from dig.ggraph.utils import qed, calculate_min_plogp, reward_target_molecule_similarity
from .energy_func import EnergyFunc
from .util import rescale_adj, requires_grad, clip_grad

[docs]class GraphEBM(Generator): r""" The method class for GraphEBM algorithm proposed in the paper `GraphEBM: Molecular Graph Generation with Energy-Based Models <>`_. This class provides interfaces for running random generation, goal-directed generation (including property optimization and constrained optimization), and compositional generation with GraphEBM algorithm. Please refer to the `benchmark codes <>`_ for usage examples. Args: n_atom (int): Maximum number of atoms. n_atom_type (int): Number of possible atom types. n_edge_type (int): Number of possible bond types. hidden (int): Hidden dimensions. device (torch.device, optional): The device where the model is deployed. """ def __init__(self, n_atom, n_atom_type, n_edge_type, hidden, device=None): super(GraphEBM, self).__init__() if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device self.energy_function = EnergyFunc(n_atom_type, hidden, n_edge_type).to(self.device) self.n_atom = n_atom self.n_atom_type = n_atom_type self.n_edge_type = n_edge_type
[docs] def train_rand_gen(self, loader, lr, wd, max_epochs, c, ld_step, ld_noise, ld_step_size, clamp, alpha, save_interval, save_dir): r""" Running training for random generation task. Args: loader: The data loader for loading training samples. It is supposed to use dig.ggraph.dataset.QM9/ZINC250k as the dataset class, and apply to it to form the data loader. lr (float): The learning rate for training. wd (float): The weight decay factor for training. max_epochs (int): The maximum number of training epochs. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. alpha (float): The weight coefficient for loss function. save_interval (int): The frequency to save the model parameters to .pt files, *e.g.*, if save_interval=2, the model parameters will be saved for every 2 training epochs. save_dir (str): the directory to save the model parameters. """ parameters = self.energy_function.parameters() optimizer = Adam(parameters, lr=lr, betas=(0.0, 0.999), weight_decay=wd) if not os.path.exists(save_dir): os.makedirs(save_dir) for epoch in range(max_epochs): t_start = time.time() losses_reg = [] losses_en = [] losses = [] for _, batch in enumerate(tqdm(loader)): ### Dequantization pos_x = pos_x += c * torch.rand_like(pos_x, device=self.device) pos_adj = pos_adj += c * torch.rand_like(pos_adj, device=self.device) ### Langevin dynamics neg_x = torch.rand_like(pos_x, device=self.device) * (1 + c) neg_adj = torch.rand_like(pos_adj, device=self.device) pos_adj = rescale_adj(pos_adj) neg_x.requires_grad = True neg_adj.requires_grad = True requires_grad(parameters, False) self.energy_function.eval() noise_x = torch.randn_like(neg_x, device=self.device) noise_adj = torch.randn_like(neg_adj, device=self.device) for _ in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) neg_out = self.energy_function(neg_adj, neg_x) neg_out.sum().backward() if clamp:, 0.01), 0.01), alpha=ld_step_size), alpha=ld_step_size) neg_x.grad.detach_() neg_x.grad.zero_() neg_adj.grad.detach_() neg_adj.grad.zero_(), 1 + c), 1) ### Training by backprop neg_x = neg_x.detach() neg_adj = neg_adj.detach() requires_grad(parameters, True) self.energy_function.train() self.energy_function.zero_grad() pos_out = self.energy_function(pos_adj, pos_x) neg_out = self.energy_function(neg_adj, neg_x) loss_reg = (pos_out ** 2 + neg_out ** 2) # energy magnitudes regularizer loss_en = pos_out - neg_out # loss for shaping energy function loss = loss_en + alpha * loss_reg loss = loss.mean() loss.backward() clip_grad(optimizer) optimizer.step() losses_reg.append(loss_reg.mean()) losses_en.append(loss_en.mean()) losses.append(loss) t_end = time.time() ### Save checkpoints if (epoch+1) % save_interval == 0:, os.path.join(save_dir, 'epoch_{}.pt'.format(epoch + 1))) print('Saving checkpoint at epoch ', epoch+1) print('==========================================') print('Epoch: {:03d}, Loss: {:.6f}, Energy Loss: {:.6f}, Regularizer Loss: {:.6f}, Sec/Epoch: {:.2f}'.format(epoch+1, (sum(losses)/len(losses)).item(), (sum(losses_en)/len(losses_en)).item(), (sum(losses_reg)/len(losses_reg)).item(), t_end-t_start)) print('==========================================')
[docs] def run_rand_gen(self, checkpoint_path, n_samples, c, ld_step, ld_noise, ld_step_size, clamp, atomic_num_list): r""" Running graph generation for random generation task. Args: checkpoint_path (str): The path of the trained model, *i.e.*, the .pt file. n_samples (int): the number of molecules to generate. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. atomic_num_list (list): The list used to indicate atom types. :rtype: gen_mols (list): A list of generated molecules represented by rdkit Chem.Mol objects; """ print("Loading paramaters from {}".format(checkpoint_path)) self.energy_function.load_state_dict(torch.load(checkpoint_path)) parameters = self.energy_function.parameters() ### Initialization print("Initializing samples...") gen_x = torch.rand(n_samples, self.n_atom_type, self.n_atom, device=self.device) * (1 + c) gen_adj = torch.rand(n_samples, self.n_edge_type, self.n_atom, self.n_atom, device=self.device) gen_x.requires_grad = True gen_adj.requires_grad = True requires_grad(parameters, False) self.energy_function.eval() noise_x = torch.randn_like(gen_x, device=self.device) noise_adj = torch.randn_like(gen_adj, device=self.device) ### Langevin dynamics print("Generating samples...") for _ in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) gen_out = self.energy_function(gen_adj, gen_x) gen_out.sum().backward() if clamp:, 0.01), 0.01), alpha=-ld_step_size), alpha=-ld_step_size) gen_x.grad.detach_() gen_x.grad.zero_() gen_adj.grad.detach_() gen_adj.grad.zero_(), 1 + c), 1) gen_x = gen_x.detach() gen_adj = gen_adj.detach() gen_adj = (gen_adj + gen_adj.permute(0, 1, 3, 2)) / 2 gen_mols = gen_mol_from_one_shot_tensor(gen_adj, gen_x, atomic_num_list, correct_validity=True) return gen_mols
[docs] def train_goal_directed(self, loader, lr, wd, max_epochs, c, ld_step, ld_noise, ld_step_size, clamp, alpha, save_interval, save_dir): r""" Running training for goal-directed generation task. Args: loader: The data loader for loading training samples. It is supposed to use dig.ggraph.dataset.QM9/ZINC250k as the dataset class, and apply to it to form the data loader. lr (float): The learning rate for training. wd (float): The weight decay factor for training. max_epochs (int): The maximum number of training epochs. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. alpha (float): The weight coefficient for loss function. save_interval (int): The frequency to save the model parameters to .pt files, *e.g.*, if save_interval=2, the model parameters will be saved for every 2 training epochs. save_dir (str): the directory to save the model parameters. """ parameters = self.energy_function.parameters() optimizer = Adam(parameters, lr=lr, betas=(0.0, 0.999), weight_decay=wd) if not os.path.exists(save_dir): os.makedirs(save_dir) for epoch in range(max_epochs): t_start = time.time() losses_reg = [] losses_en = [] losses = [] for _, batch in enumerate(tqdm(loader)): ### Dequantization pos_x = pos_x += c * torch.rand_like(pos_x, device=self.device) pos_adj = pos_adj += c * torch.rand_like(pos_adj, device=self.device) pos_y = ### Langevin dynamics neg_x = torch.rand_like(pos_x, device=self.device) * (1 + c) neg_adj = torch.rand_like(pos_adj, device=self.device) pos_adj = rescale_adj(pos_adj) neg_x.requires_grad = True neg_adj.requires_grad = True requires_grad(parameters, False) self.energy_function.eval() noise_x = torch.randn_like(neg_x, device=self.device) noise_adj = torch.randn_like(neg_adj, device=self.device) for _ in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) neg_out = self.energy_function(neg_adj, neg_x) neg_out.sum().backward() if clamp:, 0.01), 0.01), alpha=ld_step_size), alpha=ld_step_size) neg_x.grad.detach_() neg_x.grad.zero_() neg_adj.grad.detach_() neg_adj.grad.zero_(), 1 + c), 1) ### Training by backprop neg_x = neg_x.detach() neg_adj = neg_adj.detach() requires_grad(parameters, True) self.energy_function.train() self.energy_function.zero_grad() pos_out = self.energy_function(pos_adj, pos_x) neg_out = self.energy_function(neg_adj, neg_x) loss_reg = (pos_out ** 2 + neg_out ** 2) # energy magnitudes regularizer loss_en = (1 + torch.exp(pos_y)) * pos_out - neg_out # loss for shaping energy function loss = loss_en + alpha * loss_reg loss = loss.mean() loss.backward() clip_grad(optimizer) optimizer.step() losses_reg.append(loss_reg.mean()) losses_en.append(loss_en.mean()) losses.append(loss) t_end = time.time() ### Save checkpoints if (epoch+1) % save_interval == 0:, os.path.join(save_dir, 'epoch_{}.pt'.format(epoch + 1))) print('Saving checkpoint at epoch ', epoch+1) print('==========================================') print('Epoch: {:03d}, Loss: {:.6f}, Energy Loss: {:.6f}, Regularizer Loss: {:.6f}, Sec/Epoch: {:.2f}'.format(epoch+1, (sum(losses)/len(losses)).item(), (sum(losses_en)/len(losses_en)).item(), (sum(losses_reg)/len(losses_reg)).item(), t_end-t_start)) print('==========================================')
[docs] def run_prop_opt(self, checkpoint_path, initialization_loader, c, ld_step, ld_noise, ld_step_size, clamp, atomic_num_list, train_smiles): r""" Running graph generation for goal-directed generation task: property optimization. Args: checkpoint_path (str): The path of the trained model, *i.e.*, the .pt file. initialization_loader: The data loader for loading samples to initialize the Langevin dynamics. It is supposed to use dig.ggraph.dataset.QM9/ZINC250k as the dataset class, and apply to it to form the data loader. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. atomic_num_list (list): The list used to indicate atom types. train_smiles (list): A list of smiles string corresponding to training samples. :rtype: save_mols_list (list), prop_list (list): save_mols_list is a list of generated molecules with high QED scores represented by rdkit Chem.Mol objects; prop_list is a list of the corresponding QED scores. """ print("Loading paramaters from {}".format(checkpoint_path)) self.energy_function.load_state_dict(torch.load(checkpoint_path)) parameters = self.energy_function.parameters() save_mols_list = [] prop_list = [] for _, batch in enumerate(tqdm(initialization_loader)): ### Initialization gen_x = gen_adj = gen_x.requires_grad = True gen_adj.requires_grad = True requires_grad(parameters, False) self.energy_function.eval() noise_x = torch.randn_like(gen_x, device=self.device) noise_adj = torch.randn_like(gen_adj, device=self.device) ### Langevin dynamics for _ in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) gen_out = self.energy_function(gen_adj, gen_x) gen_out.sum().backward() if clamp:, 0.01), 0.01), alpha=-ld_step_size), alpha=-ld_step_size) gen_x.grad.detach_() gen_x.grad.zero_() gen_adj.grad.detach_() gen_adj.grad.zero_(), 1 + c), 1) gen_x_t = copy.deepcopy(gen_x) gen_adj_t = copy.deepcopy(gen_adj) gen_adj_t = (gen_adj_t + gen_adj_t.permute(0, 1, 3, 2)) / 2 gen_mols = gen_mol_from_one_shot_tensor(gen_adj_t, gen_x_t, atomic_num_list, correct_validity=True) gen_smiles = [Chem.MolToSmiles(mol) for mol in gen_mols] for mol_idx in range(len(gen_smiles)): if gen_mols[mol_idx] is not None: tmp_mol = gen_mols[mol_idx] tmp_smiles = gen_smiles[mol_idx] if tmp_smiles not in train_smiles: tmp_qed = qed(tmp_mol) if tmp_qed > 0.930: save_mols_list.append(tmp_mol) prop_list.append(tmp_qed) return save_mols_list, prop_list
[docs] def run_const_prop_opt(self, checkpoint_path, initialization_loader, c, ld_step, ld_noise, ld_step_size, clamp, atomic_num_list, train_smiles): r""" Running graph generation for goal-directed generation task: constrained property optimization. Args: checkpoint_path (str): The path of the trained model, *i.e.*, the .pt file. initialization_loader: The data loader for loading samples to initialize the Langevin dynamics. It is supposed to use dig.ggraph.dataset.QM9/ZINC250k as the dataset class, and apply to it to form the data loader. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. atomic_num_list (list): The list used to indicate atom types. train_smiles (list): A list of smiles string corresponding to training samples. :rtype: mols_0_list (list), mols_2_list (list), mols_4_list (list), mols_6_list (list), imp_0_list (list), imp_2_list (list), imp_4_list (list), imp_4_list (list): They are lists of optimized molecules (represented by rdkit Chem.Mol objects) and the corresponding improvements under the threshold 0.0, 0.2, 0.4, 0.6, respectively. """ print("Loading paramaters from {}".format(checkpoint_path)) self.energy_function.load_state_dict(torch.load(checkpoint_path)) parameters = self.energy_function.parameters() mols_0_list = [None]*800 mols_2_list = [None]*800 mols_4_list = [None]*800 mols_6_list = [None]*800 imp_0_list = [0]*800 imp_2_list = [0]*800 imp_4_list = [0]*800 imp_6_list = [0]*800 for i, batch in enumerate(tqdm(initialization_loader)): ### Initialization gen_x = gen_adj = ori_mols = gen_mol_from_one_shot_tensor(gen_adj, gen_x, atomic_num_list, correct_validity=True) ori_smiles = [Chem.MolToSmiles(mol) for mol in ori_mols] gen_x.requires_grad = True gen_adj.requires_grad = True requires_grad(parameters, False) self.energy_function.eval() noise_x = torch.randn_like(gen_x, device=self.device) noise_adj = torch.randn_like(gen_adj, device=self.device) ### Langevin dynamics for k in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) gen_out = self.energy_function(gen_adj, gen_x) gen_out.sum().backward() if clamp:, 0.1), 0.1), alpha=-ld_step_size), alpha=-ld_step_size) gen_x.grad.detach_() gen_x.grad.zero_() gen_adj.grad.detach_() gen_adj.grad.zero_(), 1 + c), 1) gen_x_t = copy.deepcopy(gen_x) gen_adj_t = copy.deepcopy(gen_adj) gen_adj_t = (gen_adj_t + gen_adj_t.permute(0, 1, 3, 2)) / 2 gen_mols = gen_mol_from_one_shot_tensor(gen_adj_t, gen_x_t, atomic_num_list, correct_validity=True) gen_smiles = [Chem.MolToSmiles(mol) for mol in gen_mols] for mol_idx in range(len(gen_smiles)): if gen_mols[mol_idx] is not None: tmp_mol = gen_mols[mol_idx] ori_mol = ori_mols[mol_idx] imp_p = calculate_min_plogp(tmp_mol) - calculate_min_plogp(ori_mol) current_sim = reward_target_molecule_similarity(tmp_mol, ori_mol) if current_sim >= 0.: if imp_p > imp_0_list[mol_idx]: mols_0_list[mol_idx] = tmp_mol if current_sim >= 0.2: if imp_p > imp_2_list[mol_idx]: mols_2_list[mol_idx] = tmp_mol if current_sim >= 0.4: if imp_p > imp_4_list[mol_idx]: mols_4_list[mol_idx] = tmp_mol if current_sim >= 0.6: if imp_p > imp_6_list[mol_idx]: mols_6_list[mol_idx] = tmp_mol return mols_0_list, mols_2_list, mols_4_list, mols_6_list, imp_0_list, imp_2_list, imp_4_list, imp_4_list
[docs] def run_comp_gen(self, checkpoint_path_qed, checkpoint_path_plogp, n_samples, c, ld_step, ld_noise, ld_step_size, clamp, atomic_num_list): r""" Running graph generation for compositional generation task. Args: checkpoint_path_qed (str): The path of the model trained on QED property, *i.e.*, the .pt file. checkpoint_path_plogp (str): The path of the model trained on plogp property, *i.e.*, the .pt file. n_samples (int): the number of molecules to generate. c (float): The scaling hyperparameter for dequantization. ld_step (int): The number of iteration steps of Langevin dynamics. ld_noise (float): The standard deviation of the added noise in Langevin dynamics. ld_step_size (int): The step size of Langevin dynamics. clamp (bool): Whether to use gradient clamp in Langevin dynamics. atomic_num_list (list): The list used to indicate atom types. :rtype: gen_mols (list): A list of generated molecules represented by rdkit Chem.Mol objects; """ model_qed = self.energy_function model_plogp = copy.deepcopy(self.energy_function) print("Loading paramaters from {}".format(checkpoint_path_qed)) model_qed.load_state_dict(torch.load(checkpoint_path_qed)) parameters_qed = model_qed.parameters() print("Loading paramaters from {}".format(checkpoint_path_plogp)) model_plogp.load_state_dict(torch.load(checkpoint_path_plogp)) parameters_plogp = model_plogp.parameters() ### Initialization print("Initializing samples...") gen_x = torch.rand(n_samples, self.n_atom_type, self.n_atom, device=self.device) * (1 + c) gen_adj = torch.rand(n_samples, self.n_edge_type, self.n_atom, self.n_atom, device=self.device) gen_x.requires_grad = True gen_adj.requires_grad = True requires_grad(parameters_qed, False) requires_grad(parameters_plogp, False) model_qed.eval() model_plogp.eval() noise_x = torch.randn_like(gen_x, device=self.device) noise_adj = torch.randn_like(gen_adj, device=self.device) ### Langevin dynamics print("Generating samples...") for _ in range(ld_step): noise_x.normal_(0, ld_noise) noise_adj.normal_(0, ld_noise) gen_out_qed = model_qed(gen_adj, gen_x) gen_out_plogp = model_plogp(gen_adj, gen_x) gen_out = 0.5 * gen_out_qed + 0.5 * gen_out_plogp gen_out.sum().backward() if clamp:, 0.01), 0.01), alpha=-ld_step_size), alpha=-ld_step_size) gen_x.grad.detach_() gen_x.grad.zero_() gen_adj.grad.detach_() gen_adj.grad.zero_(), 1 + c), 1) gen_x = gen_x.detach() gen_adj = gen_adj.detach() gen_adj = (gen_adj + gen_adj.permute(0, 1, 3, 2)) / 2 gen_mols = gen_mol_from_one_shot_tensor(gen_adj, gen_x, atomic_num_list, correct_validity=True) return gen_mols