Dense MinCut pooling
torchmil.nn.gnns.dense_mincut_pool
dense_mincut_pool(x, adj, s, mask=None, temp=1.0)
Dense MinCut Pooling.
Adapts the implementation from torch_geometric.
Parameters:
-
x(Tensor) –Input tensor of shape
(batch_size, n_nodes, in_dim). -
adj(Tensor) –Adjacency tensor of shape
(batch_size, n_nodes, n_nodes). -
s(Tensor) –Dense learned assignments tensor of shape
(batch_size, n_nodes, n_cluster). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, n_nodes). -
temp(float, default:1.0) –Temperature.
Returns:
-
x_(Tensor) –Pooled node feature tensor of shape
(batch_size, n_cluster, in_dim). -
adj_(Tensor) –Coarsened adjacency tensor of shape
(batch_size, n_cluster, n_cluster). -
mincut_loss(Tensor) –MinCut loss.
-
ortho_loss(Tensor) –Orthogonality loss.