# Author: Youzhi Luo (yzluo@tamu.edu)
# Updated by: Anmol Anand (aanand@tamu.edu)
import os
import torch
import torch.nn.functional as F
import numpy as np
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from sklearn.model_selection import KFold, train_test_split
from dig.auggraph.method.GraphAug.model import GIN, GCN
from dig.auggraph.dataset.aug_dataset import DegreeTrans, Subset, AUG_trans
from dig.auggraph.method.GraphAug.aug import Augmenter
from dig.auggraph.method.GraphAug.constants import *
[docs]class RunnerAugCls(object):
r"""
Runs the training of a graph classification model using the augmented
data generated by the already trained generator. Check
:obj:`examples.auggraph.GraphAug.run_aug_cls` for examples on how to run
this augmented classifier model.
Args:
data_root_path (string): Directory where datasets should be saved.
dataset_name (:class:`dig.auggraph.method.GraphAug.constants.enums.DatasetName`):
Name of the graph dataset.
conf (dict): Hyperparameters for the model. Check
:obj:`examples.auggraph.GraphAug.conf.aug_cls_conf` for examples on
how to define the conf dictionary for the augmented classifier
model.
"""
def __init__(self, data_root_path, dataset_name, conf):
self.conf = conf
self.dataset = self._get_dataset(data_root_path, dataset_name)
self.augmenter = self._get_aug_model()
self.model = self._get_model()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.train_data_trans = AUG_trans(self.augmenter, self.device, pre_trans=self.data_trans,
post_trans=self.data_trans)
self.dataset_name = dataset_name
def _get_dataset(self, data_root_path, dataset_name):
dataset = TUDataset(data_root_path, name=dataset_name.value)
if dataset_name in [DatasetName.MUTAG]:
self.data_trans = None
self.conf[IN_DIMENSION] = dataset[0].x.shape[1]
self.conf[EDGE_IN_DIMENSION] = dataset[0].x.shape[1]
if dataset_name in [DatasetName.NCI1, DatasetName.NCI109, DatasetName.PROTEINS]:
self.data_trans = None
self.conf[IN_DIMENSION] = dataset[0].x.shape[1]
elif dataset_name in [DatasetName.COLLAB, DatasetName.IMDB_BINARY]:
self.data_trans = DegreeTrans(dataset)
self.conf[IN_DIMENSION] = self.data_trans(dataset[0]).x.shape[1]
self.conf[NUM_CLASSES] = dataset.num_classes
return dataset
def _get_aug_model(self):
in_dim = self.conf[IN_DIMENSION]
self.conf[GENERATOR_PARAMS][IN_DIMENSION] = in_dim
if AugType.NODE_FM.value in self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS]:
self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS][AugType.NODE_FM.value][NODE_FEAT_DIM] = in_dim
augmenter = Augmenter(**self.conf[GENERATOR_PARAMS])
if self.conf[AUG_MODEL_PATH] is not None:
augmenter.load_state_dict(torch.load(self.conf[AUG_MODEL_PATH], map_location=torch.device('cpu')))
augmenter.eval()
return augmenter
def _get_model(self):
if self.conf[MODEL_NAME] == CLSModelType.GIN:
return GIN(self.conf[IN_DIMENSION], self.conf[NUM_CLASSES], self.conf[NUM_LAYERS],
self.conf[HIDDEN_UNITS], self.conf[DROPOUT])
elif self.conf[MODEL_NAME] == CLSModelType.GCN:
return GCN(self.conf[IN_DIMENSION], self.conf[NUM_CLASSES], self.conf[NUM_LAYERS], self.conf[HIDDEN_UNITS],
self.conf[DROPOUT])
def _train_epoch(self, loader, optimizer):
self.model.train()
for data_batch in loader:
data_batch = data_batch.to(self.device)
optimizer.zero_grad()
try:
output = self.model(data_batch)
except:
print(data_batch.x.shape, data_batch.edge_index.shape)
print(data_batch.batch)
exit()
loss = F.nll_loss(output, data_batch.y)
loss.backward()
optimizer.step()
def test(self, loader):
self.model.eval()
num_correct = 0
for data_batch in loader:
data_batch = data_batch.to(self.device)
output = self.model(data_batch)
pred = output.max(dim=1)[1]
num_correct += pred.eq(data_batch.y).sum().item()
return num_correct / len(loader.dataset)
[docs] def train_test(self, out_root_path, log_file='record.txt'):
r"""
This method is used to run the training for the classification model
on the augmented graph dataset and then validate the epoch parameters.
Args:
out_root_path (string): Directory where the results of this
augmented classifier model will be saved.
log_file (string): File where training and validation logs are
written.
"""
val_accs, test_accs = [], []
kf = KFold(n_splits=10, shuffle=True)
self.dataset.shuffle()
self.model = self.model.to(self.device)
out_path = os.path.join(out_root_path, self.dataset_name.value)
if not os.path.isdir(out_path):
os.makedirs(out_path)
f = open(os.path.join(out_path, log_file), 'a')
f.write('10-CV results for dataset {} with model {}, num layers {}, hidden {}\n'.format(self.dataset_name,
self.conf[MODEL_NAME].value,
self.conf[NUM_LAYERS],
self.conf[HIDDEN_UNITS]))
f.write('Use the learnable augmentation with params below\n')
for aug_type in self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS]:
f.write('{}: {}\n'.format(aug_type, self.conf[GENERATOR_PARAMS][AUG_TYPE_PARAMS][aug_type]))
f.close()
for i, (train_idx, test_idx) in enumerate(kf.split(list(range(len(self.dataset))))):
train_idx, val_idx = train_test_split(train_idx, test_size=0.1)
train_set, val_set, test_set = Subset(self.dataset[train_idx.tolist()], transform=self.train_data_trans), \
Subset(self.dataset[val_idx.tolist()], transform=self.data_trans), Subset(
self.dataset[test_idx.tolist()], transform=self.data_trans)
train_loader = DataLoader(train_set, batch_size=self.conf[BATCH_SIZE], shuffle=True, num_workers=16)
val_loader = DataLoader(val_set, batch_size=self.conf[BATCH_SIZE], shuffle=True)
test_loader = DataLoader(test_set, batch_size=self.conf[BATCH_SIZE], shuffle=True)
self.model.reset_parameters()
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.conf[INITIAL_LR])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=self.conf[FACTOR],
patience=self.conf[PATIENCE],
min_lr=self.conf[MIN_LR])
best_val_acc, best_test_acc = 0.0, 0.0
for epoch in range(self.conf[MAX_NUM_EPOCHS]):
lr = scheduler.optimizer.param_groups[0]['lr']
self._train_epoch(train_loader, optimizer)
val_acc = self.test(val_loader)
print('Epoch {}, validation accuracy {}'.format(epoch, val_acc))
test_acc = self.test(test_loader)
scheduler.step(val_acc)
if val_acc > best_val_acc:
best_val_acc = val_acc
if test_acc > best_test_acc:
best_test_acc = test_acc
if lr < self.conf[MIN_LR]:
break
val_accs.append(best_val_acc)
test_accs.append(best_test_acc)
f = open(os.path.join(out_path, log_file), 'a')
f.write('Split {}, validation accuracy {}, test accuracy {}\n'.format(i, best_val_acc, best_test_acc))
f.close()
f = open(os.path.join(out_path, log_file), 'a')
f.write('Validation accuracy mean {}, std {}\n'.format(np.mean(val_accs), np.std(val_accs)))
f.write('Test accuracy mean {}, std {}\n'.format(np.mean(test_accs), np.std(test_accs)))
f.close()