dig.xgraph.method¶
Methods interfaces under dig.xgraph.method
.
An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences. |
|
The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. |
|
An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. |
|
An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks. |
|
An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. |
|
An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network. |
|
The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations. |
-
class
DeepLIFT
(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶ An implementation of DeepLIFT on graph in Learning Important Features Through Propagating Activation Differences.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False. For an example, see benchmarks/xgraph.-
forward
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶ Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
-
class
GNNExplainer
(model: torch.nn.modules.module.Module, epochs: int = 100, lr: float = 0.01, explain_graph: bool = False)[source]¶ The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions.
Note
For an example, see benchmarks/xgraph.
- Parameters
model (torch.nn.Module) – The GNN module to explain.
epochs (int, optional) – The number of epochs to train. (default:
100
)lr (float, optional) – The learning rate to apply. (default:
0.01
)explain_graph (bool, optional) – Whether to explain graph classification model (default:
False
)
-
forward
(x, edge_index, mask_features=False, **kwargs)[source]¶ Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
mask_features (bool, optional) – Whether to use feature mask. Not recommended. (Default:
False
)**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
-
class
GNN_GI
(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶ An implementation of GNN-GI in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False.-
forward
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶ Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
Note
(walks, edge_masks, related_predictions): walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
-
class
GNN_LRP
(model: torch.nn.modules.module.Module, explain_graph=False)[source]¶ An implementation of GNN-LRP in Higher-Order Explanations of Graph Neural Networks via Relevant Walks.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False. GNN-LRP is very model dependent. Please be sure you know how to modify it for different models. For an example, see benchmarks/xgraph.-
forward
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs)[source]¶ Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
(walks, edge_masks, related_predictions), walks is a dictionary including walks’ edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
-
class
GradCAM
(model: torch.nn.modules.module.Module, explain_graph: bool = False)[source]¶ An implementation of GradCAM on graph in Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.
- Parameters
model (torch.nn.Module) – The target model prepared to explain.
explain_graph (bool, optional) – Whether to explain graph classification model. (default:
False
)
Note
For node classification model, the
explain_graph
flag is False. For an example, see benchmarks/xgraph.-
forward
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) → Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]][source]¶ Run the explainer for a specific graph instance.
- Parameters
x (torch.Tensor) – The graph instance’s input node features.
edge_index (torch.Tensor) – The graph instance’s edge index.
**kwargs (dict) –
node_idx
(int): The index of node that is pending to be explained. (for node classification)sparsity
(float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default:0.7
)num_classes
(int): The number of task’s classes.
- Return type
Note
(None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities.
-
class
PGExplainer
(model, in_channels: int, device, explain_graph: bool = True, epochs: int = 20, lr: float = 0.005, coff_size: float = 0.01, coff_ent: float = 0.0005, t0: float = 5.0, t1: float = 1.0, num_hops: Optional[int] = None)[source]¶ An implementation of PGExplainer in Parameterized Explainer for Graph Neural Network.
- Parameters
model (
torch.nn.Module
) – The target model prepared to explainin_channels (
int
) – Number of input channels for the explanation networkexplain_graph (
bool
) – Whether to explain graph classification model (default:True
)epochs (
int
) – Number of epochs to train the explanation networklr (
float
) – Learning rate to train the explanation networkcoff_size (
float
) – Size regularization to constrain the explanation sizecoff_ent (
float
) – Entropy regularization to constrain the connectivity of explanationt0 (
float
) – The temperature at the first epocht1 (
float
) – The temperature at the final epochnum_hops (
int
,None
) – The number of hops to extract neighborhood of target node(default –
None
)
-
__set_masks__
(x: torch.Tensor, edge_index: torch.Tensor, edge_mask: Optional[torch.Tensor] = None)[source]¶ Set the edge weights before message passing
- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
edge_mask (
torch.Tensor
) – Edge weight matrix before message passing (default:None
)
The
edge_mask
will be randomly initialized when set toNone
.Note
When you use the
__set_masks__()
, the explain flag for all thetorch_geometric.nn.MessagePassing
modules inmodel
will be assigned withTrue
. In addition, theedge_mask
will be assigned to all the modules. Please take__clear_masks__()
to reset.
-
concrete_sample
(log_alpha: torch.Tensor, beta: float = 1.0, training: bool = True)[source]¶ Sample from the instantiation of concrete distribution when training
-
explain
(x: torch.Tensor, edge_index: torch.Tensor, embed: torch.Tensor, tmp: float = 1.0, training: bool = False, **kwargs) → Tuple[float, torch.Tensor][source]¶ explain the GNN behavior for graph with explanation network
- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
embed (
torch.Tensor
) – Node embedding matrix with shape[num_nodes, dim_embedding]
( (tmp) – obj`float`): The temperature parameter fed to the sample procedure
training (
bool
) – Whether in training procedure or not
- Returns
The classification probability for graph with edge mask edge_mask (
torch.Tensor
): The probability mask for graph edges- Return type
probs (
torch.Tensor
)
-
forward
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) → Tuple[None, List, List[Dict]][source]¶ explain the GNN behavior for graph and calculate the metric values. The interface for the
dig.evaluation.XCollector
.- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
kwargs (
Dict
) –- The additional parameters
top_k (
int
): The number of edges in the final explanation resultsy (
torch.Tensor
): The ground-truth labels
- Return type
(
None
, List[torch.Tensor], List[Dict])
-
get_subgraph
(node_idx: int, x: torch.Tensor, edge_index: torch.Tensor, y: Optional[torch.Tensor] = None, **kwargs) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List, Dict][source]¶ extract the subgraph of target node
- Parameters
node_idx (
int
) – The node indexx (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
y (
torch.Tensor
,None
) – Node label matrix with shape[num_nodes]
(defaultNone
)kwargs (
Dict
,None
) – Additional parameters
- Return type
(
torch.Tensor
,torch.Tensor
,torch.Tensor
,List
,Dict
)
-
class
SubgraphX
(model, num_classes: int, device, num_hops: Optional[int] = None, explain_graph: bool = True, rollout: int = 20, min_atoms: int = 5, c_puct: float = 10.0, expand_atoms=14, high2low=False, local_radius=4, sample_num=100, reward_method='mc_l_shapley', subgraph_building_method='zero_filling', save_dir: Optional[str] = None, filename: str = 'example', vis: bool = True)[source]¶ The implementation of paper On Explainability of Graph Neural Networks via Subgraph Explorations.
- Parameters
model (
torch.nn.Module
) – The target model prepared to explainnum_classes (
int
) – Number of classes for the datasetsnum_hops (
int
,None
) – The number of hops to extract neighborhood of target node (default:None
)explain_graph (
bool
) – Whether to explain graph classification model (default:True
)rollout (
int
) – Number of iteration to get the predictionmin_atoms (
int
) – Number of atoms of the leaf node in search treec_puct (
float
) – The hyperparameter which encourages the explorationexpand_atoms (
int
) – The number of atoms to expand when extend the child nodes in the search treehigh2low (
bool
) – Whether to expand children nodes from high degree to low degree when extend the child nodes in the search tree (default:False
)local_radius (
int
) – Number of local radius to calculatel_shapley
,mc_l_shapley
sample_num (
int
) – Sampling time of monte carlo sampling approximation formc_shapley
,mc_l_shapley
(default:mc_l_shapley
)reward_method (
str
) – The command string to select thesubgraph_building_method (
str
) – The command string for different subgraph building method, such aszero_filling
,split
(default:zero_filling
)save_dir (
str
,None
) – Root directory to save the explanation results (default:None
)filename (
str
) – The filename of resultsvis (
bool
) – Whether to show the visualization (default:True
)
Example
>>> # For graph classification task >>> subgraphx = SubgraphX(model=model, num_classes=2) >>> _, explanation_results, related_preds = subgraphx(x, edge_index)
-
__call__
(x: torch.Tensor, edge_index: torch.Tensor, **kwargs) → Tuple[None, List, List[Dict]][source]¶ explain the GNN behavior for the graph using SubgraphX method
- Parameters
x (
torch.Tensor
) – Node feature matrix with shape[num_nodes, dim_node_feature]
edge_index (
torch.Tensor
) – Graph connectivity in COO format with shape[2, num_edges]
kwargs (
Dict
) –
- Return type
(
None
, List[torch.Tensor], List[Dict])