Source code for dig.auggraph.method.GraphAug.runner_reward_gen

# Author: Youzhi Luo (yzluo@tamu.edu)
# Updated by: Anmol Anand (aanand@tamu.edu)

import os
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from .model import RewardGenModel
from dig.auggraph.dataset.aug_dataset import DegreeTrans, TripleSet
from dig.auggraph.method.GraphAug.constants import *


[docs]class RunnerRewardGen(object): r""" Runs the training of a reward generation model which will be able to distinguish between graphs with different labels. Check :obj:`examples.auggraph.GraphAug.run_reward_gen` for examples on how to run the reward generation model. Args: data_root_path (string): Directory where datasets should be saved. dataset_name (:class:`dig.auggraph.method.GraphAug.constants.enums.DatasetName`): Name of the graph dataset. conf (dict): Hyperparameters for the model. Check :obj:`examples.auggraph.GraphAug.conf.reward_gen_conf` for examples on how to define the conf dictionary for the reward generator. """ def __init__(self, data_root_path, dataset_name, conf): self.conf = conf self._get_dataset(data_root_path, dataset_name) self.model = self._get_model() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.dataset_name = dataset_name def _get_dataset(self, data_root_path, dataset_name): dataset = TUDataset(data_root_path, name=dataset_name.value) if dataset_name in [DatasetName.NCI1, DatasetName.MUTAG, DatasetName.PROTEINS, DatasetName.NCI109]: self.train_set = TripleSet(dataset) self.val_set = TripleSet(dataset) elif dataset_name in [DatasetName.COLLAB, DatasetName.IMDB_BINARY]: self.train_set = TripleSet(dataset, transform=DegreeTrans(dataset)) self.val_set = TripleSet(dataset, transform=DegreeTrans(dataset)) self.conf[REWARD_GEN_PARAMS][IN_DIMENSION] = self.train_set[0][0].x.shape[1] def _get_model(self): return RewardGenModel(**self.conf[REWARD_GEN_PARAMS]) def _train_epoch(self, loader, optimizer): self.model.train() for data_batch in loader: anchor_data, pos_data, neg_data = data_batch anchor_data, pos_data, neg_data = anchor_data.to(self.device), pos_data.to(self.device), neg_data.to(self.device) optimizer.zero_grad() pos_out = self.model(anchor_data, pos_data).view(-1) pos_loss = F.binary_cross_entropy(pos_out, torch.ones_like(pos_out)) neg_out = self.model(anchor_data, neg_data).view(-1) neg_loss = F.binary_cross_entropy(neg_out, torch.zeros_like(neg_out)) loss = pos_loss + neg_loss loss.backward() optimizer.step() def test(self, loader): self.model.eval() num_correct, num_pos_correct, num_neg_correct = 0, 0, 0 with torch.no_grad(): for data_batch in loader: anchor_data, pos_data, neg_data = data_batch anchor_data, pos_data, neg_data = anchor_data.to(self.device), pos_data.to(self.device), neg_data.to(self.device) output = self.model(anchor_data, pos_data) pred = (output.view(-1) > 0.5).long() num_correct += pred.sum().item() num_pos_correct += pred.sum().item() output = self.model(anchor_data, neg_data) pred = (output.view(-1) < 0.5).long() num_correct += pred.sum().item() num_neg_correct += pred.sum().item() return num_correct / (2 * len(loader.dataset)), num_pos_correct / len(loader.dataset), num_neg_correct / len(loader.dataset)
[docs] def train_test(self, results_path, num_save=30): r""" This method is used to run the training for the reward generation model and validate the epoch parameters. Args: results_path (string): Directory where the resulting optimal parameters of the reward generation model will be saved. num_save (int): Number of final epochs for which model parameters will be saved. """ self.model = self.model.to(self.device) out_path = os.path.join(results_path, self.dataset_name.value) if not os.path.isdir(out_path): os.makedirs(out_path) model_dir_name = self.conf[REWARD_GEN_PARAMS][MODEL_TYPE] model_dir = os.path.join(out_path, model_dir_name.value) if not os.path.isdir(model_dir): os.mkdir(model_dir) log_file = 'record.txt' f = open(os.path.join(out_path, log_file), 'a') f.write('Reward generator classification results for dataset {} with model parameters {}\n'.format( self.dataset_name, self.conf[REWARD_GEN_PARAMS])) f.close() train_loader = DataLoader(self.train_set, batch_size= self.conf[BATCH_SIZE], shuffle=True, num_workers=16) val_loader = DataLoader(self.val_set, batch_size=self.conf[BATCH_SIZE], shuffle=True) optimizer = torch.optim.Adam(self.model.parameters(), lr=self.conf[INITIAL_LR], weight_decay=1e-4) for epoch in range(self.conf[MAX_NUM_EPOCHS]): self._train_epoch(train_loader, optimizer) if model_dir_name is not None and epoch > self.conf[MAX_NUM_EPOCHS] - num_save: torch.save(self.model.state_dict(), os.path.join(model_dir, '{}.pt'.format(str(epoch).zfill(4)))) val_acc, val_pos_acc, val_neg_acc = self.test(val_loader) print('Epoch {}, validation accuracy {}, accuracy of positive samples {}, accuracy of negative samples {}'.format(epoch, val_acc, val_pos_acc, val_neg_acc)) f = open(os.path.join(out_path, log_file), 'a') f.write('Epoch {}, validation accuracy {}, accuracy of positive samples {}, accuracy of negative samples {}\n'.format(epoch, val_acc, val_pos_acc, val_neg_acc)) f.close()