Prob Smooth Attention Pool
torchmil.nn.attention.ProbSmoothAttentionPool
Bases: Module
Probabilistic Smooth Attention Pooling, proposed in in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging and Smooth Attention for Deep Multiple Instance Learning: Application to CT Intracranial Hemorrhage Detection
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model computes an attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\), where:
where \(\operatorname{MLP}\) is a multi-layer perceptron, and \(\mathbf{w}_{\mu},\mathbf{w}_{\sigma} \in \mathbb{R}^{2\texttt{att_dim} \times 1}\).
If covar_mode='zero', the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.
Then, \(M\) samples from the attention distribution are drawn as \(\widehat{\mathbf{f}}^{(m)} \sim q(\mathbf{f} \mid \mathbf{X})\). With these samples, the bag representation is computed as:
where \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(M)} \right]^\top \in \mathbb{R}^{N \times M}\).
Kullback-Leibler Divergence. Given a bag with adjancency matrix \(\mathbf{A}\), the KL divergence between the attention distribution and the prior distribution is computed as:
where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\).
__init__(in_dim=None, att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000)
Parameters:
-
in_dim(int, default:None) –Input dimension. If not provided, it will be lazily initialized.
-
att_dim(int, default:128) –Attention dimension.
-
covar_mode(str, default:'diag') –Covariance mode. Must be 'diag' or 'zero'.
-
n_samples_train(int, default:1000) –Number of samples during training.
-
n_samples_test(int, default:5000) –Number of samples during testing.
forward(X, adj=None, mask=None, return_att_samples=False, return_att_dist=False, return_kl_div=False, n_samples=None)
In the following, if covar_mode='zero' then n_samples is automatically set to 1 and diag_Sigma_f is set to None.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, dim). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
adj(Tensor, default:None) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). Only required whenreturn_kl_div=True. -
return_att_samples(bool, default:False) –If True, returns samples from the attention distribution
fin addition toz. -
return_att_dist(bool, default:False) –If True, returns the attention distribution (
mu_f,diag_Sigma_f) in addition toz. -
return_kl_div(bool, default:False) –If True, returns the KL divergence between the attention distribution and the prior distribution.
-
n_samples(int, default:None) –Number of samples to draw. If not provided, it will use
n_samples_trainduring training andn_samples_testduring testing.
Returns:
-
z(Tensor) –Bag representation of shape
(batch_size, dim, n_samples). -
f(Tensor) –Samples from the attention distribution of shape
(batch_size, bag_size, n_samples). Only returned whenreturn_att_samples=True. -
mu_f(Tensor) –Mean of the attention distribution of shape
(batch_size, bag_size, 1). Only returned whenreturn_att_dist=True. -
diag_Sigma_f(Tensor) –Covariance of the attention distribution of shape
(batch_size, bag_size, 1). Only returned whenreturn_att_dist=True. -
kl_div(Tensor) –KL divergence between the attention distribution and the prior distribution, of shape
(). Only returned whenreturn_kl_div=True.