Attention Pool
torchmil.nn.attention.AttentionPool
Bases: Module
Attention-based pooling, as proposed in the paper Attention-based Multiple Instance Learning.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{in_dim}}\) as,
where \(\mathbf{f} = \operatorname{MLP}(\mathbf{X}) \in \mathbb{R}^{N}\) are the attention values and \(s_n\) is the normalized attention score for the \(n\)-th instance.
To compute the attention values, the \(\operatorname{MLP}\) is defined as
where \(\mathbf{W}_1 \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\), \(\mathbf{W}_2 \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\), \(\mathbf{w} \in \mathbb{R}^{\texttt{att_dim}}\), \(\operatorname{act} \ \colon \mathbb{R} \to \mathbb{R}\) is the activation function, \(\operatorname{sigm} \ \colon \mathbb{R} \to \left] 0, 1 \right[\) is the sigmoid function, and \(\odot\) denotes element-wise multiplication.
__init__(in_dim=None, att_dim=128, act='tanh', gated=False)
Parameters:
-
in_dim(int, default:None) –Input dimension. If not provided, it will be lazily initialized.
-
att_dim(int, default:128) –Attention dimension.
-
act(str, default:'tanh') –Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.
-
gated(bool, default:False) –If True, use gated attention.
forward(X, mask=None, return_att=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, in_dim). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention values (before normalization) in addition to
z.
Returns:
-
z(Tensor) –Bag representation of shape
(batch_size, in_dim). -
f(Tensor) –Only returned when
return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).