Source code for dig.auggraph.method.SMixup.smixup

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np

from dig.auggraph.method.SMixup.model.GCN import GCN
from dig.auggraph.method.SMixup.model.GIN import GIN
from dig.auggraph.method.SMixup.model.GraphMatching import GraphMatching
from dig.auggraph.method.SMixup.utils.sinkhorn import Sinkhorn
from dig.auggraph.method.SMixup.utils.utils import NormalizedDegree, triplet_loss
from dig.auggraph.dataset.aug_dataset import TripleSet

from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_dense_adj, dense_to_sparse, degree
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch, Data
from torch.nn.functional import softmax

from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score

[docs]class smixup(): r""" The S-Mixup from the `"Graph Mixup with Soft Alignments" <https://icml.cc/virtual/2023/poster/24930>`_ paper. Args: data_root_path (string): Directory where datasets are saved. dataset (string): Dataset Name. conf (dict): Hyperparameters of the graph matching network which is used to compute the soft alignments. """ def __init__(self, data_root_path, dataset, GMNET_conf): self._get_dataset(data_root_path, dataset) self.GMNET_conf = GMNET_conf self._get_GMNET(self.GMNET_conf['nlayers'], self.GMNET_conf['nhidden'], self.GMNET_conf['bs'], self.GMNET_conf['lr'], self.GMNET_conf['epochs']) def _get_GMNET(self, GMNET_nlayers, GMNET_hidden, GMNET_bs, GMNET_lr, GMNet_epochs): conf = {} conf_dis_param = {} conf_dis_param['num_layers'] = GMNET_nlayers conf_dis_param['hidden'] = GMNET_hidden conf_dis_param['model_type'] = 'gmnet' conf_dis_param['pool_type'] = 'sum' conf_dis_param['fuse_type'] = 'abs_diff' conf['dis_param'] = conf_dis_param conf['batch_size'] = GMNET_bs conf['start_lr'] = GMNET_lr conf['factor'] = 0.5 conf['patience'] = 5 conf['min_lr'] = 0.0000001 conf['pre_train_path'] = None conf['max_num_epochs'] = GMNet_epochs conf['dis_param']['in_dim'] = self.dataset[0].x.shape[1] self.GMNET = GraphMatching(**conf['dis_param']) def _get_dataset(self, data_root_path, dataset): if (dataset == 'IMDBB'): dataset = TUDataset(data_root_path, name='IMDB-BINARY', use_node_attr=True) num_cls = 2 elif dataset == 'MUTAG': dataset = TUDataset(data_root_path, name="MUTAG", use_node_attr=True) num_cls = 2 elif dataset == 'PROTEINS': dataset = TUDataset(data_root_path, name=dataset, use_node_attr=True) num_cls = 2 elif dataset == 'REDDITB': dataset = TUDataset(data_root_path, name="REDDIT-BINARY", use_node_attr=True) num_cls = 2 elif dataset == 'IMDBM': dataset = TUDataset(data_root_path, name='IMDB-MULTI', use_node_attr=True) num_cls = 3 elif dataset == 'REDDITM5': dataset = TUDataset(data_root_path, name='REDDIT-MULTI-5K', use_node_attr=True) num_cls = 5 elif dataset == 'REDDITM12': dataset = TUDataset(data_root_path, name='REDDIT-MULTI-12K', use_node_attr=True) num_cls = 11 elif dataset == 'NCI1': dataset = TUDataset(data_root_path, name='NCI1', use_node_attr=True) num_cls = 2 if dataset.data.x is None: max_degree = 0 degs = [] for data in dataset: degs += [degree(data.edge_index[0], dtype=torch.long)] max_degree = max(max_degree, degs[-1].max().item()) if max_degree < 1000: dataset.transform = T.OneHotDegree(max_degree) else: deg = torch.cat(degs, dim=0).to(torch.float) mean, std = deg.mean().item(), deg.std().item() dataset.transform = NormalizedDegree(mean, std) self.dataset = dataset self.num_cls = num_cls
[docs] def train_test(self, batch_size, cls_model, cls_nlayers, cls_hidden, cls_dropout, cls_lr, cls_epochs, alpha, ckpt_path, sim_method = 'cos'): r""" This method first train a GMNET and then use the GMNET to perform S-Mixup. Args: batch_size (int): Batch size of training the classifier. cls_model (string): Use GCN or GIN as the backbone of the classifier. cls_nlayers (int): Number of GNN layers of the classifier. cls_hidden (int): Number of hidden units of the classifier. cls_dropout (float): Dropout ratio of the classifier. cls_lr (float): Initial learning rate of training the classifier. cls_epochs (int): Training epochs of the classifier. alpha (float): Mixup ratio. ckpt_path (string): Location for saving checkpoints. sim_method (string): Similarity function used to compute the assignment matrix. (default: :obj:`cos`) """ criterion = nn.CrossEntropyLoss() test_accs = [] kf = KFold(n_splits=10, shuffle=True) for i, (train_idx, test_idx) in enumerate(kf.split(list(range(len(self.dataset))))): train_idx, val_idx = train_test_split(train_idx, test_size=0.1) train_set, val_set, test_set = self.dataset[train_idx.tolist()], self.dataset[val_idx.tolist()], self.dataset[test_idx.tolist()] self.train_GMNET(train_set) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 8) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers = 8) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers = 8) if (cls_model == 'GCN'): model = GCN(self.dataset[0].x.shape[1], self.num_cls, cls_nlayers, cls_hidden, cls_dropout) elif (cls_model == 'GIN'): model = GIN(self.dataset[0].x.shape[1], self.num_cls, cls_nlayers, cls_hidden, cls_dropout) model.cuda() optimizer = optim.Adam(model.parameters(), lr=cls_lr, weight_decay=1e-5) best_acc = 0.0 for epoch in range(1, cls_epochs + 1): loss_accum = 0.0 for step, batch in enumerate(train_loader): batch = self.Mixup(batch, alpha = alpha, sim_method = sim_method) model.train() optimizer.zero_grad() _, output = model(batch) loss = batch.lam * criterion(output, batch.y1.long()) + (1-batch.lam) * criterion(output, batch.y2.long()) loss.backward() optimizer.step() loss_accum += loss.item() train_loss = loss_accum / (step + 1) print("Epoch [{}] Train_loss {}".format(epoch, train_loss)) y_label = [] y_pred = [] for step, batch in enumerate(val_loader): batch = batch.cuda() model.eval() _, output = model(batch) pred = torch.argmax(output, dim = 1).long() y_pred = y_pred + pred.cpu().detach().numpy().flatten().tolist() y = batch.y.long() y_label = y_label + y.cpu().detach().numpy().flatten().tolist() acc_val = accuracy_score(y_pred, y_label) print("Epoch [{}] Test results:".format(epoch), "acc_val: {:.4f}".format(acc_val),) if acc_val >= best_acc: best_acc = acc_val torch.save(model.state_dict(), ckpt_path + "/best_val.pth") model.load_state_dict(torch.load(ckpt_path + "/best_val.pth")) y_label = [] y_pred = [] for step, batch in enumerate(test_loader): batch = batch.cuda() model.eval() _, output = model(batch) pred = torch.argmax(output, dim = 1).long() y_pred = y_pred + pred.cpu().detach().numpy().flatten().tolist() y = batch.y.long() y_label = y_label + y.cpu().detach().numpy().flatten().tolist() acc_test = accuracy_score(y_pred, y_label) print("Split {}: acc_test: {:.4f}".format(i, acc_test),) test_accs.append(acc_test) print("Final result: acc_test: {:.4f}+-{:.4f}".format(np.mean(test_accs), np.std(test_accs)))
def train_GMNET(self, train_set): self.GMNET.cuda() train_set = TripleSet(train_set) train_loader = DataLoader(train_set, batch_size = self.GMNET_conf['bs'], shuffle = True, num_workers = 8) optimizer = optim.Adam(self.GMNET.parameters(), lr = self.GMNET_conf['lr'], weight_decay=1e-4) for epoch in range(1, self.GMNET_conf['epochs'] + 1): print("====epoch {} ====".format(epoch)) self.GMNET.train() train_loss = 0.0 for data_batch in train_loader: anchor_data, pos_data, neg_data = data_batch anchor_data, pos_data, neg_data = anchor_data.cuda(), pos_data.cuda(), neg_data.cuda() optimizer.zero_grad() x_1, y = self.GMNET(anchor_data, pos_data, pred_head = False) x_2, z = self.GMNET(anchor_data, neg_data, pred_head = False) loss = triplet_loss(x_1, y, x_2, z) loss = torch.mean(loss) loss.backward() optimizer.step() train_loss += loss print("Epoch [{}] Train_loss {}".format(epoch, train_loss / len(train_loader))) print("GMNET training done.") self.GMNET.eval() def Mixup(self, batch, alpha, sim_method = 'cos', normalize_method = 'softmax', temperature = 1.0,): if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 0.5 lam = max(lam, 1 - lam) batch = batch.cuda() batch1 = batch.clone() data_list = list(batch1.to_data_list()) import random data_list2 = data_list.copy() random.shuffle(data_list2) batch2 = Batch.from_data_list(data_list2).cuda() h1, h2 = self.GMNET.dis_encoder(batch1, batch2, node_emd = True) h1, h2 = h1.detach(), h2.detach() for i in range(len(data_list)): data_list[i].emb = h1[batch1._slice_dict['x'][i] : batch1._slice_dict['x'][i + 1],:] data_list2[i].emb = h2[batch2._slice_dict['x'][i] : batch2._slice_dict['x'][i + 1],:] batch_size = len(data_list) mixed_data_list = [] for i in range(len(data_list)): # match = data_list[i].emb @ data_list2[i].emb.T if sim_method == 'cos': emb1 = data_list[i].emb / data_list[i].emb.norm(dim = 1)[:,None] emb2 = data_list2[i].emb / data_list2[i].emb.norm(dim = 1)[:,None] match = emb1 @ emb2.T / temperature elif sim_method == 'abs_diff': match = -(data_list[i].emb.unsqueeze(1) - data_list2[i].emb.unsqueeze(0)).norm(dim = -1) if (normalize_method == 'softmax'): normalized_match = softmax(match.detach().clone(), dim = 0) elif(normalize_method == 'sinkhorn'): normalized_match = Sinkhorn(match.detach().clone()) mixed_adj = lam * to_dense_adj(data_list[i].edge_index)[0].double()+ (1-lam) * normalized_match.double() @ to_dense_adj(data_list2[i].edge_index)[0].double() @ normalized_match.double().T mixed_adj[mixed_adj < 0.1] = 0 mixed_x = lam * data_list[i].x + (1-lam) * normalized_match.float() @ data_list2[i].x edge_index, edge_weights = dense_to_sparse(mixed_adj) data = Data(x = mixed_x.float(), edge_index = edge_index, edge_weights = edge_weights, y1 = data_list[i].y, y2 = data_list2[i].y) mixed_data_list.append(data) b = Batch.from_data_list(mixed_data_list) b.lam = lam return b