DeepGCNLayer
torchmil.nn.gnns.DeepGCNLayer
Bases: Module
Implementation of a DeepGCN layer.
Adapts the implementation from torch_geometric.
__init__(conv=None, norm=None, act=None, block='plain', dropout=0.0)
Parameters:
-
conv(Module, default:None) –Convolutional layer.
-
norm(Module, default:None) –Normalization layer.
-
act(Module, default:None) –Activation layer.
-
block(str, default:'plain') –Skip connection type. Possible values: 'res', 'res+', 'dense', 'plain'.
-
dropout(float, default:0.0) –Dropout rate.
forward(x, adj)
Forward method.
Parameters:
-
x(Tensor) –Node features of shape
(batch_size, n_nodes, in_dim). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, n_nodes, n_nodes).
Returns:
-
y(Tensor) –Output tensor of shape
(batch_size, n_nodes, out_dim).