dig.xgraph.method¶
Methods interfaces under dig.xgraph.method.
An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model. (default: |
|
An implementation of FlowX in FlowX: Towards Explainable Graph Neural Networks via Message Flows. :param model: The target model prepared to explain. :type model: torch.nn.Module :param epochs: The training steps. :type epochs: int, optional :param lr: The explainer training learning rate. :type lr: float, optional :param explain_graph: Whether to explain graph classification model. (default: |
|
The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. .. note:: For an example, see benchmarks/xgraph. :param model: The GNN module to explain. :type model: torch.nn.Module :param epochs: The number of epochs to train. (default: |
|
An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. |
|
An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model. (default: |
|
An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. |
|
An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network. |
|
The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations. |
- class DeepLIFT(model: Module, explain_graph: bool = False)[source]¶
An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model.
(default:
False)Note
For node classification model, the
explain_graphflag is False. For an example, see benchmarks/xgraph.- forward(x: Tensor, edge_index: Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param **kwargs:
node_idx(int): The index of node that is pending to be explained.(for node classification)
sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class FlowX(model, epochs=500, lr=0.3, explain_graph=False, molecule=False)[source]¶
An implementation of FlowX in FlowX: Towards Explainable Graph Neural Networks via Message Flows. :param model: The target model prepared to explain. :type model: torch.nn.Module :param epochs: The training steps. :type epochs: int, optional :param lr: The explainer training learning rate. :type lr: float, optional :param explain_graph: Whether to explain graph classification model.
(default:
False)Note
For node classification model, the
explain_graphflag is False.- flow_shap(x, edge_index, edge_index_with_loop, walk_indices_list, **kwargs)[source]¶
Flow shapley calculations.
- forward(x: Tensor, edge_index: Tensor, **kwargs) Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]][source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param **kwargs:
node_idx(int, list, tuple, torch.Tensor): The index of node that is pending to be explained.(for node classification)
sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)num_classes(int): The number of task’s classes.Note
(None, masks, related_predictions): masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNNExplainer(model: Module, epochs: int = 100, lr: float = 0.01, coff_edge_size: float = 0.001, coff_edge_ent: float = 0.001, coff_node_feat_size: float = 1.0, coff_node_feat_ent: float = 0.1, explain_graph: bool = False, indirect_graph_symmetric_weights: bool = False)[source]¶
The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. .. note:: For an example, see `benchmarks/xgraph
- Parameters
model (torch.nn.Module) – The GNN module to explain.
epochs (int, optional) – The number of epochs to train. (default:
100)lr (float, optional) – The learning rate to apply. (default:
0.01)explain_graph (bool, optional) – Whether to explain graph classification model (default:
False)indirect_graph_symmetric_weights (bool, optional) – If True, then the explainer will first realize whether this graph input has indirect edges, then makes its edge weights symmetric. (default:
False)
- forward(x, edge_index, mask_features=False, target_label=None, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param mask_features: Whether to use feature mask. Not recommended.
(Default:
False)- Parameters
target_label (torch.Tensor, optional) – if given then apply optimization only on this label
**kwargs (dict) –
node_idx(int, list, tuple, torch.Tensor): The index of node that is pending to be explained. (for node classification)sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)num_classes(int): The number of task’s classes.
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNN_GI(model: Module, explain_graph: bool = False)[source]¶
An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False)
Note
For node classification model, the
explain_graphflag is False.- forward(x: Tensor, edge_index: Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx(int): The index of node that is pending to be explained. (for node classification)sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)num_classes(int): The number of task’s classes.
- Return type
Note
(walks, edge_masks, related_predictions): walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNN_LRP(model: Module, explain_graph=False)[source]¶
An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model.
(default:
False)Note
For node classification model, the
explain_graphflag is False. GNN-LRP is very model dependent. Please be sure you know how to modify it for different models. For an example, see benchmarks/xgraph.- forward(x: Tensor, edge_index: Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param **kwargs:
node_idx(int): The index of node that is pending to be explained.(for node classification)
sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)num_classes(int): The number of task’s classes.- Return type
(walks, edge_masks, related_predictions), walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GradCAM(model: Module, explain_graph: bool = False)[source]¶
An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False)
Note
For node classification model, the
explain_graphflag is False. For an example, see benchmarks/xgraph.- forward(x: Tensor, edge_index: Tensor, **kwargs) Union[Tuple[None, List, List[Dict]], Tuple[List, List, List[Dict]]][source]¶
Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx(int): The index of node that is pending to be explained. (for node classification)sparsity(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7)num_classes(int): The number of task’s classes.
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class PGExplainer(model, in_channels: int, device, explain_graph: bool = True, epochs: int = 20, lr: float = 0.005, coff_size: float = 0.01, coff_ent: float = 0.0005, t0: float = 5.0, t1: float = 1.0, sample_bias: float = 0.0, num_hops: Optional[int] = None)[source]¶
An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network.
- Parameters
model (
torch.nn.Module) – The target model prepared to explainin_channels (
int) – Number of input channels for the explanation networkexplain_graph (
bool) – Whether to explain graph classification model (default:True)epochs (
int) – Number of epochs to train the explanation networklr (
float) – Learning rate to train the explanation networkcoff_size (
float) – Size regularization to constrain the explanation sizecoff_ent (
float) – Entropy regularization to constrain the connectivity of explanationt0 (
float) – The temperature at the first epocht1 (
float) – The temperature at the final epochnum_hops (
int,None) – The number of hops to extract neighborhood of target node(default –
None)
- __set_masks__(x: Tensor, edge_index: Tensor, edge_mask: Optional[Tensor] = None)[source]¶
Set the edge weights before message passing
- Parameters
x (
torch.Tensor) – Node feature matrix with shape[num_nodes, dim_node_feature]edge_index (
torch.Tensor) – Graph connectivity in COO format with shape[2, num_edges]edge_mask (
torch.Tensor) – Edge weight matrix before message passing (default:None)
The
edge_maskwill be randomly initialized when set toNone.Note
When you use the
__set_masks__(), the explain flag for all thetorch_geometric.nn.MessagePassingmodules inmodelwill be assigned withTrue. In addition, theedge_maskwill be assigned to all the modules. Please take__clear_masks__()to reset.
- concrete_sample(log_alpha: Tensor, beta: float = 1.0, training: bool = True)[source]¶
Sample from the instantiation of concrete distribution when training
- explain(x: Tensor, edge_index: Tensor, embed: Tensor, tmp: float = 1.0, training: bool = False, **kwargs) Tuple[float, Tensor][source]¶
explain the GNN behavior for graph with explanation network
- Parameters
x (
torch.Tensor) – Node feature matrix with shape[num_nodes, dim_node_feature]edge_index (
torch.Tensor) – Graph connectivity in COO format with shape[2, num_edges]embed (
torch.Tensor) – Node embedding matrix with shape[num_nodes, dim_embedding]( (tmp) – obj`float`): The temperature parameter fed to the sample procedure
training (
bool) – Whether in training procedure or not
- Returns
The classification probability for graph with edge mask edge_mask (
torch.Tensor): The probability mask for graph edges- Return type
probs (
torch.Tensor)
- forward(x: Tensor, edge_index: Tensor, **kwargs) Tuple[None, List, List[Dict]][source]¶
explain the GNN behavior for graph and calculate the metric values. The interface for the
dig.evaluation.XCollector.- Parameters
x (
torch.Tensor) – Node feature matrix with shape[num_nodes, dim_node_feature]edge_index (
torch.Tensor) – Graph connectivity in COO format with shape[2, num_edges]kwargs (
Dict) –- The additional parameters
top_k (
int): The number of edges in the final explanation results
- Return type
(
None, List[torch.Tensor], List[Dict])
- get_subgraph(node_idx: int, x: Tensor, edge_index: Tensor, y: Optional[Tensor] = None, **kwargs) Tuple[Tensor, Tensor, Tensor, List, Dict][source]¶
extract the subgraph of target node
- Parameters
node_idx (
int) – The node indexx (
torch.Tensor) – Node feature matrix with shape[num_nodes, dim_node_feature]edge_index (
torch.Tensor) – Graph connectivity in COO format with shape[2, num_edges]y (
torch.Tensor,None) – Node label matrix with shape[num_nodes](defaultNone)kwargs (
Dict,None) – Additional parameters
- Return type
(
torch.Tensor,torch.Tensor,torch.Tensor,List,Dict)
- class SubgraphX(model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False, explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool = True)[source]¶
The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations.
- Parameters
model (
torch.nn.Module) – The target model prepared to explainnum_classes (
int) – Number of classes for the datasetsnum_hops (
int,None) – The number of hops to extract neighborhood of target node (default:None)explain_graph (
bool) – Whether to explain graph classification model (default:True)rollout (
int) – Number of iteration to get the predictionmin_atoms (
int) – Number of atoms of the leaf node in search treec_puct (
float) – The hyperparameter which encourages the explorationexpand_atoms (
int) – The number of atoms to expand when extend the child nodes in the search treehigh2low (
bool) – Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default:False)local_radius (
int) – Number of local radius to calculatel_shapley,mc_l_shapleysample_num (
int) – Sampling time of monte carlo sampling approximation formc_shapley,mc_l_shapley(default:mc_l_shapley)reward_method (
str) – The command string to select thesubgraph_building_method (
str) – The command string for different subgraph building method, such aszero_filling,split(default:zero_filling)save_dir (
str,None) – Root directory to save the explanation results (default:None)filename (
str) – The filename of resultsvis (
bool) – Whether to show the visualization (default:True)
Example
>>> # For graph classification task >>> subgraphx = SubgraphX(model=model, num_classes=2) >>> _, explanation_results, related_preds = subgraphx(x, edge_index)
- __call__(x: Tensor, edge_index: Tensor, **kwargs) Tuple[None, List, List[Dict]][source]¶
explain the GNN behavior for the graph using SubgraphX method :param x: Node feature matrix with shape
[num_nodes, dim_node_feature]- Parameters
edge_index (
torch.Tensor) – Graph connectivity in COO format with shape[2, num_edges]kwargs (
Dict) –
- Return type
(
None, List[torch.Tensor], List[Dict])