DeepGraphSurv
torchmil.models.DeepGraphSurv
Bases: Module
DeepGraphSurv model, as proposed in Graph CNN for Survival Analysis on Whole Slide Pathological Images.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, the representation branch transforms the instance features using a Graph Convolutional Network (GCN), and the attention branch computes the attention values \(\mathbf{f}\) using another GCN,
These GCNs are implemented using the DeepGCN layer (see DeepGCNLayer) with GCNConv, LayerNorm, and ReLU activation (see GCNConv).
Writing \(\mathbf{H} = \left[ \mathbf{h}_1, \ldots, \mathbf{h}_N \right]^\top\), the attention values are used to compute the bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{hidden_dim}}\) as
where \(s_n\) is the normalized attention score for the \(n\)-th instance. The bag representation \(\mathbf{z}\) is then fed into a classifier (one linear layer) to predict the bag label.
__init__(in_shape=None, n_layers_rep=1, n_layers_att=1, hidden_dim=None, att_dim=128, dropout=0.0, K=5, compute_lambda_max=False, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape(tuple, default:None) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
n_layers_rep(int, default:1) –Number of ChebConv layers in the representation branch.
-
n_layers_att(int, default:1) –Number of ChebConv layers in the attention branch.
-
hidden_dim(int, default:None) –Hidden dimension. If not provided, it will be set to the feature dimension.
-
att_dim(int, default:128) –Attention dimension.
-
dropout(float, default:0.0) –Dropout rate.
-
K(int, default:5) –Order of the Chebyshev polynomial approximation for the ChebConv layers.
-
compute_lambda_max(bool, default:False) –If True, computes the maximum eigenvalue of the adjacency matrix for normalization. If False, it will be set to 2.0.
-
feat_ext(Module, default:Identity()) –Feature extractor.
-
criterion(Module, default:BCEWithLogitsLoss()) –Loss function.
forward(X, adj, mask=None, return_att=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention values (before normalization) in addition to
Y_pred.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
att(Tensor) –Only returned when
return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).
compute_loss(Y, X, adj, mask=None)
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size).
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss
predict(X, adj, mask=None, return_inst_pred=False)
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:False) –If True, returns instance predictions.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions of shape(batch_size, bag_size).