Source code for dig.auggraph.method.GraphAug.runner_generator

# Author: Youzhi Luo (yzluo@tamu.edu)
# Updated by: Anmol Anand (aanand@tamu.edu)

import os
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.datasets import TUDataset
from .model import RewardGenModel
from dig.auggraph.dataset.aug_dataset import Subset, DegreeTrans
from .aug import Augmenter
from dig.auggraph.method.GraphAug.constants import *


[docs]class RunnerGenerator(object): r""" Runs the training of an augmented samples generator model which uses the already trained reward generation model. For a given graph, the model generates an augmented sample and a likelihood that this is a label invariant augmentation. This prediction is then evaluated by the reward generation model and a loss is computed based on these metrics. The loss is then minimized through training. Check :obj:`examples.auggraph.GraphAug.run_generator` for examples on how to run the generator model. Args: data_root_path (string): Directory where datasets should be saved. dataset_name (:class:`dig.auggraph.method.GraphAug.constants.enums.DatasetName`): Name of the graph dataset. conf (dict): Hyperparameters for the model. Check :obj:`examples.auggraph.GraphAug.conf.generator_conf` for examples on how to define the conf dictionary for the generator. """ def __init__(self, data_root_path, dataset_name, conf): self.conf = conf self._get_dataset(data_root_path, dataset_name) self._get_model() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.dataset_name = dataset_name if conf[BASELINE] == BaselineType.EXP: self.moving_mean = None def _get_dataset(self, data_root_path, dataset_name): dataset = TUDataset(data_root_path, name=dataset_name.value) if dataset_name in [DatasetName.NCI1, DatasetName.MUTAG, DatasetName.AIDS, DatasetName.NCI109, DatasetName.PROTEINS]: self.train_set = Subset(dataset) self.val_set = Subset(dataset) elif dataset_name in [DatasetName.COLLAB, DatasetName.IMDB_BINARY]: self.train_set = Subset(dataset, transform=DegreeTrans(dataset)) self.val_set = Subset(dataset, transform=DegreeTrans(dataset)) self.post_trans = DegreeTrans(dataset) in_dim = self.val_set[0].x.shape[1] self.conf[REWARD_GEN_PARAMS][IN_DIMENSION] = in_dim self.conf[GENERATOR_PARAMS][IN_DIMENSION] = in_dim if AugType.NODE_FM.value in self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS]: self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS][AugType.NODE_FM.value][NODE_FEAT_DIM] = in_dim def _get_model(self): self.reward_generator = RewardGenModel(**self.conf[REWARD_GEN_PARAMS]) if self.conf[REWARD_GEN_STATE_PATH] is not None: self.reward_generator.load_state_dict(torch.load(self.conf[REWARD_GEN_STATE_PATH])) self.reward_generator.eval() self.generator = Augmenter(**self.conf[GENERATOR_PARAMS]) def _train_epoch(self, loader, optimizer): self.generator.train() for iter, data_batch in enumerate(loader): anchor_data = data_batch anchor_data = anchor_data.to(self.device) g_loss_sum, reward_avg = 0, 0 for _ in range(self.conf[GENERATOR_STEPS]): optimizer.zero_grad() neg_data, log_likelihoods = self.generator(anchor_data) if hasattr(self, 'post_trans'): neg_data = self.post_trans(neg_data) neg_out = self.reward_generator(anchor_data, neg_data).view(-1) rewards = torch.log(neg_out) if self.conf[BASELINE] == BaselineType.MEAN: baseline = torch.mean(rewards) advantages = (rewards - baseline).detach() elif self.conf[BASELINE] == BaselineType.EXP: new_mean = torch.mean(rewards) if self.moving_mean is None: self.moving_mean = new_mean.cpu().detach().item() else: new_mean = new_mean.cpu().detach().item() self.moving_mean = self.conf[MOVING_RATIO] * self.moving_mean + ( 1.0 - self.conf[MOVING_RATIO]) * new_mean advantages = (rewards - self.moving_mean).detach() elif isinstance(self.conf[BASELINE], float): advantages = (rewards - self.conf[BASELINE]).detach() else: advantages = rewards.detach() g_loss = -torch.sum(log_likelihoods.view(-1) * advantages) if torch.isnan(g_loss) or torch.isinf(g_loss): torch.save(self.generator.state_dict(), os.path.join(self.out_path, 'buggy_model.pt')) exit() g_loss.backward() optimizer.step() g_loss_sum += g_loss.detach().item() reward_avg += torch.mean(rewards).detach().item() print('Generator loss {}, Average reward {}'.format(g_loss_sum, reward_avg / self.conf[GENERATOR_STEPS])) def test(self): self.generator.eval() num_correct, total_rewards = 0, 0.0 for data in self.val_set: data = data.to(self.device) aug_data, _ = self.generator(data) data1, data2 = Batch.from_data_list([data]), Batch.from_data_list([aug_data]) output = self.reward_generator(data1, data2) pred = (output.view(-1) > 0.5).long() num_correct += pred.sum().item() total_rewards += torch.log(output).view(-1)[0].item() return num_correct / len(self.val_set), total_rewards / len(self.val_set)
[docs] def train_test(self, results_path): r""" This method is used to run the training for the augmented samples generator and validate the epoch parameters. Args: results_path (string): Directory where the resulting optimal parameters of the generator model will be saved. """ self.reward_generator, self.generator = self.reward_generator.to(self.device), self.generator.to(self.device) out_path = os.path.join(results_path, self.dataset_name.value) if not os.path.isdir(out_path): os.makedirs(out_path) model_path = os.path.join(out_path, 'max_augs_{}'.format(self.conf[GENERATOR_PARAMS][MAX_NUM_AUG])) if not os.path.isdir(model_path): os.mkdir(model_path) self.out_path = out_path val_record_file = 'val_record.txt' f = open(os.path.join(model_path, val_record_file), 'a') f.write('Generator results for dataset {} using augmentation below\n'.format(self.dataset_name)) for aug_type in self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS]: f.write('{}: {}\n'.format(aug_type, self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS][aug_type])) f.close() train_loader = DataLoader(self.train_set, batch_size=self.conf[BATCH_SIZE], shuffle=True, num_workers=8) optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.conf[INITIAL_LR], weight_decay=1e-4) for epoch in range(self.conf[MAX_NUM_EPOCHS]): self._train_epoch(train_loader, optimizer) if epoch % self.conf[TEST_INTERVAL] == 0: val_acc, avg_reward = self.test() print('Epoch {}, validation accuracy {}, average reward {}'.format(epoch, val_acc, avg_reward)) f = open(os.path.join(out_path, val_record_file), 'a') f.write('Epoch {}, validation accuracy {}, average reward {}\n'.format(epoch, val_acc, avg_reward)) f.close() if self.conf[SAVE_MODEL]: torch.save(self.generator.state_dict(), os.path.join(model_path, '{}.pt'.format(str(epoch).zfill(4))))