TransformerABMIL
torchmil.models.TransformerABMIL
Bases: MILModel
Transformer Attention-based Multiple Instance Learning model. Proposed in Sm: enhanced localization in Multiple Instance Learning for medical imaging classification.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\). Then, it transforms the instance features using a transformer encoder,
and finally it aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{D}\) using the attention-based pooling,
where \(\mathbf{f} = \operatorname{MLP}(\mathbf{X}) \in \mathbb{R}^{N}\) are the attention values. The bag representation \(\mathbf{z}\) is then fed into a classifier (one linear layer) to predict the bag label.
See AttentionPool for more details on the attention-based pooling, and TransformerEncoder for more details on the transformer encoder.
__init__(in_shape, pool_att_dim=128, pool_act='tanh', pool_gated=False, feat_ext=torch.nn.Identity(), transf_att_dim=512, transf_n_layers=1, transf_n_heads=8, transf_use_mlp=True, transf_add_self=True, transf_dropout=0.0, criterion=torch.nn.BCEWithLogitsLoss())
Class constructor.
Parameters:
-
in_shape(tuple) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
pool_att_dim(int, default:128) –Attention dimension for pooling.
-
pool_act(str, default:'tanh') –Activation function for pooling. Possible values: 'tanh', 'relu', 'gelu'.
-
pool_gated(bool, default:False) –If True, use gated attention in the attention pooling.
-
feat_ext(Module, default:Identity()) –Feature extractor.
-
transf_att_dim(int, default:512) –Attention dimension for transformer encoder.
-
transf_n_layers(int, default:1) –Number of layers in transformer encoder.
-
transf_n_heads(int, default:8) –Number of heads in transformer encoder.
-
transf_use_mlp(bool, default:True) –Whether to use MLP in transformer encoder.
-
transf_add_self(bool, default:True) –Whether to add input to output in transformer encoder.
-
transf_dropout(float, default:0.0) –Dropout rate in transformer encoder.
-
criterion(Module, default:BCEWithLogitsLoss()) –Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.
forward(X, mask=None, return_att=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention values (before normalization) in addition to
Y_logits_pred.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
att(Tensor) –Only returned when
return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).
compute_loss(Y, X, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size).
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss value.
predict(X, mask=None, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:True) –If
True, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions of shape(batch_size, bag_size).