Skip to content

VAEABMIL

torchmil.models.VAEABMIL

Bases: MILModel

Variational Autoencoder - Attention-based Multiple Instance Learning (VAEABMIL) model, proposed in the paper Using Variational Autoencoders for Out of Distribution Detection in Histological Multiple Instance Learning.

The model jointly trains a Variational Autoencoder (VAE) on instance features to learn a latent representation that is used for attention-based multiple instance learning and to detect out-of-distribution instances and bags.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model uses the VAE, to obtain an approximated posterior distribution \(p(\mathbf{z} | \mathbf{x})\) for each instance \(\mathbf{x}\) in the bag. Then, \(\mathbf{X} = [\mathbf{z}_1, \ldots, \mathbf{z}_N] \in \mathbb{R}^{N \times D}\) with \(\mathbf{z}_i \sim p(\mathbf{z}_i \mid \mathbf{x}_i)\).

Lastly, it aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{D}\) using the attention-based pooling,

\[ \mathbf{z}, \mathbf{f} = \operatorname{AttentionPool}(\mathbf{X}). \]

where \(\mathbf{f} \in \mathbb{R}^{N}\) are the attention values. See AttentionPool for more details on the attention-based pooling. The bag representation \(\mathbf{z}\) is then fed into a classifier (one linear layer) to predict the bag label.

__init__(feat_ext, in_shape=None, att_dim=128, att_act='tanh', gated=False, criterion=torch.nn.BCEWithLogitsLoss(), vae_loss_reduction='mean')

Parameters:

  • feat_ext (VariationalAutoEncoderMIL) –

    Variational Autoencoder used as feature extractor.

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • att_act (str, default: 'tanh' ) –

    Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.

  • gated (bool, default: False ) –

    If True, use gated attention in the attention pooling.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

  • vae_loss_reduction (str, default: 'mean' ) –

    Reduction method for VAE loss. Possible values: 'sum', 'mean', 'none'.

forward(X, mask=None, return_att=False, return_latent_repr=False, n_samples=1)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_latent_repr (bool, default: False ) –

    If True, returns latent representation in addition to Y_pred. (Currently not implemented)

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples to use for the VAE.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

compute_loss(Y, X, mask=None, n_samples=1)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples to use for the VAE loss computation.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss values. Includes the main criterion loss and VAE losses (VaeELL and VaeKL).

predict(X, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions (attention values), in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions (attention values) of shape (batch_size, bag_size).

log_marginal_likelihood_importance_sampling(X, mask=None, n_samples=1)

Estimate the marginal log-likelihood of the input bag using importance sampling.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • n_samples (int, default: 1 ) –

    Number of importance samples to use.

Returns:

  • log_likelihood ( Tensor ) –

    Estimated marginal log-likelihood of shape (batch_size,).