dig.xgraph.method¶
Methods interfaces under dig.xgraph.method
.
An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model. (default: |
|
The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. .. note:: For an example, see benchmarks/xgraph. :param model: The GNN module to explain. :type model: torch.nn.Module :param epochs: The number of epochs to train. (default: |
|
An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. |
|
An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model. (default: |
|
An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. |
|
An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network. |
|
The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations. |
- class DeepLIFT(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶
An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model.
(default:
False
)Note
For node classification model, the
explain_graph
flag is False. For an example, see benchmarks/xgraph.- forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param **kwargs:
node_idx
(int): The index of node that is pending to be explained.(for node classification)
sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNNExplainer(model: torch.nn.modules.module.Module, epochs: int = 100, lr: float = 0.01, coff_size: float = 0.001, coff_ent: float = 0.001, explain_graph: bool = False)[source]¶
The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. .. note:: For an example, see `benchmarks/xgraph
- Parameters
model (torch.nn.Module) – The GNN module to explain.
epochs (int, optional) – The number of epochs to train. (default:
100
)lr (float, optional) – The learning rate to apply. (default:
0.01
)explain_graph (bool, optional) – Whether to explain graph classification model (default:
False
)
- forward(x, edge_index, mask_features=False, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param mask_features: Whether to use feature mask. Not recommended.
(Default:
False
)- Parameters
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNN_GI(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶
An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False.- forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
Note
(walks, edge_masks, related_predictions): walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GNN_LRP(model: torch.nn.modules.module.Module, explain_graph=False)[source]¶
An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. :param model: The target model prepared to explain. :type model: torch.nn.Module :param explain_graph: Whether to explain graph classification model.
(default:
False
)Note
For node classification model, the
explain_graph
flag is False. GNN-LRP is very model dependent. Please be sure you know how to modify it for different models. For an example, see benchmarks/xgraph.- forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶
Run the explainer for a specific graph instance. :param x: The graph instance’s input node features. :type x: torch.Tensor :param edge_index: The graph instance’s edge index. :type edge_index: torch.Tensor :param **kwargs:
node_idx
(int): The index of node that is pending to be explained.(for node classification)
sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.- Return type
(walks, edge_masks, related_predictions), walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class GradCAM(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶
An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False. For an example, see benchmarks/xgraph.- forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) Union[Tuple[None, List, List[Dict]], Tuple[List, List, List[Dict]]] [source]¶
Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
- class PGExplainer(model, in_channels: int, device, explain_graph: bool = True, epochs: int = 20, lr: float = 0.005, coff_size: float = 0.01, coff_ent: float = 0.0005, t0: float = 5.0, t1: float = 1.0, sample_bias: float = 0.0, num_hops: Optional[int] = None)[source]¶
An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network.
- Parameters
model (
torch.nn.Module
) – The target model prepared to explainin_channels (
int
) – Number of input channels for the explanation networkexplain_graph (
bool
) – Whether to explain graph classification model (default:True
)epochs (
int
) – Number of epochs to train the explanation networklr (
float
) – Learning rate to train the explanation networkcoff_size (
float
) – Size regularization to constrain the explanation sizecoff_ent (
float
) – Entropy regularization to constrain the connectivity of explanationt0 (
float
) – The temperature at the first epocht1 (
float
) – The temperature at the final epochnum_hops (
int
,None
) – The number of hops to extract neighborhood of target node(default –
None
)
- __set_masks__(x: torch.Tensor, edge_index: torch.Tensor, edge_mask: Optional[torch.Tensor] = None)[source]¶
Set the edge weights before message passing
- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
edge_mask (
torch.Tensor
) – Edge weight matrix before message passing (default:None
)
The
edge_mask
will be randomly initialized when set toNone
.Note
When you use the
__set_masks__()
, the explain flag for all thetorch_geometric.nn.MessagePassing
modules inmodel
will be assigned withTrue
. In addition, theedge_mask
will be assigned to all the modules. Please take__clear_masks__()
to reset.
- concrete_sample(log_alpha: torch.Tensor, beta: float = 1.0, training: bool = True)[source]¶
Sample from the instantiation of concrete distribution when training
- explain(x: torch.Tensor, edge_index: torch.Tensor, embed: torch.Tensor, tmp: float = 1.0, training: bool = False, **kwargs) Tuple[float, torch.Tensor] [source]¶
explain the GNN behavior for graph with explanation network
- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
embed (
torch.Tensor
) – Node embedding matrix with shape[num_nodes, dim_embedding]
( (tmp) – obj`float`): The temperature parameter fed to the sample procedure
training (
bool
) – Whether in training procedure or not
- Returns
The classification probability for graph with edge mask edge_mask (
torch.Tensor
): The probability mask for graph edges- Return type
probs (
torch.Tensor
)
- forward(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) Tuple[None, List, List[Dict]] [source]¶
explain the GNN behavior for graph and calculate the metric values. The interface for the
dig.evaluation.XCollector
.- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
kwargs (
Dict
) –- The additional parameters
top_k (
int
): The number of edges in the final explanation results
- Return type
(
None
, List[torch.Tensor], List[Dict])
- get_subgraph(node_idx: int, x: torch.Tensor, edge_index: torch.Tensor, y: Optional[torch.Tensor] = None, **kwargs) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List, Dict] [source]¶
extract the subgraph of target node
- Parameters
node_idx (
int
) – The node indexx (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
y (
torch.Tensor
,None
) – Node label matrix with shape[num_nodes]
(defaultNone
)kwargs (
Dict
,None
) – Additional parameters
- Return type
(
torch.Tensor
,torch.Tensor
,torch.Tensor
,List
,Dict
)
- class SubgraphX(model, num_classes: int, device, num_hops: Optional[int] = None, verbose: bool = False, explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool = True)[source]¶
The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations.
- Parameters
model (
torch.nn.Module
) – The target model prepared to explainnum_classes (
int
) – Number of classes for the datasetsnum_hops (
int
,None
) – The number of hops to extract neighborhood of target node (default:None
)explain_graph (
bool
) – Whether to explain graph classification model (default:True
)rollout (
int
) – Number of iteration to get the predictionmin_atoms (
int
) – Number of atoms of the leaf node in search treec_puct (
float
) – The hyperparameter which encourages the explorationexpand_atoms (
int
) – The number of atoms to expand when extend the child nodes in the search treehigh2low (
bool
) – Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default:False
)local_radius (
int
) – Number of local radius to calculatel_shapley
,mc_l_shapley
sample_num (
int
) – Sampling time of monte carlo sampling approximation formc_shapley
,mc_l_shapley
(default:mc_l_shapley
)reward_method (
str
) – The command string to select thesubgraph_building_method (
str
) – The command string for different subgraph building method, such aszero_filling
,split
(default:zero_filling
)save_dir (
str
,None
) – Root directory to save the explanation results (default:None
)filename (
str
) – The filename of resultsvis (
bool
) – Whether to show the visualization (default:True
)
Example
>>> # For graph classification task >>> subgraphx = SubgraphX(model=model, num_classes=2) >>> _, explanation_results, related_preds = subgraphx(x, edge_index)
- __call__(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) Tuple[None, List, List[Dict]] [source]¶
explain the GNN behavior for the graph using SubgraphX method :param x: Node feature matrix with shape
[num_nodes, dim_node_feature]
- Parameters
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
kwargs (
Dict
) –
- Return type
(
None
, List[torch.Tensor], List[Dict])