Source code for dig.xgraph.dataset.syn_dataset

import os
import torch
import pickle
import numpy as np
import os.path as osp

import tqdm
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.data.dataset import files_exist
import shutil


def read_ba2motif_data(folder: str, prefix):
    with open(os.path.join(folder, f"{prefix}.pkl"), 'rb') as f:
        dense_edges, node_features, graph_labels = pickle.load(f)

    data_list = []
    for graph_idx in range(dense_edges.shape[0]):
        edge_index = dense_to_sparse(torch.from_numpy(dense_edges[graph_idx]))[0]
        data_list.append(Data(x=torch.from_numpy(node_features[graph_idx]).float(),
                              edge_index=edge_index,
                              y=torch.from_numpy(np.where(graph_labels[graph_idx])[0])))
    return data_list


[docs]class SynGraphDataset(InMemoryDataset): r""" The Synthetic datasets used in `Parameterized Explainer for Graph Neural Network <https://arxiv.org/abs/2011.04573>`_. It takes Barabási–Albert(BA) graph or balance tree as base graph and randomly attachs specific motifs to the base graph. Args: root (:obj:`str`): Root data directory to save datasets name (:obj:`str`): The name of the dataset. Including :obj:`BA_shapes`, BA_grid, transform (:obj:`Callable`, :obj:`None`): A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (:obj:`Callable`, :obj:`None`): A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) """ url = 'https://github.com/divelab/DIG_storage/raw/main/xgraph/datasets/{}' # Format: name: [display_name, url_name, filename] names = { 'ba_shapes': ['BA_shapes', 'BA_shapes.pkl', 'BA_shapes'], 'ba_community': ['BA_Community', 'BA_Community.pkl', 'BA_Community'], 'tree_grid': ['Tree_Grid', 'Tree_Grid.pkl', 'Tree_Grid'], 'tree_cycle': ['Tree_Cycle', 'Tree_Cycles.pkl', 'Tree_Cycles'], 'ba_2motifs': ['BA_2Motifs', 'BA_2Motifs.pkl', 'BA_2Motifs'] } def __init__(self, root, name, transform=None, pre_transform=None): self.name = name.lower() super(SynGraphDataset, self).__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self): return osp.join(self.root, self.name, 'raw') @property def processed_dir(self): return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self): return f'{self.names[self.name][2]}.pkl' @property def processed_file_names(self): return ['data.pt']
[docs] def download(self): url = self.url.format(self.names[self.name][1]) path = download_url(url, self.raw_dir)
[docs] def process(self): if self.name.lower() == 'BA_2Motifs'.lower(): data_list = read_ba2motif_data(self.raw_dir, self.names[self.name][2]) if self.pre_filter is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [data for data in data_list if self.pre_filter(data)] self.data, self.slices = self.collate(data_list) if self.pre_transform is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [self.pre_transform(data) for data in data_list] self.data, self.slices = self.collate(data_list) else: # Read data into huge `Data` list. data = self.read_syn_data() data = data if self.pre_transform is None else self.pre_transform(data) data_list = [data] torch.save(self.collate(data_list), self.processed_paths[0])
def __repr__(self): return '{}({})'.format(self.names[self.name][0], len(self)) def gen_motif_edge_mask(self, data, node_idx=0, num_hops=3): if self.name in ['ba_2motifs']: return torch.logical_and(data.edge_index[0] >= 20, data.edge_index[1] >= 20) elif self.name in ['ba_shapes', 'ba_community', 'tree_grid', 'tree_cycle']: """ selection in a loop way to fetch all the nodes in the connected motifs """ if data.y[node_idx] == 0: return torch.zeros_like(data.edge_index[0]).type(torch.bool) connected_motif_nodes = set() edge_label_matrix = data.edge_label_matrix + data.edge_label_matrix.T edge_index = data.edge_index.to('cpu') if isinstance(node_idx, torch.Tensor): connected_motif_nodes.add(node_idx.item()) else: connected_motif_nodes.add(node_idx) for _ in range(num_hops): append_node = set() for node in connected_motif_nodes: append_node.update(tuple(torch.where(edge_label_matrix[node] != 0)[0].tolist())) connected_motif_nodes.update(append_node) connected_motif_nodes_tensor = torch.Tensor(list(connected_motif_nodes)) frm_mask = (edge_index[0].unsqueeze(1) - connected_motif_nodes_tensor.unsqueeze(0) == 0).any(dim=1) to_mask = (edge_index[1].unsqueeze(1) - connected_motif_nodes_tensor.unsqueeze(0) == 0).any(dim=1) return torch.logical_and(frm_mask, to_mask) def read_syn_data(self): with open(self.raw_paths[0], 'rb') as f: adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix = pickle.load(f) x = torch.from_numpy(features).float() y = train_mask.reshape(-1, 1) * y_train + val_mask.reshape(-1, 1) * y_val + test_mask.reshape(-1, 1) * y_test y = torch.from_numpy(np.where(y)[1]) edge_index = dense_to_sparse(torch.from_numpy(adj))[0] data = Data(x=x, y=y, edge_index=edge_index) data.train_mask = torch.from_numpy(train_mask) data.val_mask = torch.from_numpy(val_mask) data.test_mask = torch.from_numpy(test_mask) data.edge_label_matrix = torch.from_numpy(edge_label_matrix) return data
[docs]class BA_LRP(InMemoryDataset): r""" The synthetic graph classification dataset used in `Higher-Order Explanations of Graph Neural Networks via Relevant Walks <https://arxiv.org/abs/2006.03589>`_. The first class in :class:`~BA_LRP` is Barabási–Albert(BA) graph which connects a new node :math:`\mathcal{V}` from current graph :math:`\mathcal{G}`. .. math:: p(\mathcal{V}) = \frac{Degree(\mathcal{V})}{\sum_{\mathcal{V}' \in \mathcal{G}} Degree(\mathcal{V}')} The second class in :class:`~BA_LRP` has a slightly higher growth model and nodes are selected without replacement with the inverse preferential attachment model. .. math:: p(\mathcal{V}) = \frac{Degree(\mathcal{V})^{-1}}{\sum_{\mathcal{V}' \in \mathcal{G}} Degree(\mathcal{V}')^{-1}} Args: root (:obj:`str`): Root data directory to save datasets num_per_class (:obj:`int`): The number of the graphs for each class. transform (:obj:`Callable`, :obj:`None`): A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (:obj:`Callable`, :obj:`None`): A function/transform that takes in an :class:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) .. note:: :class:`~BA_LRP` will automatically generate the dataset if the dataset file is not existed in the root directory. Example: >>> dataset = BA_LRP(root='./datasets') >>> loader = Dataloader(dataset, batch_size=32) >>> data = next(iter(loader)) # Batch(batch=[640], edge_index=[2, 1344], x=[640, 1], y=[32, 1]) Where the attributes of data indices: - :obj:`batch`: The assignment vector mapping each node to its graph index - :obj:`x`: The node features - :obj:`edge_index`: The edge matrix - :obj:`y`: The graph label """ url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/datasets/ba_lrp.pt') def __init__(self, root, num_per_class=10000, transform=None, pre_transform=None): self.name = 'ba_lrp' self.num_per_class = num_per_class super().__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self): return osp.join(self.root, self.name, 'raw') @property def processed_dir(self): return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self): return [f"raw.pt"] @property def processed_file_names(self): return [f'data.pt']
[docs] def download(self): url = self.url path = download_url(url, self.raw_dir) # shutil.move(path, path.replace('ba_lrp_old.pt', 'raw.pt')) data_list = torch.load(path) pyg_data_list = [] for data in data_list: pyg_data_list.append(Data(x=data['x'], edge_index=data['edge_index'], y=data['y'])) data, slices = self.collate(pyg_data_list) torch.save((data, slices), self.raw_paths[0])
@staticmethod def gen_class1(): x = torch.tensor([[1], [1]], dtype=torch.float) edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) data = Data(x=x, edge_index=edge_index, y=torch.tensor([[0]], dtype=torch.float)) for i in range(2, 20): data.x = torch.cat([data.x, torch.tensor([[1]], dtype=torch.float)], dim=0) deg = torch.stack([(data.edge_index[0] == node_idx).float().sum() for node_idx in range(i)], dim=0) sum_deg = deg.sum(dim=0, keepdim=True) probs = (deg / sum_deg).unsqueeze(0) prob_dist = torch.distributions.Categorical(probs) node_pick = prob_dist.sample().squeeze() data.edge_index = torch.cat([data.edge_index, torch.tensor([[node_pick, i], [i, node_pick]], dtype=torch.long)], dim=1) data.y = torch.cat([data.y, torch.tensor([[0]], dtype=torch.float)], dim=0) return data @staticmethod def gen_class2(): x = torch.tensor([[1], [1]], dtype=torch.float) edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) data = Data(x=x, edge_index=edge_index, y=torch.tensor([[1]], dtype=torch.float)) epsilon = 1e-30 for i in range(2, 20): data.x = torch.cat([data.x, torch.tensor([[1]], dtype=torch.float)], dim=0) deg_reciprocal = torch.stack([1 / ((data.edge_index[0] == node_idx).float().sum() + epsilon) for node_idx in range(i)], dim=0) sum_deg_reciprocal = deg_reciprocal.sum(dim=0, keepdim=True) probs = (deg_reciprocal / sum_deg_reciprocal).unsqueeze(0) prob_dist = torch.distributions.Categorical(probs) node_pick = -1 for _ in range(1 if i % 5 != 4 else 2): new_node_pick = prob_dist.sample().squeeze() while new_node_pick == node_pick: new_node_pick = prob_dist.sample().squeeze() node_pick = new_node_pick data.edge_index = torch.cat([data.edge_index, torch.tensor([[node_pick, i], [i, node_pick]], dtype=torch.long)], dim=1) data.y = torch.cat([data.y, torch.tensor([[1]], dtype=torch.float)], dim=0) return data
[docs] def process(self): if files_exist(self.raw_paths): shutil.copyfile(self.raw_paths[0], self.processed_paths[0]) return data_list = [] for i in tqdm.tqdm(range(self.num_per_class)): data_list.append(self.gen_class1()) data_list.append(self.gen_class2()) data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0])
if __name__ == '__main__': # lrp_dataset = BA_LRP(root='.', num_per_class=10000) syn_dataset = SynGraphDataset(root='.', name='BA_Community')