Variational Autoencoder (VAE)
torchmil.nn.VariationalAutoEncoder
Bases: Module
Variational Autoencoder (VAE) model for learning latent representations.
The VAE learns a latent representation \(\mathbf{z}\) of input data \(\mathbf{x}\) by maximizing the Evidence Lower Bound (ELBO):
where \(q_\phi(\mathbf{z}|\mathbf{x})\) is the encoder (posterior) and \(p_\theta(\mathbf{x}|\mathbf{z})\) is the decoder (likelihood). Both the encoder and decoder are implemented as MLPs.
__init__(input_shape=(512,), layer_sizes=[128, 64], activations=['relu', 'None'], covar_mode='single', jitter=1e-07)
Parameters:
-
input_shape(tuple[int], default:(512,)) –Shape of input data (excluding batch dimension).
-
layer_sizes(list[int], default:[128, 64]) –List of hidden layer sizes for the encoder (decoder mirrors this).
-
activations(list[str], default:['relu', 'None']) –List of activation functions for each layer. Must have same length as layer_sizes.
-
covar_mode(str, default:'single') –Covariance mode for the variational distributions. Options: 'single', 'diagonal'.
-
jitter(float, default:1e-07) –Small value added to log_std for numerical stability.
get_reparameterized_samples(mean, log_std, n_samples=1)
Generate reparameterized samples using the reparameterization trick.
Parameters:
-
mean(Tensor) –Mean of the distribution of shape
(batch_size, latent_dim). -
log_std(Tensor) –Log standard deviation of shape
(batch_size, d_var_enc). -
n_samples(int, default:1) –Number of samples to generate.
Returns:
-
Tensor–Reparameterized samples of shape
(batch_size, n_samples, latent_dim).
get_raw_output_enc(X)
Compute the mean and log standard deviation of the posterior distribution \(q(\mathbf{z}\mid \mathbf{x})\).
The posterior distribution is parameterized as:
\(q(\mathbf{z} \mid \mathbf{x}) = \mathcal N(\mathbf{x} \mid \mu(\mathbf{x}), \sigma(\mathbf{x}) * \mathbf{I})\)
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, input_dim).
Returns:
-
mean(Tensor) –Mean vector of shape
(batch_size, latent_dim). -
log_std(Tensor) –Log standard deviation of shape
(batch_size, d_var_enc).
get_raw_output_dec(samples)
Compute the mean and log standard deviation of the likelihood distribution \(p(\mathbf{x}|\mathbf{z})\).
The likelihood distribution is parameterized as: $p(\mathbf{x} \mid \mathbf{z}) = \mathcal N (\mathbf{x} \mid \mu(\mathbf{z}), \sigma(\mathbf{z}) \mathbf{I})
Parameters:
-
samples(Tensor) –Samples from the posterior distribution of shape
(batch_size, latent_dim).
Returns:
-
mean(Tensor) –Mean of the likelihood of shape
(batch_size, input_dim). -
log_std(Tensor) –Log standard deviation of shape
(batch_size, d_var_dec).
forward(X, n_samples=1, return_mean_logstd=False)
Forward pass through the VAE encoder.
Note: This method only implements encoding, since the latent variables are used for downstream tasks.
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, ...). -
n_samples(int, default:1) –Number of Monte Carlo samples to generate from the posterior.
-
return_mean_logstd(bool, default:False) –If True, also returns the posterior mean and log standard deviation.
Returns:
-
posterior_samples(Tensor | tuple[Tensor, Tensor, Tensor]) –Samples from the posterior \(q(\mathbf{z}|\mathbf{x})\) of shape
(batch_size, n_samples, latent_dim). -
post_mean(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Posterior mean of shape(batch_size, latent_dim). -
post_log_std(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Posterior log std of shape(batch_size, latent_dim).
get_posterior_samples(X, n_samples=1, return_mean_logstd=False)
Generate samples from the posterior distribution \(q(\mathbf{z}|\mathbf{x})\).
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, input_dim). -
n_samples(int, default:1) –Number of samples to obtain.
-
return_mean_logstd(bool, default:False) –Whether to return the mean and log_std used to obtain the samples.
Returns:
-
rep_samples(Tensor | tuple[Tensor, Tensor, Tensor]) –Samples of \(q(\mathbf{z}|\mathbf{x})\) of shape
(batch_size, n_samples, latent_dim). -
mean(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Mean of shape(batch_size, latent_dim). -
log_std_v(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Log std of shape(batch_size, latent_dim).
complete_forward_samples(X, n_samples=1)
Compute samples from the likelihood \(p(\mathbf{x}|\mathbf{z})\) using a complete forward pass.
This method first encodes the input to obtain latent samples, then decodes these samples to reconstruct the original input.
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, input_dim). -
n_samples(int, default:1) –Number of Monte Carlo samples for the forward pass.
Returns:
-
reconstructions(Tensor) –Reconstructed data of shape
(batch_size, input_dim).
compute_loss(X, reduction='mean', n_samples=1, return_samples=False)
Compute the ELBO:
\(\mathcal E_q[log p(\mathbf{x}|\mathbf{z})] - KL[q(\mathbf{z})||p(\mathbf{z})]\)
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, input_dim). -
reduction(str, default:'mean') –Way to reduce the loss across instances. Options: 'sum', 'mean', 'none'.
-
n_samples(int, default:1) –Number of Monte Carlo samples for the loss computation.
-
return_samples(bool, default:False) –If True, also returns the latent samples used for loss computation.
Returns:
-
loss_dict(dict | tuple[dict, Tensor]) –Dictionary containing 'VaeELL' (negative expected log-likelihood) and 'VaeKL' (KL divergence).
-
samples(dict | tuple[dict, Tensor]) –Only returned when
return_samples=True. Latent samples of shape(batch_size * n_samples, latent_dim).
_kl_prior(mean, log_std)
Compute KL divergence between posterior \(q(\mathbf{z}|\mathbf{x})\) and standard normal prior.
Computes \(D_{KL}(q_\phi(\mathbf{z}|\mathbf{x}) || \mathcal{N}(0, I))\) for a multivariate Gaussian posterior with diagonal covariance matrix.
Parameters:
-
mean(Tensor) –Posterior mean vectors of shape
(batch_size, latent_dim). -
log_std(Tensor) –Posterior log standard deviations of shape
(batch_size, latent_dim).
Returns:
-
kl_div(Tensor) –KL divergence per instance of shape
(batch_size,).
_diagonal_log_gaussian_pdf(inputs, mean, log_std)
Compute log probability density of a diagonal Gaussian.
Computes \(\log \mathcal{N}(x; \mu, \sigma^2 I)\) for inputs with diagonal covariance.
Parameters:
-
inputs(Tensor) –Input data of shape
(batch_size, input_dim). -
mean(Tensor) –Gaussian mean of shape
(batch_size, input_dim). -
log_std(Tensor) –Gaussian log standard deviation of shape
(batch_size, input_dim).
Returns:
-
log_prob(Tensor) –Log probability densities of shape
(batch_size,).
log_marginal_likelihood_importance_sampling(X, n_samples=1)
Compute log marginal likelihood log \(p(\mathbf{x})\) via importance sampling. The estimation is computed as
\(\log p(\mathbf{x}) \approx \log \frac{1}{K} \sum_{i=1}^K \frac{p(x|z_i)p(z_i)}{q(z_i|x)}\)
Parameters:
-
X(Tensor) –Input data of shape
(batch_size, input_dim). -
n_samples(int, default:1) –Number of importance samples for estimation.
Returns:
-
log_marginal(Tensor) –Log marginal likelihood estimates of shape
(batch_size,).
torchmil.nn.VariationalAutoEncoderMIL
Bases: VariationalAutoEncoder
Variational Autoencoder for Multiple Instance Learning.
This class extends the VAE to handle bag-structured data by processing each instance in a bag independently through the VAE and returning results in bag format.
__init__(input_shape=(512,), layer_sizes=[128, 64], activations=['relu', 'None'], covar_mode='single', jitter=1e-07)
Parameters:
-
input_shape(tuple[int], default:(512,)) –Shape of input data (excluding batch dimension).
-
layer_sizes(list[int], default:[128, 64]) –List of hidden layer sizes for the encoder (decoder mirrors this).
-
activations(list[str], default:['relu', 'None']) –List of activation functions for each layer. Must have same length as layer_sizes.
-
covar_mode(str, default:'single') –Covariance mode for the variational distributions. Options: 'single', 'diagonal'.
-
jitter(float, default:1e-07) –Small value added to log_std for numerical stability.
forward(X, n_samples=1, return_mean_logstd=False)
Forward pass for bag-structured data.
This method processes each instance in the bags independently through the VAE encoder.
Used in MIL feature extraction where output must be (batch_size, bag_size, latent_dim).
Parameters:
-
X(Tensor) –Bag data of shape
(batch_size, bag_size, input_dim). -
n_samples(int, default:1) –Number of Monte Carlo samples.
-
return_mean_logstd(bool, default:False) –Whether to return posterior mean and log standard deviation.
Returns:
-
samples(Tensor | tuple[Tensor, Tensor, Tensor]) –Encoding samples of shape
(batch_size, bag_size, n_samples, latent_dim). -
mean(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Mean of shape(batch_size, bag_size, latent_dim). -
log_std(Tensor | tuple[Tensor, Tensor, Tensor]) –Only returned when
return_mean_logstd=True. Log std of shape(batch_size, bag_size, latent_dim).
log_marginal_likelihood_importance_sampling(X, mask=None, n_samples=1)
Compute log marginal likelihood for bag-structured data via importance sampling.
This method processes each instance in the bags independently and returns log marginal estimates for each instance.
Parameters:
-
X(Tensor) –Bag data of shape
(batch_size, bag_size, input_dim). -
mask(Tensor | None, default:None) –Optional binary mask of shape
(batch_size, bag_size)indicating valid instances. -
n_samples(int, default:1) –Number of importance samples for estimation.
Returns:
-
log_marginal(Tensor) –Log marginal likelihood per instance of shape
(batch_size, bag_size).
compute_loss(X, mask=None, reduction='mean', n_samples=1, return_samples=False)
Compute VAE loss for bag-structured data.
The loss is computed for each instance in the bags and then aggregated according to the reduction strategy and optional mask.
Parameters:
-
X(Tensor) –Bag data of shape
(batch_size, bag_size, input_dim). -
mask(Tensor | None, default:None) –Optional binary mask of shape
(batch_size, bag_size)for valid instances. -
reduction(str, default:'mean') –Reduction method ('sum', 'mean', or 'none').
-
n_samples(int, default:1) –Number of Monte Carlo samples for loss computation.
-
return_samples(bool, default:False) –Whether to return latent samples used in loss computation.
Returns:
-
loss_dict(dict | tuple[dict, Tensor]) –Dictionary with 'VaeELL' and 'VaeKL' losses.
-
samples(dict | tuple[dict, Tensor]) –Only returned when
return_samples=True. Latent samples used for loss computation.
complete_forward_samples(X)
Compute reconstructions for bag-structured data via complete forward pass.
This method processes each instance in the bags independently through the VAE encoder-decoder pipeline and returns reconstructions in bag format.
Parameters:
-
X(Tensor) –Bag data of shape
(batch_size, bag_size, input_dim).
Returns:
-
reconstructions(Tensor) –Reconstructed bag data of shape
(batch_size, bag_size, input_dim).