dig.sslgraph.utils

Utilities under dig.sslgraph.utils.

Encoder

A wrapped torch.nn.Module class for the convinient instantiation of pre-implemented graph encoders.

setup_seed

To setup seed for reproducible experiments.

class Encoder(feat_dim, hidden_dim, n_layers=5, pool='sum', gnn='gin', node_level=False, graph_level=True, **kwargs)[source]

A wrapped torch.nn.Module class for the convinient instantiation of pre-implemented graph encoders.

Parameters
  • feat_dim (int) – The dimension of input node features.

  • hidden_dim (int) – The dimension of node-level (local) embeddings.

  • n_layer (int, optional) – The number of GNN layers in the encoder. (default: 5)

  • pool (string, optional) – The global pooling methods, sum or mean. (default: sum)

  • gnn (string, optional) – The type of GNN layer, gcn, gin or resgcn. (default: gin)

  • bn (bool, optional) – Whether to include batch normalization. (default: True)

  • act (string, optional) – The activation function, relu or prelu. (default: relu)

  • bias (bool, optional) – Whether to include bias term in Linear. (default: True)

  • xavier (bool, optional) – Whether to apply xavier initialization. (default: True)

  • node_level (bool, optional) – If True, the encoder will output node level embedding (local representations). (default: False)

  • graph_level (bool, optional) – If True, the encoder will output graph level embeddings (global representations). (default: True)

  • edge_weight (bool, optional) – Only applied to GCN. Whether to use edge weight to compute the aggregation. (default: False)

Note

For GCN and GIN encoders, the dimension of the output node-level (local) embedding will be hidden_dim, whereas the node-level embedding will be hidden_dim * n_layers. For ResGCN, the output embeddings for boths node and graphs will have dimensions hidden_dim.

Examples

>>> feat_dim = dataset[0].x.shape[1]
>>> encoder = Encoder(feat_dim, 128, n_layer=3, gnn="gin")
>>> encoder(some_batched_data).shape # graph-level embedding of shape [batch_size, 128*3]
torch.Size([32, 384])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False)
>>> encoder(some_batched_data).shape # node-level embedding of shape [n_nodes, 128]
torch.Size([707, 128])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False)
>>> encoder(some_batched_data) # a tuple of graph-level and node-level embeddings
(tensor([...]), tensor([...]))
forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

setup_seed(seed)[source]

To setup seed for reproducible experiments.

Parameters

seed (int, or float) – The number used as seed.