VAEABMIL
torchmil.models.VAEABMIL
Bases: MILModel
Variational Autoencoder - Attention-based Multiple Instance Learning (VAEABMIL) model, proposed in the paper Using Variational Autoencoders for Out of Distribution Detection in Histological Multiple Instance Learning.
The model jointly trains a Variational Autoencoder (VAE) on instance features to learn a latent representation that is used for attention-based multiple instance learning and to detect out-of-distribution instances and bags.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model uses the VAE, to obtain an approximated posterior distribution \(p(\mathbf{z} | \mathbf{x})\) for each instance \(\mathbf{x}\) in the bag. Then, \(\mathbf{X} = [\mathbf{z}_1, \ldots, \mathbf{z}_N] \in \mathbb{R}^{N \times D}\) with \(\mathbf{z}_i \sim p(\mathbf{z}_i \mid \mathbf{x}_i)\).
Lastly, it aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{D}\) using the attention-based pooling,
where \(\mathbf{f} \in \mathbb{R}^{N}\) are the attention values. See AttentionPool for more details on the attention-based pooling. The bag representation \(\mathbf{z}\) is then fed into a classifier (one linear layer) to predict the bag label.
__init__(feat_ext, in_shape=None, att_dim=128, att_act='tanh', gated=False, criterion=torch.nn.BCEWithLogitsLoss(), vae_loss_reduction='mean')
Parameters:
-
feat_ext(VariationalAutoEncoderMIL) –Variational Autoencoder used as feature extractor.
-
in_shape(tuple, default:None) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
att_dim(int, default:128) –Attention dimension.
-
att_act(str, default:'tanh') –Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.
-
gated(bool, default:False) –If True, use gated attention in the attention pooling.
-
criterion(Module, default:BCEWithLogitsLoss()) –Loss function. By default, Binary Cross-Entropy loss from logits.
-
vae_loss_reduction(str, default:'mean') –Reduction method for VAE loss. Possible values: 'sum', 'mean', 'none'.
forward(X, mask=None, return_att=False, return_latent_repr=False, n_samples=1)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention values (before normalization) in addition to
Y_pred. -
return_latent_repr(bool, default:False) –If True, returns latent representation in addition to
Y_pred. (Currently not implemented) -
n_samples(int, default:1) –Number of Monte Carlo samples to use for the VAE.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
att(Tensor) –Only returned when
return_att=True. Attention values (before normalization) of shape(batch_size, bag_size).
compute_loss(Y, X, mask=None, n_samples=1)
Compute loss given true bag labels.
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
n_samples(int, default:1) –Number of Monte Carlo samples to use for the VAE loss computation.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss values. Includes the main criterion loss and VAE losses (VaeELL and VaeKL).
predict(X, mask=None, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_inst_pred(bool, default:True) –If
True, returns instance labels predictions (attention values), in addition to bag label predictions.
Returns:
-
Y_pred(Tensor) –Bag label logits of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions (attention values) of shape(batch_size, bag_size).
log_marginal_likelihood_importance_sampling(X, mask=None, n_samples=1)
Estimate the marginal log-likelihood of the input bag using importance sampling.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...). -
n_samples(int, default:1) –Number of importance samples to use.
Returns:
-
log_likelihood(Tensor) –Estimated marginal log-likelihood of shape
(batch_size,).