Source code for dig.fairgraph.method.run

from .Graphair import graphair,aug_module,GCN,GCN_Body,Classifier

import time

[docs]class run(): r""" This class instantiates Graphair model and implements method to train and evaluate. """ def __init__(self): pass
[docs] def run(self,device,dataset,model='Graphair',epochs=10_000,test_epochs=1_000, lr=1e-4,weight_decay=1e-5): r""" This method runs training and evaluation for a fairgraph model on the given dataset. Check :obj:`examples.fairgraph.Graphair.run_graphair_nba.py` for examples on how to run the Graphair model. :param device: Device for computation. :type device: :obj:`torch.device` :param model: Defaults to `Graphair`. (Note that at this moment, only `Graphair` is supported) :type model: str, optional :param dataset: The dataset to train on. Should be one of :obj:`dig.fairgraph.dataset.fairgraph_dataset.POKEC` or :obj:`dig.fairgraph.dataset.fairgraph_dataset.NBA`. :type dataset: :obj:`object` :param epochs: Number of epochs to train on. Defaults to 10_000. :type epochs: int, optional :param test_epochs: Number of epochs to train the classifier while running evaluation. Defaults to 1_000. :type test_epochs: int,optional :param lr: Learning rate. Defaults to 1e-4. :type lr: float,optional :param weight_decay: Weight decay factor for regularization. Defaults to 1e-5. :type weight_decay: float, optional :raise: :obj:`Exception` when model is not Graphair. At this moment, only Graphair is supported. """ # Train script dataset_name = dataset.name features = dataset.features sens = dataset.sens adj = dataset.adj idx_sens = dataset.idx_sens_train # generate model if model=='Graphair': aug_model = aug_module(features, n_hidden=64, temperature=1).to(device) f_encoder = GCN_Body(in_feats = features.shape[1], n_hidden = 64, out_feats = 64, dropout = 0.1, nlayer = 2).to(device) sens_model = GCN(in_feats = features.shape[1], n_hidden = 64, out_feats = 64, nclass = 1).to(device) classifier_model = Classifier(input_dim=64,hidden_dim=64) model = graphair(aug_model=aug_model,f_encoder=f_encoder,sens_model=sens_model,classifier_model=classifier_model, lr=lr,weight_decay=weight_decay,dataset=dataset_name).to(device) else: raise Exception('At this moment, only Graphair is supported!') # call fit_whole st_time = time.time() model.fit_whole(epochs=epochs,adj=adj, x=features,sens=sens,idx_sens = idx_sens,warmup=0, adv_epoches=1) print("Training time: ", time.time() - st_time) # Test script model.test(adj=adj,features=features,labels=dataset.labels,epochs=test_epochs,idx_train=dataset.idx_train,idx_val=dataset.idx_val,idx_test=dataset.idx_test,sens=sens)