Skip to content

Variational Autoencoder (VAE)

torchmil.nn.VariationalAutoEncoder

Bases: Module

Variational Autoencoder (VAE) model for learning latent representations.

The VAE learns a latent representation \(\mathbf{z}\) of input data \(\mathbf{x}\) by maximizing the Evidence Lower Bound (ELBO):

\[ \mathcal{L}(\theta, \phi; \mathbf{x}) = \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})} [\log p_\theta(\mathbf{x}|\mathbf{z})] - \text{KL}(q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z})) \]

where \(q_\phi(\mathbf{z}|\mathbf{x})\) is the encoder (posterior) and \(p_\theta(\mathbf{x}|\mathbf{z})\) is the decoder (likelihood). Both the encoder and decoder are implemented as MLPs.

__init__(input_shape=(512,), layer_sizes=[128, 64], activations=['relu', 'None'], covar_mode='single', jitter=1e-07)

Parameters:

  • input_shape (tuple[int], default: (512,) ) –

    Shape of input data (excluding batch dimension).

  • layer_sizes (list[int], default: [128, 64] ) –

    List of hidden layer sizes for the encoder (decoder mirrors this).

  • activations (list[str], default: ['relu', 'None'] ) –

    List of activation functions for each layer. Must have same length as layer_sizes.

  • covar_mode (str, default: 'single' ) –

    Covariance mode for the variational distributions. Options: 'single', 'diagonal'.

  • jitter (float, default: 1e-07 ) –

    Small value added to log_std for numerical stability.

get_reparameterized_samples(mean, log_std, n_samples=1)

Generate reparameterized samples using the reparameterization trick.

Parameters:

  • mean (Tensor) –

    Mean of the distribution of shape (batch_size, latent_dim).

  • log_std (Tensor) –

    Log standard deviation of shape (batch_size, d_var_enc).

  • n_samples (int, default: 1 ) –

    Number of samples to generate.

Returns:

  • Tensor

    Reparameterized samples of shape (batch_size, n_samples, latent_dim).

get_raw_output_enc(X)

Compute the mean and log standard deviation of the posterior distribution \(q(\mathbf{z}\mid \mathbf{x})\).

The posterior distribution is parameterized as:

\(q(\mathbf{z} \mid \mathbf{x}) = \mathcal N(\mathbf{x} \mid \mu(\mathbf{x}), \sigma(\mathbf{x}) * \mathbf{I})\)

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, input_dim).

Returns:

  • mean ( Tensor ) –

    Mean vector of shape (batch_size, latent_dim).

  • log_std ( Tensor ) –

    Log standard deviation of shape (batch_size, d_var_enc).

get_raw_output_dec(samples)

Compute the mean and log standard deviation of the likelihood distribution \(p(\mathbf{x}|\mathbf{z})\).

The likelihood distribution is parameterized as: $p(\mathbf{x} \mid \mathbf{z}) = \mathcal N (\mathbf{x} \mid \mu(\mathbf{z}), \sigma(\mathbf{z}) \mathbf{I})

Parameters:

  • samples (Tensor) –

    Samples from the posterior distribution of shape (batch_size, latent_dim).

Returns:

  • mean ( Tensor ) –

    Mean of the likelihood of shape (batch_size, input_dim).

  • log_std ( Tensor ) –

    Log standard deviation of shape (batch_size, d_var_dec).

forward(X, n_samples=1, return_mean_logstd=False)

Forward pass through the VAE encoder.

Note: This method only implements encoding, since the latent variables are used for downstream tasks.

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, ...).

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples to generate from the posterior.

  • return_mean_logstd (bool, default: False ) –

    If True, also returns the posterior mean and log standard deviation.

Returns:

  • posterior_samples ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Samples from the posterior \(q(\mathbf{z}|\mathbf{x})\) of shape (batch_size, n_samples, latent_dim).

  • post_mean ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Posterior mean of shape (batch_size, latent_dim).

  • post_log_std ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Posterior log std of shape (batch_size, latent_dim).

get_posterior_samples(X, n_samples=1, return_mean_logstd=False)

Generate samples from the posterior distribution \(q(\mathbf{z}|\mathbf{x})\).

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, input_dim).

  • n_samples (int, default: 1 ) –

    Number of samples to obtain.

  • return_mean_logstd (bool, default: False ) –

    Whether to return the mean and log_std used to obtain the samples.

Returns:

  • rep_samples ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Samples of \(q(\mathbf{z}|\mathbf{x})\) of shape (batch_size, n_samples, latent_dim).

  • mean ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Mean of shape (batch_size, latent_dim).

  • log_std_v ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Log std of shape (batch_size, latent_dim).

complete_forward_samples(X, n_samples=1)

Compute samples from the likelihood \(p(\mathbf{x}|\mathbf{z})\) using a complete forward pass.

This method first encodes the input to obtain latent samples, then decodes these samples to reconstruct the original input.

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, input_dim).

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples for the forward pass.

Returns:

  • reconstructions ( Tensor ) –

    Reconstructed data of shape (batch_size, input_dim).

compute_loss(X, reduction='mean', n_samples=1, return_samples=False)

Compute the ELBO:

\(\mathcal E_q[log p(\mathbf{x}|\mathbf{z})] - KL[q(\mathbf{z})||p(\mathbf{z})]\)

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, input_dim).

  • reduction (str, default: 'mean' ) –

    Way to reduce the loss across instances. Options: 'sum', 'mean', 'none'.

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples for the loss computation.

  • return_samples (bool, default: False ) –

    If True, also returns the latent samples used for loss computation.

Returns:

  • loss_dict ( dict | tuple[dict, Tensor] ) –

    Dictionary containing 'VaeELL' (negative expected log-likelihood) and 'VaeKL' (KL divergence).

  • samples ( dict | tuple[dict, Tensor] ) –

    Only returned when return_samples=True. Latent samples of shape (batch_size * n_samples, latent_dim).

_kl_prior(mean, log_std)

Compute KL divergence between posterior \(q(\mathbf{z}|\mathbf{x})\) and standard normal prior.

Computes \(D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) || \mathcal{N}(0, I))\) for a multivariate Gaussian posterior with diagonal covariance matrix.

Parameters:

  • mean (Tensor) –

    Posterior mean vectors of shape (batch_size, latent_dim).

  • log_std (Tensor) –

    Posterior log standard deviations of shape (batch_size, latent_dim).

Returns:

  • kl_div ( Tensor ) –

    KL divergence per instance of shape (batch_size,).

_diagonal_log_gaussian_pdf(inputs, mean, log_std)

Compute log probability density of a diagonal Gaussian.

Computes \(\log \mathcal{N}(x; \mu, \sigma^2 I)\) for inputs with diagonal covariance.

Parameters:

  • inputs (Tensor) –

    Input data of shape (batch_size, input_dim).

  • mean (Tensor) –

    Gaussian mean of shape (batch_size, input_dim).

  • log_std (Tensor) –

    Gaussian log standard deviation of shape (batch_size, input_dim).

Returns:

  • log_prob ( Tensor ) –

    Log probability densities of shape (batch_size,).

log_marginal_likelihood_importance_sampling(X, n_samples=1)

Compute log marginal likelihood log \(p(\mathbf{x})\) via importance sampling. The estimation is computed as

\(\log p(\mathbf{x}) \approx \log \frac{1}{K} \sum_{i=1}^K \frac{p(x|z_i)p(z_i)}{q(z_i|x)}\)

Parameters:

  • X (Tensor) –

    Input data of shape (batch_size, input_dim).

  • n_samples (int, default: 1 ) –

    Number of importance samples for estimation.

Returns:

  • log_marginal ( Tensor ) –

    Log marginal likelihood estimates of shape (batch_size,).


torchmil.nn.VariationalAutoEncoderMIL

Bases: VariationalAutoEncoder

Variational Autoencoder for Multiple Instance Learning.

This class extends the VAE to handle bag-structured data by processing each instance in a bag independently through the VAE and returning results in bag format.

__init__(input_shape=(512,), layer_sizes=[128, 64], activations=['relu', 'None'], covar_mode='single', jitter=1e-07)

Parameters:

  • input_shape (tuple[int], default: (512,) ) –

    Shape of input data (excluding batch dimension).

  • layer_sizes (list[int], default: [128, 64] ) –

    List of hidden layer sizes for the encoder (decoder mirrors this).

  • activations (list[str], default: ['relu', 'None'] ) –

    List of activation functions for each layer. Must have same length as layer_sizes.

  • covar_mode (str, default: 'single' ) –

    Covariance mode for the variational distributions. Options: 'single', 'diagonal'.

  • jitter (float, default: 1e-07 ) –

    Small value added to log_std for numerical stability.

forward(X, n_samples=1, return_mean_logstd=False)

Forward pass for bag-structured data.

This method processes each instance in the bags independently through the VAE encoder. Used in MIL feature extraction where output must be (batch_size, bag_size, latent_dim).

Parameters:

  • X (Tensor) –

    Bag data of shape (batch_size, bag_size, input_dim).

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples.

  • return_mean_logstd (bool, default: False ) –

    Whether to return posterior mean and log standard deviation.

Returns:

  • samples ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Encoding samples of shape (batch_size, bag_size, n_samples, latent_dim).

  • mean ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Mean of shape (batch_size, bag_size, latent_dim).

  • log_std ( Tensor | tuple[Tensor, Tensor, Tensor] ) –

    Only returned when return_mean_logstd=True. Log std of shape (batch_size, bag_size, latent_dim).

log_marginal_likelihood_importance_sampling(X, mask=None, n_samples=1)

Compute log marginal likelihood for bag-structured data via importance sampling.

This method processes each instance in the bags independently and returns log marginal estimates for each instance.

Parameters:

  • X (Tensor) –

    Bag data of shape (batch_size, bag_size, input_dim).

  • mask (Tensor | None, default: None ) –

    Optional binary mask of shape (batch_size, bag_size) indicating valid instances.

  • n_samples (int, default: 1 ) –

    Number of importance samples for estimation.

Returns:

  • log_marginal ( Tensor ) –

    Log marginal likelihood per instance of shape (batch_size, bag_size).

compute_loss(X, mask=None, reduction='mean', n_samples=1, return_samples=False)

Compute VAE loss for bag-structured data.

The loss is computed for each instance in the bags and then aggregated according to the reduction strategy and optional mask.

Parameters:

  • X (Tensor) –

    Bag data of shape (batch_size, bag_size, input_dim).

  • mask (Tensor | None, default: None ) –

    Optional binary mask of shape (batch_size, bag_size) for valid instances.

  • reduction (str, default: 'mean' ) –

    Reduction method ('sum', 'mean', or 'none').

  • n_samples (int, default: 1 ) –

    Number of Monte Carlo samples for loss computation.

  • return_samples (bool, default: False ) –

    Whether to return latent samples used in loss computation.

Returns:

  • loss_dict ( dict | tuple[dict, Tensor] ) –

    Dictionary with 'VaeELL' and 'VaeKL' losses.

  • samples ( dict | tuple[dict, Tensor] ) –

    Only returned when return_samples=True. Latent samples used for loss computation.

complete_forward_samples(X)

Compute reconstructions for bag-structured data via complete forward pass.

This method processes each instance in the bags independently through the VAE encoder-decoder pipeline and returns reconstructions in bag format.

Parameters:

  • X (Tensor) –

    Bag data of shape (batch_size, bag_size, input_dim).

Returns:

  • reconstructions ( Tensor ) –

    Reconstructed bag data of shape (batch_size, bag_size, input_dim).