dig.xgraph.utils

Methods interfaces under dig.xgraph.method.

MCTS

Monte Carlo Tree Search Method.

class MCTS(X: torch.Tensor, edge_index: torch.Tensor, num_hops: int, n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms: int = 14, high2low: bool = False, node_idx: Optional[int] = None, score_func: Optional[Callable] = None, device='cpu')[source]

Monte Carlo Tree Search Method.

Parameters
  • X (torch.Tensor) – Input node features

  • edge_index (torch.Tensor) – The edge indices.

  • num_hops (int) – The number of hops \(k\).

  • n_rollout (int) – The number of sequence to build the monte carlo tree.

  • min_atoms (int) – The number of atoms for the subgraph in the monte carlo tree leaf node.

  • c_puct (float) – The hyper-parameter to encourage exploration while searching.

  • expand_atoms (int) – The number of children to expand.

  • high2low (bool) – Whether to expand children tree node from high degree nodes to low degree nodes.

  • node_idx (int) – The target node index to extract the neighborhood.

  • score_func (Callable) – The reward function for tree node, such as mc_shapely and mc_l_shapely.