Source code for dig.xgraph.method.subgraphx

import os
import math
import copy
import torch
import numpy as np
import networkx as nx
from rdkit import Chem
from torch import Tensor
from textwrap import wrap
from functools import partial
from collections import Counter
from typing import List, Tuple, Dict
from torch_geometric.data import Batch, Data
from torch_geometric.utils import to_networkx
from typing import Callable, Union, Optional
import matplotlib.pyplot as plt
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.datasets import MoleculeNet
from torch_geometric.utils import remove_self_loops
from .shapley import GnnNetsGC2valueFunc, GnnNetsNC2valueFunc, \
    gnn_score, mc_shapley, l_shapley, mc_l_shapley, NC_mc_l_shapley, sparsity


def find_closest_node_result(results, max_nodes):
    """ return the highest reward tree_node with its subgraph is smaller than max_nodes """
    results = sorted(results, key=lambda x: len(x.coalition))

    result_node = results[0]
    for result_idx in range(len(results)):
        x = results[result_idx]
        if len(x.coalition) <= max_nodes and x.P > result_node.P:
            result_node = x
    return result_node


def reward_func(reward_method, value_func, node_idx=None,
                local_radius=4, sample_num=100,
                subgraph_building_method='zero_filling'):
    if reward_method.lower() == 'gnn_score':
        return partial(gnn_score,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method)

    elif reward_method.lower() == 'mc_shapley':
        return partial(mc_shapley,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    elif reward_method.lower() == 'l_shapley':
        return partial(l_shapley,
                       local_radius=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method)

    elif reward_method.lower() == 'mc_l_shapley':
        return partial(mc_l_shapley,
                       local_radius=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    elif reward_method.lower() == 'nc_mc_l_shapley':
        assert node_idx is not None, " Wrong node idx input "
        return partial(NC_mc_l_shapley,
                       node_idx=node_idx,
                       local_radius=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    else:
        raise NotImplementedError


def k_hop_subgraph_with_default_whole_graph(
        edge_index, node_idx=None, num_hops=3, relabel_nodes=False,
        num_nodes=None, flow='source_to_target'):
    r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node
    :attr:`node_idx`.
    It returns (1) the nodes involved in the subgraph, (2) the filtered
    :obj:`edge_index` connectivity, (3) the mapping from node indices in
    :obj:`node_idx` to their new location, and (4) the edge mask indicating
    which edges were preserved.
    Args:
        node_idx (int, list, tuple or :obj:`torch.Tensor`): The central
            node(s).
        num_hops: (int): The number of hops :math:`k`.
        edge_index (LongTensor): The edge indices.
        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
            :obj:`edge_index` will be relabeled to hold consecutive indices
            starting from zero. (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
        flow (string, optional): The flow direction of :math:`k`-hop
            aggregation (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
    :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
             :class:`BoolTensor`)
    """

    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index  # edge_index 0 to 1, col: source, row: target

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

    inv = None

    if node_idx is None:
        subsets = torch.tensor([0])
        cur_subsets = subsets
        while 1:
            node_mask.fill_(False)
            node_mask[subsets] = True
            torch.index_select(node_mask, 0, row, out=edge_mask)
            subsets = torch.cat([subsets, col[edge_mask]]).unique()
            if not cur_subsets.equal(subsets):
                cur_subsets = subsets
            else:
                subset = subsets
                break
    else:
        if isinstance(node_idx, (int, list, tuple)):
            node_idx = torch.tensor([node_idx], device=row.device, dtype=torch.int64).flatten()
        elif isinstance(node_idx, torch.Tensor) and len(node_idx.shape) == 0:
            node_idx = torch.tensor([node_idx])
        else:
            node_idx = node_idx.to(row.device)

        subsets = [node_idx]
        for _ in range(num_hops):
            node_mask.fill_(False)
            node_mask[subsets[-1]] = True
            torch.index_select(node_mask, 0, row, out=edge_mask)
            subsets.append(col[edge_mask])
        subset, inv = torch.cat(subsets).unique(return_inverse=True)
        inv = inv[:node_idx.numel()]

    node_mask.fill_(False)
    node_mask[subset] = True
    edge_mask = node_mask[row] & node_mask[col]

    edge_index = edge_index[:, edge_mask]

    if relabel_nodes:
        node_idx = row.new_full((num_nodes,), -1)
        node_idx[subset] = torch.arange(subset.size(0), device=row.device)
        edge_index = node_idx[edge_index]

    return subset, edge_index, inv, edge_mask  # subset: key new node idx; value original node idx


def compute_scores(score_func, children):
    results = []
    for child in children:
        if child.P == 0:
            score = score_func(child.coalition, child.data)
        else:
            score = child.P
        results.append(score)
    return results


class PlotUtils(object):
    def __init__(self, dataset_name, is_show=True):
        self.dataset_name = dataset_name
        self.is_show = is_show

    def plot(self, graph, nodelist, figname, title_sentence=None, **kwargs):
        """ plot function for different dataset """
        if self.dataset_name.lower() in ['ba_2motifs', 'ba_lrp']:
            self.plot_ba2motifs(graph, nodelist, title_sentence=title_sentence, figname=figname)
        elif self.dataset_name.lower() in ['mutag'] + list(MoleculeNet.names.keys()):
            x = kwargs.get('x')
            self.plot_molecule(graph, nodelist, x, title_sentence=title_sentence, figname=figname)
        elif self.dataset_name.lower() in ['ba_shapes', 'ba_community', 'tree_grid', 'tree_cycle']:
            y = kwargs.get('y')
            node_idx = kwargs.get('node_idx')
            self.plot_bashapes(graph, nodelist, y, node_idx, title_sentence=title_sentence, figname=figname)
        elif self.dataset_name.lower() in ['graph_sst2', 'graph_sst5', 'twitter']:
            words = kwargs.get('words')
            self.plot_sentence(graph, nodelist, words=words, title_sentence=title_sentence, figname=figname)
        else:
            raise NotImplementedError

    def plot_subgraph(self,
                      graph,
                      nodelist,
                      colors: Union[None, str, List[str]] = '#FFA500',
                      labels=None,
                      edge_color='gray',
                      edgelist=None,
                      subgraph_edge_color='black',
                      title_sentence=None,
                      figname=None):

        if edgelist is None:
            edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
                        if n_frm in nodelist and n_to in nodelist]
        pos = nx.kamada_kawai_layout(graph)
        pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}

        nx.draw_networkx_nodes(graph, pos,
                               nodelist=list(graph.nodes()),
                               node_color=colors,
                               node_size=300)

        nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)

        nx.draw_networkx_edges(graph, pos=pos_nodelist,
                               edgelist=edgelist, width=6,
                               edge_color=subgraph_edge_color,
                               arrows=False)

        if labels is not None:
            nx.draw_networkx_labels(graph, pos, labels)

        plt.axis('off')
        if title_sentence is not None:
            plt.title('\n'.join(wrap(title_sentence, width=60)))
        if figname is not None:
            plt.savefig(figname)
        if self.is_show:
            plt.show()
        if figname is not None:
            plt.close()

    def plot_subgraph_with_nodes(self,
                                 graph,
                                 nodelist,
                                 node_idx,
                                 colors='#FFA500',
                                 labels=None,
                                 edge_color='gray',
                                 edgelist=None,
                                 subgraph_edge_color='black',
                                 title_sentence=None,
                                 figname=None):
        node_idx = int(node_idx)
        if edgelist is None:
            edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
                        if n_frm in nodelist and n_to in nodelist]

        pos = nx.kamada_kawai_layout(graph)  # calculate according to graph.nodes()
        pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}

        nx.draw_networkx_nodes(graph, pos,
                               nodelist=list(graph.nodes()),
                               node_color=colors,
                               node_size=300)
        if isinstance(colors, list):
            list_indices = int(np.where(np.array(graph.nodes()) == node_idx)[0])
            node_idx_color = colors[list_indices]
        else:
            node_idx_color = colors

        nx.draw_networkx_nodes(graph, pos=pos,
                               nodelist=[node_idx],
                               node_color=node_idx_color,
                               node_size=600)

        nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)

        nx.draw_networkx_edges(graph, pos=pos_nodelist,
                               edgelist=edgelist, width=3,
                               edge_color=subgraph_edge_color,
                               arrows=False)

        if labels is not None:
            nx.draw_networkx_labels(graph, pos, labels)

        plt.axis('off')
        if title_sentence is not None:
            plt.title('\n'.join(wrap(title_sentence, width=60)))

        if figname is not None:
            plt.savefig(figname)
        if self.is_show:
            plt.show()
        if figname is not None:
            plt.close()

    def plot_sentence(self, graph, nodelist, words, edgelist=None, title_sentence=None, figname=None):
        pos = nx.kamada_kawai_layout(graph)
        words_dict = {i: words[i] for i in graph.nodes}
        if nodelist is not None:
            pos_coalition = {k: v for k, v in pos.items() if k in nodelist}
            nx.draw_networkx_nodes(graph, pos_coalition,
                                   nodelist=nodelist,
                                   node_color='yellow',
                                   node_shape='o',
                                   node_size=500)
            if edgelist is None:
                edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
                            if n_frm in nodelist and n_to in nodelist]
                nx.draw_networkx_edges(graph, pos=pos_coalition, edgelist=edgelist, width=5, edge_color='yellow', arrows=True)

        nx.draw_networkx_nodes(graph, pos, nodelist=list(graph.nodes()), node_size=300)

        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        nx.draw_networkx_labels(graph, pos, words_dict)

        plt.axis('off')
        plt.title('\n'.join(wrap(' '.join(words), width=50)))
        if title_sentence is not None:
            string = '\n'.join(wrap(' '.join(words), width=50))
            string += '\n'.join(wrap(title_sentence, width=60))
            plt.title(string)
        if figname is not None:
            plt.savefig(figname)
        if self.is_show:
            plt.show()
        if figname is not None:
            plt.close()

    def plot_ba2motifs(self,
                       graph,
                       nodelist,
                       edgelist=None,
                       title_sentence=None,
                       figname=None):
        return self.plot_subgraph(graph, nodelist,
                                  edgelist=edgelist,
                                  title_sentence=title_sentence,
                                  figname=figname)

    def plot_molecule(self,
                      graph,
                      nodelist,
                      x,
                      edgelist=None,
                      title_sentence=None,
                      figname=None):
        # collect the text information and node color
        if self.dataset_name == 'mutag':
            node_dict = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
            node_idxs = {k: int(v) for k, v in enumerate(np.where(x.cpu().numpy() == 1)[1])}
            node_labels = {k: node_dict[v] for k, v in node_idxs.items()}
            node_color = ['#E49D1C', '#4970C6', '#FF5357', '#29A329', 'brown', 'darkslategray', '#F0EA00']
            colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]

        elif self.dataset_name in MoleculeNet.names.keys():
            element_idxs = {k: int(v) for k, v in enumerate(x[:, 0])}
            node_idxs = element_idxs
            node_labels = {k: Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), int(v))
                           for k, v in element_idxs.items()}
            node_color = ['#29A329', 'lime', '#F0EA00',  'maroon', 'brown', '#E49D1C', '#4970C6', '#FF5357']
            colors = [node_color[(v - 1) % len(node_color)] for k, v in node_idxs.items()]
        else:
            raise NotImplementedError

        self.plot_subgraph(graph, nodelist,
                           colors=colors,
                           labels=node_labels,
                           edgelist=edgelist,
                           edge_color='gray',
                           subgraph_edge_color='black',
                           title_sentence=title_sentence,
                           figname=figname)

    def plot_bashapes(self,
                      graph,
                      nodelist,
                      y,
                      node_idx,
                      edgelist=None,
                      title_sentence=None,
                      figname=None):
        node_idxs = {k: int(v) for k, v in enumerate(y.reshape(-1).tolist())}
        node_color = ['#FFA500', '#4970C6', '#FE0000', 'green']
        colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
        self.plot_subgraph_with_nodes(graph,
                                      nodelist,
                                      node_idx,
                                      colors,
                                      edgelist=edgelist,
                                      title_sentence=title_sentence,
                                      figname=figname,
                                      subgraph_edge_color='black')


class MCTSNode(object):
    def __init__(self, coalition: list = None, data: Data = None, ori_graph: nx.Graph = None,
                 c_puct: float = 10.0, W: float = 0, N: int = 0, P: float = 0,
                 load_dict: Optional[Dict] = None, device='cpu'):
        self.data = data
        self.coalition = coalition
        self.ori_graph = ori_graph
        self.device = device
        self.c_puct = c_puct
        self.children = []
        self.W = W  # sum of node value
        self.N = N  # times of arrival
        self.P = P  # property score (reward)
        if load_dict is not None:
            self.load_info(load_dict)

    def Q(self):
        return self.W / self.N if self.N > 0 else 0

    def U(self, n):
        return self.c_puct * self.P * math.sqrt(n) / (1 + self.N)

    @property
    def info(self):
        info_dict = {
            'data': self.data.to('cpu'),
            'coalition': self.coalition,
            'ori_graph': self.ori_graph,
            'W': self.W,
            'N': self.N,
            'P': self.P
        }
        return info_dict

    def load_info(self, info_dict):
        self.W = info_dict['W']
        self.N = info_dict['N']
        self.P = info_dict['P']
        self.coalition = info_dict['coalition']
        self.ori_graph = info_dict['ori_graph']
        self.data = info_dict['data'].to(self.device)
        self.children = []
        return self


[docs]class MCTS(object): r""" Monte Carlo Tree Search Method. Args: X (:obj:`torch.Tensor`): Input node features edge_index (:obj:`torch.Tensor`): The edge indices. num_hops (:obj:`int`): The number of hops :math:`k`. n_rollout (:obj:`int`): The number of sequence to build the monte carlo tree. min_atoms (:obj:`int`): The number of atoms for the subgraph in the monte carlo tree leaf node. c_puct (:obj:`float`): The hyper-parameter to encourage exploration while searching. expand_atoms (:obj:`int`): The number of children to expand. high2low (:obj:`bool`): Whether to expand children tree node from high degree nodes to low degree nodes. node_idx (:obj:`int`): The target node index to extract the neighborhood. score_func (:obj:`Callable`): The reward function for tree node, such as mc_shapely and mc_l_shapely. """ def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int, n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms: int = 14, high2low: bool = False, node_idx: int = None, score_func: Callable = None, device='cpu'): self.X = X self.edge_index = edge_index self.device = device self.num_hops = num_hops self.data = Data(x=self.X, edge_index=self.edge_index) graph_data = Data(x=self.X, edge_index=remove_self_loops(self.edge_index)[0]) self.graph = to_networkx(graph_data, to_undirected=True) self.data = Batch.from_data_list([self.data]) self.num_nodes = self.graph.number_of_nodes() self.score_func = score_func self.n_rollout = n_rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.high2low = high2low self.new_node_idx = None # extract the sub-graph and change the node indices. if node_idx is not None: if isinstance(node_idx, Tensor): node_idx = node_idx.item() self.ori_node_idx = node_idx self.ori_graph = copy.copy(self.graph) x, edge_index, subset, edge_mask, kwargs = \ self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops) self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)]) self.graph = self.ori_graph.subgraph(subset.tolist()) mapping = {int(v): k for k, v in enumerate(subset)} self.graph = nx.relabel_nodes(self.graph, mapping) self.new_node_idx = torch.where(subset == self.ori_node_idx)[0].item() self.num_nodes = self.graph.number_of_nodes() self.subset = subset self.root_coalition = sorted([node for node in range(self.num_nodes)]) self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct, device=self.device) self.root = self.MCTSNodeClass(self.root_coalition) self.state_map = {str(self.root.coalition): self.root} def set_score_func(self, score_func): self.score_func = score_func @staticmethod def __subgraph__(node_idx, x, edge_index, num_hops, **kwargs): num_nodes, num_edges = x.size(0), edge_index.size(1) subset, edge_index, _, edge_mask = k_hop_subgraph_with_default_whole_graph( edge_index, node_idx, num_hops, relabel_nodes=True, num_nodes=num_nodes) x = x[subset] for key, item in kwargs.items(): if torch.is_tensor(item) and item.size(0) == num_nodes: item = item[subset] elif torch.is_tensor(item) and item.size(0) == num_edges: item = item[edge_mask] kwargs[key] = item return x, edge_index, subset, edge_mask, kwargs def mcts_rollout(self, tree_node): cur_graph_coalition = tree_node.coalition if len(cur_graph_coalition) <= self.min_atoms: return tree_node.P # Expand if this node has never been visited if len(tree_node.children) == 0: node_degree_list = list(self.graph.subgraph(cur_graph_coalition).degree) node_degree_list = sorted(node_degree_list, key=lambda x: x[1], reverse=self.high2low) all_nodes = [x[0] for x in node_degree_list] if self.new_node_idx: expand_nodes = [node for node in all_nodes if node != self.new_node_idx] else: expand_nodes = all_nodes if len(all_nodes) > self.expand_atoms: expand_nodes = expand_nodes[:self.expand_atoms] for each_node in expand_nodes: # for each node, pruning it and get the remaining sub-graph # here we check the resulting sub-graphs and only keep the largest one subgraph_coalition = [node for node in all_nodes if node != each_node] subgraphs = [self.graph.subgraph(c) for c in nx.connected_components(self.graph.subgraph(subgraph_coalition))] if self.new_node_idx: for sub in subgraphs: if self.new_node_idx in list(sub.nodes()): main_sub = sub else: main_sub = subgraphs[0] for sub in subgraphs: if sub.number_of_nodes() > main_sub.number_of_nodes(): main_sub = sub new_graph_coalition = sorted(list(main_sub.nodes())) # check the state map and merge the same sub-graph find_same = False for old_graph_node in self.state_map.values(): if Counter(old_graph_node.coalition) == Counter(new_graph_coalition): new_node = old_graph_node find_same = True if not find_same: new_node = self.MCTSNodeClass(new_graph_coalition) self.state_map[str(new_graph_coalition)] = new_node find_same_child = False for cur_child in tree_node.children: if Counter(cur_child.coalition) == Counter(new_graph_coalition): find_same_child = True if not find_same_child: tree_node.children.append(new_node) scores = compute_scores(self.score_func, tree_node.children) for child, score in zip(tree_node.children, scores): child.P = score sum_count = sum([c.N for c in tree_node.children]) selected_node = max(tree_node.children, key=lambda x: x.Q() + x.U(sum_count)) v = self.mcts_rollout(selected_node) selected_node.W += v selected_node.N += 1 return v def mcts(self, verbose=True): if verbose: print(f"The nodes in graph is {self.graph.number_of_nodes()}") for rollout_idx in range(self.n_rollout): self.mcts_rollout(self.root) if verbose: print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.") explanations = [node for _, node in self.state_map.items()] explanations = sorted(explanations, key=lambda x: x.P, reverse=True) return explanations
[docs]class SubgraphX(object): r""" The implementation of paper `On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>`_. Args: model (:obj:`torch.nn.Module`): The target model prepared to explain num_classes(:obj:`int`): Number of classes for the datasets num_hops(:obj:`int`, :obj:`None`): The number of hops to extract neighborhood of target node (default: :obj:`None`) explain_graph(:obj:`bool`): Whether to explain graph classification model (default: :obj:`True`) rollout(:obj:`int`): Number of iteration to get the prediction min_atoms(:obj:`int`): Number of atoms of the leaf node in search tree c_puct(:obj:`float`): The hyperparameter which encourages the exploration expand_atoms(:obj:`int`): The number of atoms to expand when extend the child nodes in the search tree high2low(:obj:`bool`): Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default: :obj:`False`) local_radius(:obj:`int`): Number of local radius to calculate :obj:`l_shapley`, :obj:`mc_l_shapley` sample_num(:obj:`int`): Sampling time of monte carlo sampling approximation for :obj:`mc_shapley`, :obj:`mc_l_shapley` (default: :obj:`mc_l_shapley`) reward_method(:obj:`str`): The command string to select the subgraph_building_method(:obj:`str`): The command string for different subgraph building method, such as :obj:`zero_filling`, :obj:`split` (default: :obj:`zero_filling`) save_dir(:obj:`str`, :obj:`None`): Root directory to save the explanation results (default: :obj:`None`) filename(:obj:`str`): The filename of results vis(:obj:`bool`): Whether to show the visualization (default: :obj:`True`) Example: >>> # For graph classification task >>> subgraphx = SubgraphX(model=model, num_classes=2) >>> _, explanation_results, related_preds = subgraphx(x, edge_index) """ def __init__(self, model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False, explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool = True): self.model = model self.model.eval() self.device = device self.model.to(self.device) self.num_classes = num_classes self.num_hops = self.update_num_hops(num_hops) self.explain_graph = explain_graph self.verbose = verbose # mcts hyper-parameters self.rollout = rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.high2low = high2low # reward function hyper-parameters self.local_radius = local_radius self.sample_num = sample_num self.reward_method = reward_method self.subgraph_building_method = subgraph_building_method # saving and visualization self.vis = vis self.save_dir = save_dir self.filename = filename self.save = True if self.save_dir is not None else False def update_num_hops(self, num_hops): if num_hops is not None: return num_hops k = 0 for module in self.model.modules(): if isinstance(module, MessagePassing): k += 1 return k def get_reward_func(self, value_func, node_idx=None): if self.explain_graph: node_idx = None else: assert node_idx is not None return reward_func(reward_method=self.reward_method, value_func=value_func, node_idx=node_idx, local_radius=self.local_radius, sample_num=self.sample_num, subgraph_building_method=self.subgraph_building_method) def get_mcts_class(self, x, edge_index, node_idx: int = None, score_func: Callable = None): if self.explain_graph: node_idx = None else: assert node_idx is not None return MCTS(x, edge_index, node_idx=node_idx, device=self.device, score_func=score_func, num_hops=self.num_hops, n_rollout=self.rollout, min_atoms=self.min_atoms, c_puct=self.c_puct, expand_atoms=self.expand_atoms, high2low=self.high2low) def visualization(self, results: list, max_nodes: int, plot_utils: PlotUtils, words: Optional[list] = None, y: Optional[Tensor] = None, title_sentence: Optional[str] = None, vis_name: Optional[str] = None): if self.save: if vis_name is None: vis_name = f"{self.filename}.png" else: vis_name = None tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) if self.explain_graph: if words is not None: plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, words=words, title_sentence=title_sentence, figname=vis_name) else: plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, x=tree_node_x.data.x, title_sentence=title_sentence, figname=vis_name) else: subset = self.mcts_state_map.subset subgraph_y = y[subset].to('cpu') subgraph_y = torch.tensor([subgraph_y[node].item() for node in tree_node_x.ori_graph.nodes()]) plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, node_idx=self.mcts_state_map.new_node_idx, title_sentence=title_sentence, y=subgraph_y, figname=vis_name) def read_from_MCTSInfo_list(self, MCTSInfo_list): if isinstance(MCTSInfo_list[0], dict): ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in MCTSInfo_list] elif isinstance(MCTSInfo_list[0][0], dict): ret_list = [] for single_label_MCTSInfo_list in MCTSInfo_list: single_label_ret_list = [MCTSNode(device=self.device).load_info(node_info) for node_info in single_label_MCTSInfo_list] ret_list.append(single_label_ret_list) return ret_list def write_from_MCTSNode_list(self, MCTSNode_list): if isinstance(MCTSNode_list[0], MCTSNode): ret_list = [node.info for node in MCTSNode_list] elif isinstance(MCTSNode_list[0][0], MCTSNode): ret_list = [] for single_label_MCTSNode_list in MCTSNode_list: single_label_ret_list = [node.info for node in single_label_MCTSNode_list] ret_list.append(single_label_ret_list) return ret_list def explain(self, x: Tensor, edge_index: Tensor, label: int, max_nodes: int = 5, node_idx: Optional[int] = None, saved_MCTSInfo_list: Optional[List[List]] = None): probs = self.model(x, edge_index).squeeze().softmax(dim=-1) if self.explain_graph: if saved_MCTSInfo_list: results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list) if not saved_MCTSInfo_list: value_func = GnnNetsGC2valueFunc(self.model, target_class=label) payoff_func = self.get_reward_func(value_func) self.mcts_state_map = self.get_mcts_class(x, edge_index, score_func=payoff_func) results = self.mcts_state_map.mcts(verbose=self.verbose) # l sharply score value_func = GnnNetsGC2valueFunc(self.model, target_class=label) tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) else: if saved_MCTSInfo_list: results = self.read_from_MCTSInfo_list(saved_MCTSInfo_list) self.mcts_state_map = self.get_mcts_class(x, edge_index, node_idx=node_idx) self.new_node_idx = self.mcts_state_map.new_node_idx # mcts will extract the subgraph and relabel the nodes value_func = GnnNetsNC2valueFunc(self.model, node_idx=self.mcts_state_map.new_node_idx, target_class=label) if not saved_MCTSInfo_list: payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.new_node_idx) self.mcts_state_map.set_score_func(payoff_func) results = self.mcts_state_map.mcts(verbose=self.verbose) tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) # keep the important structure masked_node_list = [node for node in range(tree_node_x.data.x.shape[0]) if node in tree_node_x.coalition] # remove the important structure, for node_classification, # remain the node_idx when remove the important structure maskout_node_list = [node for node in range(tree_node_x.data.x.shape[0]) if node not in tree_node_x.coalition] if not self.explain_graph: maskout_node_list += [self.new_node_idx] masked_score = gnn_score(masked_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method=self.subgraph_building_method) maskout_score = gnn_score(maskout_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method=self.subgraph_building_method) sparsity_score = sparsity(masked_node_list, tree_node_x.data, subgraph_building_method=self.subgraph_building_method) results = self.write_from_MCTSNode_list(results) related_pred = {'masked': masked_score, 'maskout': maskout_score, 'origin': probs[node_idx, label].item(), 'sparsity': sparsity_score} return results, related_pred
[docs] def __call__(self, x: Tensor, edge_index: Tensor, **kwargs)\ -> Tuple[None, List, List[Dict]]: r""" explain the GNN behavior for the graph using SubgraphX method Args: x (:obj:`torch.Tensor`): Node feature matrix with shape :obj:`[num_nodes, dim_node_feature]` edge_index (:obj:`torch.Tensor`): Graph connectivity in COO format with shape :obj:`[2, num_edges]` kwargs(:obj:`Dict`): The additional parameters - node_idx (:obj:`int`, :obj:`None`): The target node index when explain node classification task - max_nodes (:obj:`int`, :obj:`None`): The number of nodes in the final explanation results :rtype: (:obj:`None`, List[torch.Tensor], List[Dict]) """ node_idx = kwargs.get('node_idx') max_nodes = kwargs.get('max_nodes') # default max subgraph size # collect all the class index labels = tuple(label for label in range(self.num_classes)) ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) related_preds = [] explanation_results = [] saved_results = None if self.save: if os.path.isfile(os.path.join(self.save_dir, f"{self.filename}.pt")): saved_results = torch.load(os.path.join(self.save_dir, f"{self.filename}.pt")) for label_idx, label in enumerate(ex_labels): results, related_pred = self.explain(x, edge_index, label=label, max_nodes=max_nodes, node_idx=node_idx, saved_MCTSInfo_list=saved_results) related_preds.append(related_pred) explanation_results.append(results) if self.save: torch.save(explanation_results, os.path.join(self.save_dir, f"{self.filename}.pt")) return None, explanation_results, related_preds