dig.xgraph.method

Methods interfaces under dig.xgraph.method.

DeepLIFT

An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences.

GNNExplainer

The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions.

GNN_GI

An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.

GNN_LRP

An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.

GradCAM

An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.

PGExplainer

An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network.

SubgraphX

The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations.

class DeepLIFT(model: torch.nn.modules.module.Module, explain_graph=False)[source]

An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain.

  • explain_graph (bool, optional) – Whether to explain graph classification model. (default: False)

Note

For node classification model, the explain_graph flag is False. For an example, see benchmarks/xgraph.

forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]

Run the explainer for a specific graph instance.

Parameters
  • x (torch.Tensor) – The graph instance’s input node features.

  • edge_index (torch.Tensor) – The graph instance’s edge index.

  • **kwargs (dict) – node_idx (int): The index of node that is pending to be explained. (for node classification) sparsity (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: 0.7)

Return type

(None, list, list)

Note

(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.

class GNNExplainer(model, epochs=100, lr=0.01, explain_graph=False)[source]

The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions.

Note

For an example, see benchmarks/xgraph.

Parameters
  • model (torch.nn.Module) – The GNN module to explain.

  • epochs (int, optional) – The number of epochs to train. (default: 100)

  • lr (float, optional) – The learning rate to apply. (default: 0.01)

  • explain_graph (bool, optional) – Whether to explain graph classification model (default: False)

forward(x, edge_index, mask_features=False, **kwargs)[source]

Run the explainer for a specific graph instance.

Parameters
  • x (torch.Tensor) – The graph instance’s input node features.

  • edge_index (torch.Tensor) – The graph instance’s edge index.

  • mask_features (bool, optional) – Whether to use feature mask. Not recommended. (Default: False)

  • **kwargs (dict) – node_idx (int): The index of node that is pending to be explained. (for node classification) sparsity (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: 0.7) num_classes (int): The number of task’s classes.

Return type

(None, list, list)

Note

(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.

class GNN_GI(model: torch.nn.modules.module.Module, explain_graph=False)[source]

An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain.

  • explain_graph (bool, optional) – Whether to explain graph classification model. (default: False)

Note

For node classification model, the explain_graph flag is False.

forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]

Run the explainer for a specific graph instance.

Parameters
  • x (torch.Tensor) – The graph instance’s input node features.

  • edge_index (torch.Tensor) – The graph instance’s edge index.

  • **kwargs (dict) – node_idx (int): The index of node that is pending to be explained. (for node classification) sparsity (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: 0.7) num_classes (int): The number of task’s classes.

Return type

(dict, list, list)

Note

(walks, edge_masks, related_predictions): walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.

class GNN_LRP(model: torch.nn.modules.module.Module, explain_graph=False)[source]

An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain.

  • explain_graph (bool, optional) – Whether to explain graph classification model. (default: False)

Note

For node classification model, the explain_graph flag is False. GNN-LRP is very model dependent. Please be sure you know how to modify it for different models. For an example, see benchmarks/xgraph.

forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]

Run the explainer for a specific graph instance.

Parameters
  • x (torch.Tensor) – The graph instance’s input node features.

  • edge_index (torch.Tensor) – The graph instance’s edge index.

  • **kwargs (dict) – node_idx (int): The index of node that is pending to be explained. (for node classification) sparsity (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: 0.7) num_classes (int): The number of task’s classes.

Return type

(walks, edge_masks, related_predictions), walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.

class GradCAM(model, explain_graph=False)[source]

An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain.

  • explain_graph (bool, optional) – Whether to explain graph classification model. (default: False)

Note

For node classification model, the explain_graph flag is False. For an example, see benchmarks/xgraph.

forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]][source]

Run the explainer for a specific graph instance.

Parameters
  • x (torch.Tensor) – The graph instance’s input node features.

  • edge_index (torch.Tensor) – The graph instance’s edge index.

  • **kwargs (dict) – node_idx (int): The index of node that is pending to be explained. (for node classification) sparsity (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: 0.7) num_classes (int): The number of task’s classes.

Return type

(None, list, list)

Note

(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.

class PGExplainer(model, in_channels: int, device, explain_graph: bool = True, epochs: int = 20, lr: float = 0.005, coff_size: float = 0.01, coff_ent: float = 0.0005, t0: float = 5.0, t1: float = 1.0, num_hops: Optional[int] = None)[source]

An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain

  • in_channels (int) – Number of input channels for the explanation network

  • explain_graph (bool) – Whether to explain graph classification model (default: True)

  • epochs (int) – Number of epochs to train the explanation network

  • lr (float) – Learning rate to train the explanation network

  • coff_size (float) – Size regularization to constrain the explanation size

  • coff_ent (float) – Entropy regularization to constrain the connectivity of explanation

  • t0 (float) – The temperature at the first epoch

  • t1 (float) – The temperature at the final epoch

  • num_hops (int, None) – The number of hops to extract neighborhood of target node

  • (defaultNone)

__clear_masks__()[source]

clear the edge weights to None, and set the explain flag to False

__set_masks__(x: torch.Tensor, edge_index: torch.Tensor, edge_mask: Optional[torch.Tensor] = None)[source]

Set the edge weights before message passing

Parameters
  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, dim_node_feature]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • edge_mask (torch.Tensor) – Edge weight matrix before message passing (default: None)

The edge_mask will be randomly initialized when set to None.

Note

When you use the __set_masks__(), the explain flag for all the torch_geometric.nn.MessagePassing modules in model will be assigned with True. In addition, the edge_mask will be assigned to all the modules. Please take __clear_masks__() to reset.

concrete_sample(log_alpha: torch.Tensor, beta: float = 1.0, training: bool = True)[source]

Sample from the instantiation of concrete distribution when training

explain(x: torch.Tensor, edge_index: torch.Tensor, embed: torch.Tensor, tmp: float = 1.0, training: bool = False)Tuple[float, torch.Tensor][source]

explain the GNN behavior for graph with explanation network

Parameters
  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, dim_node_feature]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • embed (torch.Tensor) – Node embedding matrix with shape [num_nodes, dim_embedding]

  • ( (tmp) – obj`float`): The temperature parameter fed to the sample procedure

  • training (bool) – Whether in training procedure or not

Returns

The classification probability for graph with edge mask edge_mask (torch.Tensor): The probability mask for graph edges

Return type

probs (torch.Tensor)

forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)Tuple[None, List, List[Dict]][source]

explain the GNN behavior for graph and calculate the metric values. The interface for the dig.evaluation.XCollector.

Parameters
  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, dim_node_feature]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • kwargs (Dict) –

    The additional parameters
    • top_k (int): The number of edges in the final explanation results

    • y (torch.Tensor): The ground-truth labels

Return type

(None, List[torch.Tensor], List[Dict])

get_subgraph(node_idx: int, x: torch.Tensor, edge_index: torch.Tensor, y: Optional[torch.Tensor] = None, **kwargs)Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List, Dict][source]

extract the subgraph of target node

Parameters
  • node_idx (int) – The node index

  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, dim_node_feature]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • y (torch.Tensor, None) – Node label matrix with shape [num_nodes] (default None)

  • kwargs (Dict, None) – Additional parameters

Return type

(torch.Tensor, torch.Tensor, torch.Tensor, List, Dict)

train_explanation_network(dataset)[source]

training the explanation network by gradient descent(GD) using Adam optimizer

class SubgraphX(model, num_classes: int, device, num_hops: Optional[int] = None, explain_graph: bool = True, rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool = True)[source]

The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations.

Parameters
  • model (torch.nn.Module) – The target model prepared to explain

  • num_classes (int) – Number of classes for the datasets

  • num_hops (int, None) – The number of hops to extract neighborhood of target node (default: None)

  • explain_graph (bool) – Whether to explain graph classification model (default: True)

  • rollout (int) – Number of iteration to get the prediction

  • min_atoms (int) – Number of atoms of the leaf node in search tree

  • c_puct (float) – The hyperparameter which encourages the exploration

  • expand_atoms (int) – The number of atoms to expand when extend the child nodes in the search tree

  • high2low (bool) – Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default: False)

  • local_radius (int) – Number of local radius to calculate l_shapley, mc_l_shapley

  • sample_num (int) – Sampling time of monte carlo sampling approximation for mc_shapley, mc_l_shapley (default: mc_l_shapley)

  • reward_method (str) – The command string to select the

  • subgraph_building_method (str) – The command string for different subgraph building method, such as zero_filling, split (default: zero_filling)

  • save_dir (str, None) – Root directory to save the explanation results (default: None)

  • filename (str) – The filename of results

  • vis (bool) – Whether to show the visualization (default: True)

Example

>>> # For graph classification task
>>> subgraphx = SubgraphX(model=model, num_classes=2)
>>> _, explanation_results, related_preds = subgraphx(x, edge_index)
__call__(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)Tuple[None, List, List[Dict]][source]

explain the GNN behavior for the graph using SubgraphX method

Parameters
  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, dim_node_feature]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • kwargs (Dict) –

    The additional parameters
    • node_idx (int, None): The target node index when explain node classification task

    • max_nodes (int, None): The number of nodes in the final explanation results

Return type

(None, List[torch.Tensor], List[Dict])