Source code for dig.xgraph.method.subgraphx

import os
import math
import copy
import torch
import numpy as np
import networkx as nx
from rdkit import Chem
from torch import Tensor
from textwrap import wrap
from functools import partial
from collections import Counter
import torch.nn.functional as F
from typing import List, Tuple, Dict
from torch_geometric.data import Batch, Data
from torch_geometric.utils import to_networkx
from typing import Callable, Union, Optional
import matplotlib.pyplot as plt
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.nn import MessagePassing
from torch_geometric.datasets import MoleculeNet
from .shapley import GnnNets_GC2value_func, GnnNets_NC2value_func, \
    gnn_score, mc_shapley, l_shapley, mc_l_shapley, NC_mc_l_shapley


def find_closest_node_result(results, max_nodes):
    """ return the highest reward tree_node with its subgraph is smaller than max_nodes """

    results = sorted(results, key=lambda x: len(x.coalition))

    result_node = results[0]
    for result_idx in range(len(results)):
        x = results[result_idx]
        if len(x.coalition) <= max_nodes and x.P > result_node.P:
            result_node = x
    return result_node


def reward_func(reward_method, value_func, node_idx=None,
                local_radius=4, sample_num=100,
                subgraph_building_method='zero_filling'):
    if reward_method.lower() == 'gnn_score':
        return partial(gnn_score,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method)

    elif reward_method.lower() == 'mc_shapley':
        return partial(mc_shapley,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    elif reward_method.lower() == 'l_shapley':
        return partial(l_shapley,
                       local_raduis=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method)

    elif reward_method.lower() == 'mc_l_shapley':
        return partial(mc_l_shapley,
                       local_raduis=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    elif reward_args.reward_method.lower() == 'nc_mc_l_shapley':
        assert node_idx is None, " Wrong node idx input "
        return partial(NC_mc_l_shapley,
                       node_idx=node_idx,
                       local_raduis=local_radius,
                       value_func=value_func,
                       subgraph_building_method=subgraph_building_method,
                       sample_num=sample_num)

    else:
        raise NotImplementedError


def k_hop_subgraph_with_default_whole_graph(
        edge_index, node_idx=None, num_hops=3, relabel_nodes=False,
        num_nodes=None, flow='source_to_target'):
    r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node
    :attr:`node_idx`.
    It returns (1) the nodes involved in the subgraph, (2) the filtered
    :obj:`edge_index` connectivity, (3) the mapping from node indices in
    :obj:`node_idx` to their new location, and (4) the edge mask indicating
    which edges were preserved.
    Args:
        node_idx (int, list, tuple or :obj:`torch.Tensor`): The central
            node(s).
        num_hops: (int): The number of hops :math:`k`.
        edge_index (LongTensor): The edge indices.
        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
            :obj:`edge_index` will be relabeled to hold consecutive indices
            starting from zero. (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
        flow (string, optional): The flow direction of :math:`k`-hop
            aggregation (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
    :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
             :class:`BoolTensor`)
    """

    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index  # edge_index 0 to 1, col: source, row: target

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

    inv = None

    if node_idx is None:
        subsets = torch.tensor([0])
        cur_subsets = subsets
        while 1:
            node_mask.fill_(False)
            node_mask[subsets] = True
            torch.index_select(node_mask, 0, row, out=edge_mask)
            subsets = torch.cat([subsets, col[edge_mask]]).unique()
            if not cur_subsets.equal(subsets):
                cur_subsets = subsets
            else:
                subset = subsets
                break
    else:
        if isinstance(node_idx, (int, list, tuple)):
            node_idx = torch.tensor([node_idx], device=row.device, dtype=torch.int64).flatten()
        elif isinstance(node_idx, torch.Tensor) and len(node_idx.shape) == 0:
            node_idx = torch.tensor([node_idx])
        else:
            node_idx = node_idx.to(row.device)

        subsets = [node_idx]
        for _ in range(num_hops):
            node_mask.fill_(False)
            node_mask[subsets[-1]] = True
            torch.index_select(node_mask, 0, row, out=edge_mask)
            subsets.append(col[edge_mask])
        subset, inv = torch.cat(subsets).unique(return_inverse=True)
        inv = inv[:node_idx.numel()]

    node_mask.fill_(False)
    node_mask[subset] = True
    edge_mask = node_mask[row] & node_mask[col]

    edge_index = edge_index[:, edge_mask]

    if relabel_nodes:
        node_idx = row.new_full((num_nodes,), -1)
        node_idx[subset] = torch.arange(subset.size(0), device=row.device)
        edge_index = node_idx[edge_index]

    return subset, edge_index, inv, edge_mask  # subset: key new node idx; value original node idx


def compute_scores(score_func, children):
    results = []
    for child in children:
        if child.P == 0:
            score = score_func(child.coalition, child.data)
        else:
            score = child.P
        results.append(score)
    return results


class PlotUtils(object):
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name

    def plot(self, graph, nodelist, figname, **kwargs):
        """ plot function for different dataset """
        if self.dataset_name.lower() in ['ba_2motifs', 'ba_lrp']:
            self.plot_ba2motifs(graph, nodelist, figname=figname)
        elif self.dataset_name.lower() in ['mutag'] + list(MoleculeNet.names.keys()):
            x = kwargs.get('x')
            self.plot_molecule(graph, nodelist, x, figname=figname)
        elif self.dataset_name.lower() in ['ba_shapes', 'ba_shapes', 'tree_grid', 'tree_cycle']:
            y = kwargs.get('y')
            node_idx = kwargs.get('node_idx')
            self.plot_bashapes(graph, nodelist, y, node_idx, figname=figname)
        elif self.dataset_name.lower() in ['Graph-SST2'.lower()]:
            words = kwargs.get('words')
            self.plot_sentence(graph, nodelist, words=words, figname=figname)
        else:
            raise NotImplementedError

    @staticmethod
    def plot_subgraph(graph, nodelist, colors='#FFA500', labels=None, edge_color='gray',
                      edgelist=None, subgraph_edge_color='black', title_sentence=None, figname=None):

        if edgelist is None:
            edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges() if
                                  n_frm in nodelist and n_to in nodelist]
        pos = nx.kamada_kawai_layout(graph)
        pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}

        nx.draw_networkx_nodes(graph, pos,
                               nodelist=list(graph.nodes()),
                               node_color=colors,
                               node_size=300)

        nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)

        nx.draw_networkx_edges(graph, pos=pos_nodelist,
                               edgelist=edgelist, width=6,
                               edge_color=subgraph_edge_color,
                               arrows=False)

        if labels is not None:
            nx.draw_networkx_labels(graph, pos, labels)

        plt.axis('off')
        if title_sentence is not None:
            plt.title('\n'.join(wrap(title_sentence, width=60)))
        plt.show()

    @staticmethod
    def plot_subgraph_with_nodes(graph, nodelist, node_idx, colors='#FFA500', labels=None, edge_color='gray',
                                 edgelist=None, subgraph_edge_color='black', title_sentence=None, figname=None):
        node_idx = int(node_idx)
        if edgelist is None:
            edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges() if
                                  n_frm in nodelist and n_to in nodelist]

        pos = nx.kamada_kawai_layout(graph) # calculate according to graph.nodes()
        pos_nodelist = {k: v for k, v in pos.items() if k in nodelist}

        nx.draw_networkx_nodes(graph, pos,
                               nodelist=list(graph.nodes()),
                               node_color=colors,
                               node_size=300)
        if isinstance(colors, list):
            list_indices = int(np.where(np.array(graph.nodes()) == node_idx)[0])
            node_idx_color = colors[list_indices]
        else:
            node_idx_color = colors

        nx.draw_networkx_nodes(graph, pos=pos,
                               nodelist=[node_idx],
                               node_color=node_idx_color,
                               node_size=600)

        nx.draw_networkx_edges(graph, pos, width=3, edge_color=edge_color, arrows=False)

        nx.draw_networkx_edges(graph, pos=pos_nodelist,
                               edgelist=edgelist, width=3,
                               edge_color=subgraph_edge_color,
                               arrows=False)

        if labels is not None:
            nx.draw_networkx_labels(graph, pos, labels)

        plt.axis('off')
        if title_sentence is not None:
            plt.title('\n'.join(wrap(title_sentence, width=60)))
        plt.show()

    @staticmethod
    def plot_sentence(graph, nodelist, words, edgelist=None, figname=None):
        pos = nx.kamada_kawai_layout(graph)
        words_dict = {i: words[i] for i in graph.nodes}
        if nodelist is not None:
            pos_coalition = {k: v for k, v in pos.items() if k in nodelist}
            nx.draw_networkx_nodes(graph, pos_coalition,
                                   nodelist=nodelist,
                                   node_color='yellow',
                                   node_shape='o',
                                   node_size=500)
        if edgelist is None:
            edgelist = [(n_frm, n_to) for (n_frm, n_to) in graph.edges()
                        if n_frm in nodelist and n_to in nodelist]
            nx.draw_networkx_edges(graph, pos=pos_coalition, edgelist=edgelist, width=5, edge_color='yellow')

        nx.draw_networkx_nodes(graph, pos, nodelist=list(graph.nodes()), node_size=300)

        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        nx.draw_networkx_labels(graph, pos, words_dict)

        plt.axis('off')
        plt.title('\n'.join(wrap(' '.join(words), width=50)))
        if figname is not None:
            plt.savefig(figname)
        plt.close('all')

    def plot_ba2motifs(self, graph, nodelist, edgelist=None, figname=None):
        return self.plot_subgraph(graph, nodelist, edgelist=edgelist, figname=figname)

    def plot_molecule(self, graph, nodelist, x, edgelist=None, figname=None):
        # collect the text information and node color
        if self.dataset_name == 'mutag':
            node_dict = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
            node_idxs = {k: int(v) for k, v in enumerate(np.where(x.cpu().numpy() == 1)[1])}
            node_labels = {k: node_dict[v] for k, v in node_idxs.items()}
            node_color = ['#E49D1C', '#4970C6', '#FF5357', '#29A329', 'brown', 'darkslategray', '#F0EA00']
            colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]

        elif self.dataset_name in MoleculeNet.names.keys():
            element_idxs = {k: int(v) for k, v in enumerate(x[:, 0])}
            node_idxs = element_idxs
            node_labels = {k: Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), int(v))
                           for k, v in element_idxs.items()}
            node_color = ['#29A329', 'lime', '#F0EA00',  'maroon', 'brown', '#E49D1C', '#4970C6', '#FF5357']
            colors = [node_color[(v - 1) % len(node_color)] for k, v in node_idxs.items()]
        else:
            raise NotImplementedError

        self.plot_subgraph(graph, nodelist, colors=colors, labels=node_labels,
                           edgelist=edgelist, edge_color='gray',
                           subgraph_edge_color='black',
                           title_sentence=None, figname=figname)

    def plot_bashapes(self, graph, nodelist, y, node_idx, edgelist=None, figname=None):
        node_idxs = {k: int(v) for k, v in enumerate(y.reshape(-1).tolist())}
        node_color = ['#FFA500', '#4970C6', '#FE0000', 'green']
        colors = [node_color[v % len(node_color)] for k, v in node_idxs.items()]
        self.plot_subgraph_with_nodes(graph, nodelist, node_idx, colors, edgelist=edgelist, figname=figname,
                           subgraph_edge_color='black')


class MCTSNode(object):

    def __init__(self, coalition: list, data: Data, ori_graph: nx.Graph,
                 c_puct: float = 10.0, W: float = 0, N: int = 0, P: float = 0):
        self.data = data
        self.coalition = coalition
        self.ori_graph = ori_graph
        self.c_puct = c_puct
        self.children = []
        self.W = W  # sum of node value
        self.N = N  # times of arrival
        self.P = P  # property score (reward)

    def Q(self):
        return self.W / self.N if self.N > 0 else 0

    def U(self, n):
        return self.c_puct * self.P * math.sqrt(n) / (1 + self.N)


class MCTS(object):
    def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int,
                 n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0,
                 expand_atoms: int = 14, high2low: bool = False,
                 node_idx: int = None, score_func: Callable = None):
        """ graph is a networkX graph """
        self.X = X
        self.edge_index = edge_index
        self.num_hops = num_hops
        self.data = Data(x=self.X, edge_index=self.edge_index)
        self.graph = to_networkx(self.data, to_undirected=True)
        self.data = Batch.from_data_list([self.data])
        self.num_nodes = self.graph.number_of_nodes()
        self.score_func = score_func
        self.n_rollout = n_rollout
        self.min_atoms = min_atoms
        self.c_puct = c_puct
        self.expand_atoms = expand_atoms
        self.high2low = high2low

        # extract the sub-graph and change the node indices.
        if node_idx is not None:
            self.ori_node_idx = node_idx
            self.ori_graph = copy.copy(self.graph)
            x, edge_index, subset, edge_mask, kwargs = \
                self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops)
            self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
            self.graph = self.ori_graph.subgraph(subset.tolist())
            mapping = {int(v): k for k, v in enumerate(subset)}
            self.graph = nx.relabel_nodes(self.graph, mapping)
            self.node_idx = torch.where(subset == self.ori_node_idx)[0]
            self.num_nodes = self.graph.number_of_nodes()
            self.subset = subset

        self.root_coalition = sorted([node for node in range(self.num_nodes)])
        self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct)
        self.root = self.MCTSNodeClass(self.root_coalition)
        self.state_map = {str(self.root.coalition): self.root}

    def set_score_func(self, score_func):
        self.score_func = score_func

    def __subgraph__(self, node_idx, x, edge_index, num_hops, **kwargs):
        num_nodes, num_edges = x.size(0), edge_index.size(1)
        subset, edge_index, _, edge_mask = k_hop_subgraph_with_default_whole_graph(
            edge_index, node_idx, num_hops, relabel_nodes=True, num_nodes=num_nodes)

        x = x[subset]
        for key, item in kwargs.items():
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                item = item[subset]
            elif torch.is_tensor(item) and item.size(0) == num_edges:
                item = item[edge_mask]
            kwargs[key] = item

        return x, edge_index, subset, edge_mask, kwargs

    def mcts_rollout(self, tree_node):
        cur_graph_coalition = tree_node.coalition
        if len(cur_graph_coalition) <= self.min_atoms:
            return tree_node.P

        # Expand if this node has never been visited
        if len(tree_node.children) == 0:
            node_degree_list = list(self.graph.subgraph(cur_graph_coalition).degree)
            node_degree_list = sorted(node_degree_list, key=lambda x: x[1], reverse=self.high2low)
            all_nodes = [x[0] for x in node_degree_list]

            if len(all_nodes) < self.expand_atoms:
                expand_nodes = all_nodes
            else:
                expand_nodes = all_nodes[:self.expand_atoms]

            for each_node in expand_nodes:
                # for each node, pruning it and get the remaining sub-graph
                # here we check the resulting sub-graphs and only keep the largest one
                subgraph_coalition = [node for node in all_nodes if node != each_node]

                subgraphs = [self.graph.subgraph(c)
                             for c in nx.connected_components(self.graph.subgraph(subgraph_coalition))]
                main_sub = subgraphs[0]
                for sub in subgraphs:
                    if sub.number_of_nodes() > main_sub.number_of_nodes():
                        main_sub = sub

                new_graph_coalition = sorted(list(main_sub.nodes()))

                # check the state map and merge the same sub-graph
                Find_same = False
                for old_graph_node in self.state_map.values():
                    if Counter(old_graph_node.coalition) == Counter(new_graph_coalition):
                        new_node = old_graph_node
                        Find_same = True

                if Find_same == False:
                    new_node = self.MCTSNodeClass(new_graph_coalition)
                    self.state_map[str(new_graph_coalition)] = new_node

                Find_same_child = False
                for cur_child in tree_node.children:
                    if Counter(cur_child.coalition) == Counter(new_graph_coalition):
                        Find_same_child = True

                if Find_same_child == False:
                    tree_node.children.append(new_node)

            scores = compute_scores(self.score_func, tree_node.children)
            for child, score in zip(tree_node.children, scores):
                child.P = score

        sum_count = sum([c.N for c in tree_node.children])
        selected_node = max(tree_node.children, key=lambda x: x.Q() + x.U(sum_count))
        v = self.mcts_rollout(selected_node)
        selected_node.W += v
        selected_node.N += 1
        return v

    def mcts(self, verbose=True):
        if verbose:
            print(f"The nodes in graph is {self.graph.number_of_nodes()}")
        for rollout_idx in range(self.n_rollout):
            self.mcts_rollout(self.root)
            if verbose:
                print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")

        explanations = [node for _, node in self.state_map.items()]
        explanations = sorted(explanations, key=lambda x: x.P, reverse=True)
        return explanations


[docs]class SubgraphX(object): r""" The implementation of paper `On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>`_. Args: model (:obj:`torch.nn.Module`): The target model prepared to explain num_classes(:obj:`int`): Number of classes for the datasets num_hops(:obj:`int`, :obj:`None`): The number of hops to extract neighborhood of target node (default: :obj:`None`) explain_graph(:obj:`bool`): Whether to explain graph classification model (default: :obj:`True`) rollout(:obj:`int`): Number of iteration to get the prediction min_atoms(:obj:`int`): Number of atoms of the leaf node in search tree c_puct(:obj:`float`): The hyperparameter which encourages the exploration expand_atoms(:obj:`int`): The number of atoms to expand when extend the child nodes in the search tree high2low(:obj:`bool`): Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default: :obj:`False`) local_radius(:obj:`int`): Number of local radius to calculate :obj:`l_shapley`, :obj:`mc_l_shapley` sample_num(:obj:`int`): Sampling time of monte carlo sampling approximation for :obj:`mc_shapley`, :obj:`mc_l_shapley` (default: :obj:`mc_l_shapley`) reward_method(:obj:`str`): The command string to select the subgraph_building_method(:obj:`str`): The command string for different subgraph building method, such as :obj:`zero_filling`, :obj:`split` (default: :obj:`zero_filling`) save_dir(:obj:`str`, :obj:`None`): Root directory to save the explanation results (default: :obj:`None`) filename(:obj:`str`): The filename of results vis(:obj:`bool`): Whether to show the visualization (default: :obj:`True`) Example: >>> # For graph classification task >>> subgraphx = SubgraphX(model=model, num_classes=2) >>> _, explanation_results, related_preds = subgraphx(x, edge_index) """ def __init__(self, model, num_classes: int, device, num_hops: Optional[int] = None, explain_graph: bool = True, rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool =True): self.model = model self.model.eval() self.device = device self.model.to(self.device) self.num_classes = num_classes self.num_hops = self.update_num_hops(num_hops) self.explain_graph = explain_graph # mcts hyper-parameters self.rollout = rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.high2low = high2low # reward function hyper-parameters self.local_radius = local_radius self.sample_num = sample_num self.reward_method = reward_method self.subgraph_building_method = subgraph_building_method # saving and visualization self.vis = vis self.save_dir = save_dir self.filename = filename self.save = True if self.save_dir is not None else False def update_num_hops(self, num_hops): if num_hops is not None: return num_hops k = 0 for module in self.model.modules(): if isinstance(module, MessagePassing): k += 1 return k def get_reward_func(self, value_func, node_idx=None): if self.explain_graph: node_idx = None else: assert node_idx is not None return reward_func(reward_method=self.reward_method, value_func=value_func, node_idx=node_idx, local_radius=self.local_radius, sample_num=self.sample_num, subgraph_building_method=self.subgraph_building_method) def get_mcts_class(self, x, edge_index, node_idx: int = None, score_func: Callable = None): if self.explain_graph: node_idx = None else: assert node_idx is not None return MCTS(x, edge_index, node_idx=node_idx, score_func=score_func, num_hops=self.num_hops, n_rollout=self.rollout, min_atoms=self.min_atoms, c_puct=self.c_puct, expand_atoms=self.expand_atoms, high2low=self.high2low) def visualization(self, explanation_results: list, prediction: Union[int, Tensor], max_nodes: int, plot_utils: PlotUtils, words: Optional[list] = None, y: Optional[Tensor] = None, vis_name: Optional[str] = None): if self.save: if vis_name is None: vis_name = f"{self.filename}.png" else: vis_name = None results = explanation_results[prediction] tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) if self.explain_graph: if words is not None: plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, words=words, figname=vis_name) else: plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, x=tree_node_x.data.x, figname=vis_name) else: subset = self.mcts_state_map.subset subgraph_y = y[subset].to('cpu') subgraph_y = torch.tensor([subgraph_y[node].item() for node in tree_node_x.ori_graph.nodes()]) plot_utils.plot(tree_node_x.ori_graph, tree_node_x.coalition, node_idx=self.mcts_state_map.node_idx, y=subgraph_y, figname=vis_name)
[docs] def __call__(self, x: Tensor, edge_index: Tensor, **kwargs)\ -> Tuple[None, List, List[Dict]]: r""" explain the GNN behavior for the graph using SubgraphX method Args: x (:obj:`torch.Tensor`): Node feature matrix with shape :obj:`[num_nodes, dim_node_feature]` edge_index (:obj:`torch.Tensor`): Graph connectivity in COO format with shape :obj:`[2, num_edges]` kwargs(:obj:`Dict`): The additional parameters - node_idx (:obj:`int`, :obj:`None`): The target node index when explain node classification task - max_nodes (:obj:`int`, :obj:`None`): The number of nodes in the final explanation results :rtype: (:obj:`None`, List[torch.Tensor], List[Dict]) """ node_idx = kwargs.get('node_idx') max_nodes = kwargs.get('max_nodes') max_nodes = 14 if max_nodes is None else max_nodes # default max subgraph size # collect all the class index labels = tuple(label for label in range(self.num_classes)) ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) logits = self.model(x, edge_index) probs = F.softmax(logits, dim=-1) probs = probs.squeeze() explanation_results = [] related_preds = [] if self.explain_graph: prediction = probs.argmax(-1) for label in ex_labels: value_func = GnnNets_GC2value_func(self.model, target_class=label) payoff_func = self.get_reward_func(value_func) self.mcts_state_map = self.get_mcts_class(x, edge_index, score_func=payoff_func) results = self.mcts_state_map.mcts(verbose=False) # l sharply score data = Data(x=x, edge_index=edge_index) tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) maskout_node_list = [node for node in range(tree_node_x.data.x.shape[0]) if node not in tree_node_x.coalition] maskout_score = gnn_score(maskout_node_list, data, value_func, subgraph_building_method='zero_filling') sparsity_score = 1 - len(tree_node_x.coalition) / tree_node_x.ori_graph.number_of_nodes() explanation_results.append(results) related_preds.append({'masked': maskout_score, 'origin': probs[label], 'sparsity': sparsity_score}) else: prediction = probs[node_idx].argmax(-1) for label in ex_labels: self.mcts_state_map = self.get_mcts_class(x, edge_index, node_idx=node_idx) self.node_idx = self.mcts_state_map.node_idx # mcts will extract the subgraph and relabel the nodes value_func = GnnNets_NC2value_func(self.model, node_idx=self.mcts_state_map.node_idx, target_class=label) payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.node_idx) self.mcts_state_map.set_score_func(payoff_func) results = self.mcts_state_map.mcts(verbose=False) tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) original_node_list = [node for node in tree_node_x.ori_graph.nodes] maskout_node_list = [node for node in range(tree_node_x.data.x.shape[0]) if node not in tree_node_x.coalition] original_score = gnn_score(original_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method='zero_filling') maskout_score = gnn_score(maskout_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method='zero_filling') sparsity_score = 1 - len(tree_node_x.coalition) / tree_node_x.ori_graph.number_of_nodes() explanation_results.append(results) related_preds.append({'maskout': maskout_score, 'origin': original_score, 'sparsity': sparsity_score}) if self.save: torch.save(explanation_results[prediction], os.path.join(self.save_dir, f"{self.filename}.pt")) return None, explanation_results, related_preds