dig.sslgraph.utils¶
Utilities under dig.sslgraph.utils
.
A wrapped |
To setup seed for reproducible experiments. |
-
class
Encoder
(feat_dim, hidden_dim, n_layers=5, pool='sum', gnn='gin', bn=True, act='relu', bias=True, xavier=True, node_level=False, graph_level=True, edge_weight=False)[source]¶ A wrapped
torch.nn.Module
class for the convinient instantiation of pre-implemented graph encoders.- Parameters
feat_dim (int) – The dimension of input node features.
hidden_dim (int) – The dimension of node-level (local) embeddings.
n_layer (int, optional) – The number of GNN layers in the encoder. (default:
5
)pool (string, optional) – The global pooling methods,
sum
ormean
. (default:sum
)gnn (string, optional) – The type of GNN layer,
gcn
,gin
orresgcn
. (default:gin
)bn (bool, optional) – Whether to include batch normalization. (default:
True
)act (string, optional) – The activation function,
relu
orprelu
. (default:relu
)bias (bool, optional) – Whether to include bias term in Linear. (default:
True
)xavier (bool, optional) – Whether to apply xavier initialization. (default:
True
)node_level (bool, optional) – If
True
, the encoder will output node level embedding (local representations). (default:False
)graph_level (bool, optional) – If
True
, the encoder will output graph level embeddings (global representations). (default:True
)edge_weight (bool, optional) – Only applied to GCN. Whether to use edge weight to compute the aggregation. (default:
False
)
Note
For GCN and GIN encoders, the dimension of the output node-level (local) embedding will be
hidden_dim
, whereas the node-level embedding will behidden_dim
*n_layers
. For ResGCN, the output embeddings for boths node and graphs will have dimensionshidden_dim
.Examples
>>> feat_dim = dataset[0].x.shape[1] >>> encoder = Encoder(feat_dim, 128, n_layer=3, gnn="gin") >>> encoder(some_batched_data).shape # graph-level embedding of shape [batch_size, 128*3] torch.Size([32, 384])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data).shape # node-level embedding of shape [n_nodes, 128] torch.Size([707, 128])
>>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data) # a tuple of graph-level and node-level embeddings (tensor([...]), tensor([...]))