Source code for dig.xgraph.method.flowx

r"""
This is an official implementation of `FlowX: Towards Explainable Graph Neural Networks via Message Flows
<https://arxiv.org/abs/2206.12987>`_.
"""

from itertools import combinations
from typing import List, Tuple, Union, Dict

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import cross_entropy
from torch_geometric.utils.loop import add_self_loops, remove_self_loops

from .base_explainer import WalkBase
from ..models.utils import subgraph, gumbel_softmax


def cross_entropy_with_logit(y_pred: torch.Tensor, y_true: torch.Tensor, **kwargs):
    return cross_entropy(y_pred, y_true.long(), **kwargs)


[docs]class FlowX(WalkBase): r""" An implementation of FlowX in `FlowX: Towards Explainable Graph Neural Networks via Message Flows <https://arxiv.org/abs/2206.12987>`_. Args: model (torch.nn.Module): The target model prepared to explain. epochs (int, optional): The training steps. lr (float, optional): The explainer training learning rate. explain_graph (bool, optional): Whether to explain graph classification model. (default: :obj:`False`) .. note:: For node classification model, the :attr:`explain_graph` flag is False. """ coeffs = { 'edge_size': 5e-4, 'edge_ent': 1e-1 } def __init__(self, model, epochs=500, lr=3e-1, explain_graph=False, molecule=False): # def __init__(self, model, epochs=500, lr=1e-1, explain_graph=False, molecule=False): super().__init__(model=model, epochs=epochs, lr=lr, explain_graph=explain_graph, molecule=molecule) self.score_structure = [(i % 2, term_idx) for i in range(1, self.num_layers + 1) for term_idx in combinations(range(self.num_layers), i) ] self.ns_iter = 30 self.ns_per_iter = None self.fidelity_plus = True self.score_lr = 0e-5 # 2e-5 # self.alpha = 0.5 self.no_mask = False if self.no_mask: self.epochs = 1 self.lr = 0 self.score_lr = 0
[docs] def forward(self, x: Tensor, edge_index: Tensor, **kwargs ) -> Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]]: r""" Run the explainer for a specific graph instance. Args: x (torch.Tensor): The graph instance's input node features. edge_index (torch.Tensor): The graph instance's edge index. **kwargs (dict): :obj:`node_idx` (int, list, tuple, torch.Tensor): The index of node that is pending to be explained. (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: :obj:`0.7`) :obj:`num_classes` (int): The number of task's classes. :rtype: (None, list, list) .. note:: (None, masks, related_predictions): masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities. """ super().forward(x, edge_index, **kwargs) # Explanation initial process self.model.eval() # Initial original prediction self.ori_logits_pred = self.model(x, edge_index).softmax(1) # print(f'#I#pred label: {torch.argmax(self.ori_logits_pred)}') # Edge Index with self loop edge_index, _ = remove_self_loops(edge_index) edge_index_with_loop, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) walk_indices_list = torch.tensor( self.walks_pick(edge_index_with_loop.cpu(), list(range(edge_index_with_loop.shape[1])), num_layers=self.num_layers), device=self.device) if not self.explain_graph: node_idx = kwargs.get('node_idx') self.node_idx = node_idx assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, edge_index_with_loop, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # walk indices list mask edge2node_idx = edge_index_with_loop[1] == node_idx walk_indices_list_mask = edge2node_idx[walk_indices_list[:, -1]] walk_indices_list = walk_indices_list[walk_indices_list_mask] import time start = time.time() self.time_list = [] labels = tuple(i for i in range(kwargs.get('num_classes'))) ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) # Connect mask with self.connect_mask(self): iter_weighted_change_walks_list, iter_changed_subsets_score_list, walk_sample_count = \ self.flow_shap(x, edge_index, edge_index_with_loop, walk_indices_list, **kwargs) walk_score_list = [] for ex_label in ex_labels: # --- training --- self.train_mask(x, edge_index, ex_label, walk_indices_list, edge_index_with_loop, iter_weighted_change_walks_list, iter_changed_subsets_score_list, walk_sample_count) walk_score_list.append(self.flow_mask.data) walks = {'ids': walk_indices_list, 'score': torch.cat(walk_score_list, dim=1)} # print(f'Walk scores summation: {torch.cat(walk_score_list, dim=0).sum(0)}') # print(f'#D#WalkScore total time: {time.time() - start}\n' # f'predict time: {sum(self.time_list)}') # specify to edge with self-loop mask prediction labels = tuple(i for i in range(kwargs.get('num_classes'))) ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) start = time.time() masks = [] for ex_label in ex_labels: edge_attr = self.explain_edges_with_loop(x, walks, ex_label) mask = edge_attr mask = self.control_sparsity(mask, kwargs.get('sparsity')).sigmoid() masks.append(mask.detach()) # print(f'#D#Edge mask predict total time: {time.time() - start}') # Connect mask with self.connect_mask(self): related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return walks, masks, related_preds
def __loss__(self, raw_preds, x_label): if self.explain_graph: loss = cross_entropy_with_logit(raw_preds, x_label) else: loss = cross_entropy_with_logit(raw_preds[self.node_idx].unsqueeze(0), x_label) if self.fidelity_plus: loss = - loss # Option 2: make it hard: higher std, and closer to Sparsity # loss = loss - 10 * torch.square(self.mask - 0.5).mean() # loss = loss + 1e-3 * torch.square(self.mask.sum() - self.mask.shape[0] * (1 - x_args.sparsity)) # m = self.nec_suf_mask.sigmoid() # loss = loss + self.coeffs['edge_size'] * m.sum() # ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) # loss = loss + self.coeffs['edge_ent'] * ent.mean() return loss def train_mask(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, walk_indices_list, edge_index_with_loop, iter_weighted_change_walks_list, iter_changed_subsets_score_list, walk_sample_count, t0=7., t1=0.5, **kwargs ) -> None: # initialize a mask self.to(x.device) self.nec_suf_mask = nn.Parameter( 1e-1 * nn.init.uniform_(torch.empty((1, iter_weighted_change_walks_list.shape[1], 1), device=self.device))) if self.no_mask: self.nec_suf_mask = nn.Parameter( 100 * torch.ones((1, iter_weighted_change_walks_list.shape[1], 1), device=self.device)) self.iter_weighted_change_walks_list = nn.Parameter(iter_weighted_change_walks_list.clone().detach()) # --- Training --- walk_plain_indices_list = walk_indices_list + \ (edge_index_with_loop.shape[1] * torch.arange(self.num_layers, device=self.device)).repeat( walk_indices_list.shape[0], 1) self.flow2layeredge_matrix = torch.stack([(walk_plain_indices_list == i).float().sum(dim=1) for i in range(self.num_layers * (self.num_edges + self.num_nodes))], dim=1).detach() # train to get the mask optimizer = torch.optim.Adam([{'params': self.nec_suf_mask}], # {'params': self.iter_weighted_change_walks_list, 'lr': self.score_lr}], lr=self.lr) # print('#I#begin this ex_label') for epoch in range(1, self.epochs + 1): masked_iter_weighted_change_walks_list = self.iter_weighted_change_walks_list * self.nec_suf_mask.sigmoid() walk_scores = (masked_iter_weighted_change_walks_list.unsqueeze(3).repeat(1, 1, 1, iter_changed_subsets_score_list.shape[ 2]) * iter_changed_subsets_score_list.unsqueeze( 2)).sum(1).sum(0) # EPS will affect the stability of training EPS = 1e-18 shap_flow_score = (walk_scores / (walk_sample_count.unsqueeze(1) + EPS)) # --- score/mask transformer --- self.flow_mask = shap_flow_score[:, ex_label] # --- setting layer edge masks --- self.layer_edge_mask = (self.flow_mask * self.flow2layeredge_matrix).view(self.flow_mask.shape[0], self.num_layers, -1).sum(0) mask = self.layer_edge_mask.sum(0) mask = mask - mask.min() mask = mask / (mask.max() + EPS) climb = True if climb: # pass mask = mask ** 8 else: end_epoch = 300 temperature = float(t0 * ((t1 / t0) ** (epoch / end_epoch))) if epoch < end_epoch else t1 mask = gumbel_softmax(mask, temperature, training=True) mask = mask - mask.min() mask = mask / (mask.max() + EPS) cur_sparsity = (mask < 0.5).sum().float() / mask.shape[0] # if cur_sparsity < x_args.sparsity: # # --- early stop --- # break # if epoch % 20 == 0: # print(f'Epoch: {epoch} --- training mask Sparsity: {cur_sparsity}') if self.fidelity_plus: mask = 1 - mask # Fidelity + self.mask = mask # isig_mask = torch.log(self.mask / (1 - self.mask + EPS) + EPS) # --- temp update non-leaf edge_mask temp_edge_mask = [] for layer_idx in range(self.num_layers): # --- Attention self-loop will be put at last because of the model will do it --- temp_edge_mask.append(mask) # debug: with self.temp_mask(self, temp_edge_mask): raw_preds = self.model(x, edge_index, **kwargs) loss = self.__loss__(raw_preds, ex_label) # if epoch % 20 == 0: # print(f'#D#Loss:{loss.item()}') optimizer.zero_grad() loss.backward() optimizer.step() return
[docs] def flow_shap(self, x, edge_index, edge_index_with_loop, walk_indices_list, **kwargs ): r""" Flow shapley calculations. """ # --- Kernel algorithm --- # --- row: walk index, column: total score walk_sample_count = torch.zeros(walk_indices_list.shape[0], dtype=torch.float, device=self.device) # General setting iter_weighted_change_walks_list = [] iter_changed_subsets_score_list = [] # Random sample ns_iter * ns_per_iter times for iter_idx in range(self.ns_iter): # --- random index --- unmask_pool = torch.cat([walk_indices_list[:, layer].unique() + layer * edge_index_with_loop.shape[1] for layer in range(self.num_layers)]) self.ns_per_iter = unmask_pool.shape[0] if self.explain_graph or unmask_pool.shape[ 0] <= 100 else 100 idx = torch.randperm(unmask_pool.nelement()) unmask_pool = unmask_pool.view(-1)[idx].view(unmask_pool.size()) mask_per_sub = unmask_pool.shape[0] // self.ns_per_iter weighted_change_walks_list = [] last_eliminated_walks = torch.zeros(walk_indices_list.shape[0], dtype=torch.bool, device=self.device) layer_edge_mask_list = [] for sub_idx in range(self.ns_per_iter): # --- sub random index --- mask_pool = unmask_pool[: mask_per_sub * (sub_idx + 1)] # --- obtain changed walk idx --- eliminated_layer_edges = unmask_pool[mask_per_sub * sub_idx: mask_per_sub * (sub_idx + 1)] walk_plain_indices_list = walk_indices_list + \ (edge_index_with_loop.shape[1] * torch.arange(self.num_layers, device=self.device)).repeat( walk_indices_list.shape[0], 1) eliminated_walks = torch.stack([walk_plain_indices_list == edge for edge in eliminated_layer_edges], dim=0).long().sum(0).sum(1).bool().long() weighted_changed_walks = eliminated_walks.clone().float() weighted_changed_walks[eliminated_walks == last_eliminated_walks] = 0. weighted_changed_walks /= (weighted_changed_walks > 1e-20).sum() + 1e-30 weighted_change_walks_list.append(weighted_changed_walks) last_eliminated_walks = eliminated_walks # --- setting a subset mask --- layer_edge_masks = torch.ones((self.num_layers, edge_index_with_loop.shape[1]), device=self.device) layer_edge_masks.view(-1)[mask_pool] -= 2 layer_edge_mask_list.append(layer_edge_masks) weighted_change_walks_list = torch.stack(weighted_change_walks_list, dim=0) iter_weighted_change_walks_list.append(weighted_change_walks_list.detach()) layer_edge_mask_list = torch.stack(layer_edge_mask_list, dim=0) * float('inf') # --- compute subsets' outputs of current iteration --- for layer_idx in range(self.num_layers): # --- Attention self-loop will be put at last because of the model will do it --- self.edge_mask[layer_idx].data = torch.cat( [layer_edge_mask_list[:, layer_idx, :self.num_edges].reshape(-1), layer_edge_mask_list[:, layer_idx, self.num_edges:].reshape(-1)]).sigmoid() batch = self.batch_input(x, edge_index, self.ns_per_iter) subsets_output = self.model(data=batch).softmax(1).detach() if not self.explain_graph: subsets_output = subsets_output.view(self.ns_per_iter, -1, kwargs.get('num_classes'))[:, self.node_idx] last_subsets_output = torch.cat( [self.ori_logits_pred[self.node_idx].unsqueeze(0), subsets_output.clone()[:-1]], dim=0) else: last_subsets_output = torch.cat([self.ori_logits_pred, subsets_output.clone()[:-1]], dim=0) changed_subsets_score_list = (last_subsets_output - subsets_output).detach() iter_changed_subsets_score_list.append(changed_subsets_score_list) walk_sample_count += (weighted_change_walks_list > 1e-30).float().sum(0) # iter x subset_idx x flow_idx iter_weighted_change_walks_list = torch.stack(iter_weighted_change_walks_list, dim=0) iter_changed_subsets_score_list = torch.stack(iter_changed_subsets_score_list, dim=0) return iter_weighted_change_walks_list, iter_changed_subsets_score_list, walk_sample_count