Source code for dig.threedgraph.method.run


import time
import os
import torch
from torch.optim import Adam
from torch_geometric.data import DataLoader
import numpy as np
from torch.autograd import grad
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

[docs]class run(): r""" The base script for running different 3DGN methods. """ def __init__(self): pass
[docs] def run(self, device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation, epochs=500, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=50, weight_decay=0, energy_and_force=False, p=100, save_dir='', log_dir=''): r""" The run script for training and validation. Args: device (torch.device): Device for computation. train_dataset: Training data. valid_dataset: Validation data. test_dataset: Test data. model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet. loss_func (function): The used loss funtion for training. evaluation (function): The evaluation function. epochs (int, optinal): Number of total training epochs. (default: :obj:`500`) batch_size (int, optinal): Number of samples in each minibatch in training. (default: :obj:`32`) vt_batch_size (int, optinal): Number of samples in each minibatch in validation/testing. (default: :obj:`32`) lr (float, optinal): Initial learning rate. (default: :obj:`0.0005`) lr_decay_factor (float, optinal): Learning rate decay factor. (default: :obj:`0.5`) lr_decay_step_size (int, optinal): epochs at which lr_initial <- lr_initial * lr_decay_factor. (default: :obj:`50`) weight_decay (float, optinal): weight decay factor at the regularization term. (default: :obj:`0`) energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`) p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`) save_dir (str, optinal): The path to save trained models. If set to :obj:`''`, will not save the model. (default: :obj:`''`) log_dir (str, optinal): The path to save log files. If set to :obj:`''`, will not save the log files. (default: :obj:`''`) """ model = model.to(device) num_params = sum(p.numel() for p in model.parameters()) print(f'#Params: {num_params}') optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = StepLR(optimizer, step_size=lr_decay_step_size, gamma=lr_decay_factor) train_loader = DataLoader(train_dataset, batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, vt_batch_size, shuffle=False) test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False) best_valid = float('inf') best_test = float('inf') if save_dir != '': if not os.path.exists(save_dir): os.makedirs(save_dir) if log_dir != '': if not os.path.exists(log_dir): os.makedirs(log_dir) writer = SummaryWriter(log_dir=log_dir) for epoch in range(1, epochs + 1): print("\n=====Epoch {}".format(epoch), flush=True) print('\nTraining...', flush=True) train_mae = self.train(model, optimizer, train_loader, energy_and_force, p, loss_func, device) print('\n\nEvaluating...', flush=True) valid_mae = self.val(model, valid_loader, energy_and_force, p, evaluation, device) print('\n\nTesting...', flush=True) test_mae = self.val(model, test_loader, energy_and_force, p, evaluation, device) print() print({'Train': train_mae, 'Validation': valid_mae, 'Test': test_mae}) if log_dir != '': writer.add_scalar('train_mae', train_mae, epoch) writer.add_scalar('valid_mae', valid_mae, epoch) writer.add_scalar('test_mae', test_mae, epoch) if valid_mae < best_valid: best_valid = valid_mae best_test = test_mae if save_dir != '': print('Saving checkpoint...') checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_valid_mae': best_valid, 'num_params': num_params} torch.save(checkpoint, os.path.join(save_dir, 'valid_checkpoint.pt')) scheduler.step() print(f'Best validation MAE so far: {best_valid}') print(f'Test MAE when got best validation result: {best_test}') if log_dir != '': writer.close()
[docs] def train(self, model, optimizer, train_loader, energy_and_force, p, loss_func, device): r""" The script for training. Args: model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet. optimizer (Optimizer): Pytorch optimizer for trainable parameters in training. train_loader (Dataloader): Dataloader for training. energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`) p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`) loss_func (function): The used loss funtion for training. device (torch.device): The device where the model is deployed. :rtype: Traning loss. ( :obj:`mae`) """ model.train() loss_accum = 0 for step, batch_data in enumerate(tqdm(train_loader)): optimizer.zero_grad() batch_data = batch_data.to(device) out = model(batch_data) if energy_and_force: force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0] e_loss = loss_func(out, batch_data.y.unsqueeze(1)) f_loss = loss_func(force, batch_data.force) loss = e_loss + p * f_loss else: loss = loss_func(out, batch_data.y.unsqueeze(1)) loss.backward() optimizer.step() loss_accum += loss.detach().cpu().item() return loss_accum / (step + 1)
[docs] def val(self, model, data_loader, energy_and_force, p, evaluation, device): r""" The script for validation/test. Args: model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet. data_loader (Dataloader): Dataloader for validation or test. energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`) p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy. (default: :obj:`100`) evaluation (function): The used funtion for evaluation. device (torch.device, optional): The device where the model is deployed. :rtype: Evaluation result. ( :obj:`mae`) """ model.eval() preds = torch.Tensor([]).to(device) targets = torch.Tensor([]).to(device) if energy_and_force: preds_force = torch.Tensor([]).to(device) targets_force = torch.Tensor([]).to(device) for step, batch_data in enumerate(tqdm(data_loader)): batch_data = batch_data.to(device) out = model(batch_data) if energy_and_force: force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0] preds_force = torch.cat([preds_force,force.detach_()], dim=0) targets_force = torch.cat([targets_force,batch_data.force], dim=0) preds = torch.cat([preds, out.detach_()], dim=0) targets = torch.cat([targets, batch_data.y.unsqueeze(1)], dim=0) input_dict = {"y_true": targets, "y_pred": preds} if energy_and_force: input_dict_force = {"y_true": targets_force, "y_pred": preds_force} energy_mae = evaluation.eval(input_dict)['mae'] force_mae = evaluation.eval(input_dict_force)['mae'] print({'Energy MAE': energy_mae, 'Force MAE': force_mae}) return energy_mae + p * force_mae return evaluation.eval(input_dict)['mae']