Source code for dig.sslgraph.dataset.datasets

import re

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

from .TUDataset import TUDatasetExt
from .feat_expansion import FeatureExpander, CatDegOnehot, get_max_deg


[docs]def get_dataset(name, task, feat_str="deg", root=None): r"""A pre-implemented function to retrieve graph datasets from TUDataset. Depending on evaluation tasks, different node feature augmentation will be applied following `GraphCL <https://arxiv.org/abs/2010.13902>`_. Args: name (string): The `name <https://chrsmrrs.github.io/datasets/docs/datasets/>`_ of the dataset. task (string): The evaluation task. Either 'semisupervised' or 'unsupervised'. feat_str (bool, optional): The node feature augmentations to be applied, *e.g.*, degrees and centrality. (default: :obj:`deg`) root (string, optional): Root directory where the dataset should be saved. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Dataset` (unsupervised), or (:class:`torch_geometric.data.Dataset`, :class:`torch_geometric.data.Dataset`) (semisupervised). Examples -------- >>> dataset, dataset_pretrain = get_dataset("NCI1", "semisupervised") >>> dataset NCI1(4110) >>> dataset = get_dataset("MUTAG", "unsupervised", feat_str="") >>> dataset # degree not augmented as node attributes MUTAG(188) """ root = "." if root is None else root if task == "semisupervised": if name in ['REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']: feat_str = feat_str.replace('odeg100', 'odeg10') if name in ['DD']: feat_str = feat_str.replace('odeg100', 'odeg10') feat_str = feat_str.replace('ak3', 'ak1') degree = feat_str.find("deg") >= 0 onehot_maxdeg = re.findall("odeg(\d+)", feat_str) onehot_maxdeg = int(onehot_maxdeg[0]) if onehot_maxdeg else None pre_transform = FeatureExpander(degree=degree, onehot_maxdeg=onehot_maxdeg, AK=0).transform dataset = TUDatasetExt(root+"/semi_dataset/dataset", name, task, pre_transform=pre_transform, use_node_attr=True, processed_filename="data_%s.pt" % feat_str) dataset_pretrain = TUDatasetExt(root+"/semi_dataset/pretrain_dataset/", name, task, pre_transform=pre_transform, use_node_attr=True, processed_filename="data_%s.pt" % feat_str) dataset.data.edge_attr = None dataset_pretrain.data.edge_attr = None return dataset, dataset_pretrain elif task == "unsupervised": dataset = TUDatasetExt(root+"/unsuper_dataset/", name=name, task=task) if feat_str.find("deg") >= 0: max_degree = get_max_deg(dataset) dataset = TUDatasetExt(root+"./unsuper_dataset/", name=name, task=task, transform=CatDegOnehot(max_degree), use_node_attr=True) return dataset else: ValueError("Wrong task name")
[docs]def get_node_dataset(name, norm_feat=False, root=None): r"""A pre-implemented function to retrieve node datasets from Planetoid. Args: name (string): The name of the dataset (:obj:`"Cora"`, :obj:`"CiteSeer"`, :obj:`"PubMed"`). norm_feat (bool, optional): Whether to normalize node features. root (string, optional): Root directory where the dataset should be saved. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Dataset` Example ------- >>> dataset = get_node_dataset("Cora") >>> dataset Cora() """ root = "." if root is None else root transform = NormalizeFeatures() if norm_feat else None full_dataset = Planetoid(root+"/node_dataset/", name, transform=transform) return full_dataset