IIBMIL
torchmil.models.IIBMIL
Bases: Module
Integrated Instance-Level and Bag-Level Multiple Instance Learning (IIB-MIL) model, proposed in the paper IIB-MIL: Integrated Instance-Level and Bag-Level Multiple Instances Learning with Label Disambiguation for Pathological Image Analysis.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, a TransformerEncoder is applied to transform the instance features using context information. Subsequently, the model uses bag-level and instance-level supervision:
Bag-level supervision: The instances are aggregated into a class token using a transformer decoder. A linear layer is then applied to predict the bag label.
Instance-level supervision: Consists of four steps.
- Using an instance classifier, obtain the probability of instance \(i\) belonging to class \(c\), denoted as \(p_{i,c}\).
- The prototype \(\mathbf{p}_{c,t} \in \mathbf{R}^{D}\) of class \(c\) at time \(t\) is updated using a momentum update rule based on the set of instances with the top \(k\) highest probabilities of belonging to class \(c\). Writing \(\mathbf{P}_t = \left[ \mathbf{p}_{1,t}, \ldots, \mathbf{p}_{C,t} \right]^\top \in \mathbb{R}^{C \times D}\), the prototype label \(z_{i}\) of each instance is obtained as \(z_{i} = \text{argmax}_{c} \ \mathbf{P} \mathbf{x}_i\).
- Compute instance-level soft labels using the prototype labels and a momentum update.
- Compute the instance-level cross-entropy loss using the soft labels and the instance classifier.
A Note about Prototype Updates. This class does not automatically call update_prototypes during the loss computation (e.g., inside compute_loss). This is a deliberate design choice since updating prototypes at every iteration (per batch) can lead to rapid prototype collapse.
Therefore, you must explicitly determine when to update the prototypes by calling the update_prototypes method manually within your training loop.
Note that torchmil's Trainer does not automatically call update_prototypes during training, so you will need to implement a custom training loop if you want to use this functionality.
__init__(in_shape=None, att_dim=256, n_layers_encoder=1, n_layers_decoder=1, use_mlp_encoder=True, use_mlp_decoder=False, n_heads=4, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape(tuple, default:None) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
att_dim(int, default:256) –Attention dimension.
-
n_layers_encoder(int, default:1) –Number of layers in the transformer encoder.
-
n_layers_decoder(int, default:1) –Number of layers in the transformer decoder.
-
use_mlp_encoder(bool, default:True) –If True, uses a multi-layer perceptron (MLP) in the encoder.
-
use_mlp_decoder(bool, default:False) –If True, uses a multi-layer perceptron (MLP) in the decoder.
-
n_heads(int, default:4) –Number of attention heads.
-
feat_ext(Module, default:Identity()) –Feature extractor.
-
criterion(Module, default:BCEWithLogitsLoss()) –Loss function. By default, Binary Cross-Entropy loss from logits.
forward(X, mask=None, return_inst_pred=False, return_X_enc=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:False) –If True, returns attention values (before normalization) in addition to
Y_pred. -
return_X_enc(bool, default:False) –If True, returns instance embeddings in addition to
Y_pred.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –Only returned when
return_inst_pred=True. Instance label logits of shape(batch_size, bag_size). -
X_enc(Tensor) –Only returned when
return_X_enc=True. Instance embeddings of shape(batch_size, bag_size, att_dim).
compute_loss(Y, X, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size).
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss value.
predict(X, mask=None, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:True) –If
True, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions of shape(batch_size, bag_size).
update_prototypes(X, mask=None, proto_m=0.9)
Update prototypes.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
proto_m(float, default:0.9) –Momentum for updating prototypes
Returns:
-
None–None