GTP
torchmil.models.GTP
Bases: MILModel
Method proposed in the paper GTP: Graph-Transformer for Whole Slide Image Classification.
Forward pass. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
The bags are processed using a Graph Convolutional Network (GCN) to extract high-level instance embeddings. This GCN leverages a graph \(\mathbf{A}\) constructed from the bag, where nodes correspond to patches, and edges are determined based on spatial adjacency:
To reduce the number of nodes while preserving structural relationships, a min-cut pooling operation is applied:
The pooled graph is then passed through a Transformer encoder, where a class token is introduced:
Finally, the class token representation is used for classification:
Optionally, GraphCAM can be used to generate class activation maps highlighting the most relevant regions for the classification decision.
Loss function. By default, the model is trained end-to-end using the followind per-bag loss:
where \(\ell_{\text{BCE}}\) is the Binary Cross-Entropy loss, \(\ell_{\text{MinCut}}\) is the MinCut loss, and \(\ell_{\text{Ortho}}\) is the Orthogonality loss, computed during the min-cut pooling operation, see Dense MinCut Pooling.
__init__(in_shape, att_dim=512, n_clusters=100, n_layers=1, n_heads=8, use_mlp=True, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Class constructor.
Parameters:
-
in_shape(tuple) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
att_dim(int, default:512) –Attention dimension for transformer encoder.
-
n_clusters(int, default:100) –Number of clusters in mincut pooling.
-
n_layers(int, default:1) –Number of layers in transformer encoder.
-
n_heads(int, default:8) –Number of heads in transformer encoder.
-
use_mlp(bool, default:True) –Whether to use MLP in transformer encoder.
-
dropout(float, default:0.0) –Dropout rate in transformer encoder.
-
feat_ext(Module, default:Identity()) –Feature extractor.
-
criterion(Module, default:BCEWithLogitsLoss()) –Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.
forward(X, adj, mask=None, return_cam=False, return_loss=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_cam(bool, default:False) –If True, returns the class activation map in addition to
Y_logits_pred.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
cam(Tensor) –Only returned when
return_cam=True. Class activation map of shape (batch_size, bag_size).
compute_loss(Y, X, adj, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size).
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss value.
predict(X, adj, mask=None, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:True) –If
True, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions of shape(batch_size, bag_size).