Source code for dig.sslgraph.utils.encoders

from functools import partial
import torch
import torch.nn.functional as F

from torch.nn import Parameter
from torch.nn import Sequential, Linear, BatchNorm1d
from torch_scatter import scatter_add
from torch_geometric.nn import GINConv, GCNConv, global_add_pool, global_mean_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.inits import glorot, zeros


[docs]class Encoder(torch.nn.Module): r"""A wrapped :class:`torch.nn.Module` class for the convinient instantiation of pre-implemented graph encoders. Args: feat_dim (int): The dimension of input node features. hidden_dim (int): The dimension of node-level (local) embeddings. n_layer (int, optional): The number of GNN layers in the encoder. (default: :obj:`5`) pool (string, optional): The global pooling methods, :obj:`sum` or :obj:`mean`. (default: :obj:`sum`) gnn (string, optional): The type of GNN layer, :obj:`gcn`, :obj:`gin` or :obj:`resgcn`. (default: :obj:`gin`) bn (bool, optional): Whether to include batch normalization. (default: :obj:`True`) act (string, optional): The activation function, :obj:`relu` or :obj:`prelu`. (default: :obj:`relu`) bias (bool, optional): Whether to include bias term in Linear. (default: :obj:`True`) xavier (bool, optional): Whether to apply xavier initialization. (default: :obj:`True`) node_level (bool, optional): If :obj:`True`, the encoder will output node level embedding (local representations). (default: :obj:`False`) graph_level (bool, optional): If :obj:`True`, the encoder will output graph level embeddings (global representations). (default: :obj:`True`) edge_weight (bool, optional): Only applied to GCN. Whether to use edge weight to compute the aggregation. (default: :obj:`False`) Note ---- For GCN and GIN encoders, the dimension of the output node-level (local) embedding will be :obj:`hidden_dim`, whereas the node-level embedding will be :obj:`hidden_dim` * :obj:`n_layers`. For ResGCN, the output embeddings for boths node and graphs will have dimensions :obj:`hidden_dim`. Examples -------- >>> feat_dim = dataset[0].x.shape[1] >>> encoder = Encoder(feat_dim, 128, n_layer=3, gnn="gin") >>> encoder(some_batched_data).shape # graph-level embedding of shape [batch_size, 128*3] torch.Size([32, 384]) >>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data).shape # node-level embedding of shape [n_nodes, 128] torch.Size([707, 128]) >>> encoder = Encoder(feat_dim, 128, n_layer=5, node_level=True, graph_level=False) >>> encoder(some_batched_data) # a tuple of graph-level and node-level embeddings (tensor([...]), tensor([...])) """ def __init__(self, feat_dim, hidden_dim, n_layers=5, pool='sum', gnn='gin', node_level=False, graph_level=True, **kwargs): super(Encoder, self).__init__() if gnn == 'gin': self.encoder = GIN(feat_dim, hidden_dim, n_layers, pool, **kwargs) elif gnn == 'gcn': self.encoder = GCN(feat_dim, hidden_dim, n_layers, pool, **kwargs) elif gnn == 'resgcn': self.encoder = ResGCN(feat_dim, hidden_dim, num_conv_layers=n_layers, global_pool=pool) self.node_level = node_level self.graph_level = graph_level
[docs] def forward(self, data): z_g, z_n = self.encoder(data) if self.node_level and self.graph_level: return z_g, z_n elif self.graph_level: return z_g else: return z_n
class GIN(torch.nn.Module): def __init__(self, feat_dim, hidden_dim, n_layers=3, pool='sum', bn=False, act='relu', bias=True, xavier=True): super(GIN, self).__init__() if bn: self.bns = torch.nn.ModuleList() else: self.bns = None self.convs = torch.nn.ModuleList() self.n_layers = n_layers self.pool = pool self.act = torch.nn.PReLU() if act == 'prelu' else torch.nn.ReLU() for i in range(n_layers): start_dim = hidden_dim if i else feat_dim nn = Sequential(Linear(start_dim, hidden_dim, bias=bias), self.act, Linear(hidden_dim, hidden_dim, bias=bias)) if xavier: self.weights_init(nn) conv = GINConv(nn) self.convs.append(conv) if bn: self.bns.append(BatchNorm1d(hidden_dim)) def weights_init(self, module): for m in module.modules(): if isinstance(m, Linear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch xs = [] for i in range(self.n_layers): x = self.convs[i](x, edge_index) x = self.act(x) if self.bns is not None: x = self.bns[i](x) xs.append(x) if self.pool == 'sum': xpool = [global_add_pool(x, batch) for x in xs] else: xpool = [global_mean_pool(x, batch) for x in xs] global_rep = torch.cat(xpool, 1) return global_rep, x class GCN(torch.nn.Module): def __init__(self, feat_dim, hidden_dim, n_layers=3, pool='sum', bn=False, act='relu', bias=True, xavier=True, edge_weight=False): super(GCN, self).__init__() if bn: self.bns = torch.nn.ModuleList() else: self.bns = None self.convs = torch.nn.ModuleList() self.acts = torch.nn.ModuleList() self.n_layers = n_layers self.pool = pool self.edge_weight = edge_weight self.normalize = not edge_weight self.add_self_loops = not edge_weight if act == 'prelu': a = torch.nn.PReLU() else: a = torch.nn.ReLU() for i in range(n_layers): start_dim = hidden_dim if i else feat_dim conv = GCNConv(start_dim, hidden_dim, bias=bias, add_self_loops=self.add_self_loops, normalize=self.normalize) if xavier: self.weights_init(conv.lin) self.convs.append(conv) self.acts.append(a) if bn: self.bns.append(BatchNorm1d(hidden_dim)) def weights_init(self, m): if isinstance(m, GCNConv): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.edge_weight: edge_attr = data.edge_attr else: edge_attr = None xs = [] for i in range(self.n_layers): x = self.convs[i](x, edge_index, edge_attr) x = self.acts[i](x) if self.bns is not None: x = self.bns[i](x) xs.append(x) if self.pool == 'sum': xpool = [global_add_pool(x, batch) for x in xs] else: xpool = [global_mean_pool(x, batch) for x in xs] global_rep = torch.cat(xpool, 1) return global_rep, x class ResGCNConv(MessagePassing): def __init__(self, in_channels, out_channels, improved=False, cached=False, bias=True, edge_norm=True, gfn=False): super(ResGCNConv, self).__init__('add') self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.cached_result = None self.edge_norm = edge_norm self.gfn = gfn self.weight = Parameter(torch.Tensor(in_channels, out_channels)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.weights_init() def weights_init(self): glorot(self.weight) zeros(self.bias) self.cached_result = None @staticmethod def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) # Add edge_weight for loop edges. loop_weight = torch.full((num_nodes, ), 1 if not improved else 2, dtype=edge_weight.dtype, device=edge_weight.device) edge_weight = torch.cat([edge_weight, loop_weight], dim=0) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] def forward(self, x, edge_index, edge_weight=None): x = torch.matmul(x, self.weight) if self.gfn: return x if not self.cached or self.cached_result is None: if self.edge_norm: edge_index, norm = ResGCNConv.norm(edge_index, x.size(0), edge_weight, self.improved, x.dtype) else: norm = None self.cached_result = edge_index, norm edge_index, norm = self.cached_result return self.propagate(edge_index, x=x, norm=norm) def message(self, x_j, norm): if self.edge_norm: return norm.view(-1, 1) * x_j else: return x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) class ResGCN(torch.nn.Module): def __init__(self, feat_dim, hidden_dim, num_feat_layers=1, num_conv_layers=3, num_fc_layers=2, xg_dim=None, bn=True, gfn=False, collapse=False, residual=False, global_pool="sum", dropout=0, edge_norm=True): super(ResGCN, self).__init__() assert num_feat_layers == 1, "more feat layers are not now supported" self.conv_residual = residual self.fc_residual = False # no skip-connections for fc layers. self.collapse = collapse self.bn = bn assert "sum" in global_pool or "mean" in global_pool, global_pool if "sum" in global_pool: self.global_pool = global_add_pool else: self.global_pool = global_mean_pool self.dropout = dropout GConv = partial(ResGCNConv, edge_norm=edge_norm, gfn=gfn) if xg_dim is not None: # Utilize graph level features. self.use_xg = True self.bn1_xg = BatchNorm1d(xg_dim) self.lin1_xg = Linear(xg_dim, hidden_dim) self.bn2_xg = BatchNorm1d(hidden_dim) self.lin2_xg = Linear(hidden_dim, hidden_dim) else: self.use_xg = False if collapse: self.bn_feat = BatchNorm1d(feat_dim) self.bns_fc = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() if "gating" in global_pool: self.gating = torch.nn.Sequential( Linear(feat_dim, feat_dim), torch.nn.ReLU(), Linear(feat_dim, 1), torch.nn.Sigmoid()) else: self.gating = None for i in range(num_fc_layers - 1): self.bns_fc.append(BatchNorm1d(hidden_in)) self.lins.append(Linear(hidden_in, hidden_dim)) hidden_in = hidden_dim else: self.bn_feat = BatchNorm1d(feat_dim) feat_gfn = True # set true so GCNConv is feat transform self.conv_feat = ResGCNConv(feat_dim, hidden_dim, gfn=feat_gfn) if "gating" in global_pool: self.gating = torch.nn.Sequential( Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), Linear(hidden_dim, 1), torch.nn.Sigmoid()) else: self.gating = None self.bns_conv = torch.nn.ModuleList() self.convs = torch.nn.ModuleList() for i in range(num_conv_layers): self.bns_conv.append(BatchNorm1d(hidden_dim)) self.convs.append(GConv(hidden_dim, hidden_dim)) self.bn_hidden = BatchNorm1d(hidden_dim) self.bns_fc = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() for i in range(num_fc_layers - 1): self.bns_fc.append(BatchNorm1d(hidden_dim)) self.lins.append(Linear(hidden_dim, hidden_dim)) # BN initialization. for m in self.modules(): if isinstance(m, (torch.nn.BatchNorm1d)): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0.0001) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.use_xg: # xg is of shape [n_graphs, feat_dim] xg = self.bn1_xg(data.xg) if self.bn else xg xg = F.relu(self.lin1_xg(xg)) xg = self.bn2_xg(xg) if self.bn else xg xg = F.relu(self.lin2_xg(xg)) else: xg = None x = self.bn_feat(x) if self.bn else x x = F.relu(self.conv_feat(x, edge_index)) for i, conv in enumerate(self.convs): x_ = self.bns_conv[i](x) if self.bn else x x_ = F.relu(conv(x_, edge_index)) x = x + x_ if self.conv_residual else x_ local_rep = x gate = 1 if self.gating is None else self.gating(x) x = self.global_pool(x * gate, batch) x = x if xg is None else x + xg for i, lin in enumerate(self.lins): x_ = self.bns_fc[i](x) if self.bn else x x_ = F.relu(lin(x_)) x = x + x_ if self.fc_residual else x_ x = self.bn_hidden(x) if self.dropout > 0: x = F.dropout(x, p=self.dropout, training=self.training) return x, local_rep