dig.sslgraph.utils¶
Utilities under dig.sslgraph.utils
.
A wrapped |
To setup seed for reproducible experiments. |
- class Encoder(feat_dim, hidden_dim, n_layers=5, pool='sum', gnn='gin', node_level=False, graph_level=True, **kwargs)[source]¶
A wrapped
torch.nn.Module
class for the convinient instantiation of pre-implemented graph encoders.- Parameters
feat_dim (int) – The dimension of input node features.
hidden_dim (int) – The dimension of node-level (local) embeddings.
n_layer (int, optional) – The number of GNN layers in the encoder. (default:
5
)pool (string, optional) – The global pooling methods,
sum
ormean
. (default:sum
)gnn (string, optional) – The type of GNN layer,
gcn
,gin
orresgcn
. (default:gin
)bn (bool, optional) – Whether to include batch normalization. (default:
True
)act (string, optional) – The activation function,
relu
orprelu
. (default:relu
)bias (bool, optional) – Whether to include bias term in Linear. (default:
True
)xavier (bool, optional) – Whether to apply xavier initialization. (default:
True
)node_level (bool, optional) – If
True
, the encoder will output node level embedding (local representations). (default:False
)graph_level (bool, optional) – If
True
, the encoder will output graph level embeddings (global representations). (default:True
)edge_weight (bool, optional) – Only applied to GCN. Whether to use edge weight to compute the aggregation. (default:
False
)
Note
For GCN and GIN encoders, the dimension of the output node-level (local) embedding will be
hidden_dim
, whereas the node-level embedding will behidden_dim
*n_layers
. For ResGCN, the output embeddings for boths node and graphs will have dimensionshidden_dim
.Examples
>>> feat_dim = dataset[0].x.shape[1] >>> encoder = Encoder(feat_dim, 128, n_layer=3, gnn="gin") >>> encoder(some_batched_data).shape # graph-level embedding of shape [batch_size, 128*3] torch.Size([32, 384])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data).shape # node-level embedding of shape [n_nodes, 128] torch.Size([707, 128])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data) # a tuple of graph-level and node-level embeddings (tensor([...]), tensor([...]))
- forward(data)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.