Sm Attention Pool
torchmil.nn.attention.SmAttentionPool
Bases: Module
Attention-based pooling with the Sm operator, as proposed in Sm: enhanced localization in Multiple Instance Learning for medical imaging classification.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this module aggregates the instance features into a bag representation in a similar way as ABMIL, but it incorporates the \(\texttt{Sm}\) operator to promote smoothness in the attention values, see Sm for more details.
Formally, if \(\texttt{sm_where == "early"}\), the module first applies the \(\texttt{Sm}\) operator to the input bag features \(\mathbf{X}\),
Then, it computes the attention values \(\mathbf{f} \in \mathbb{R}^{N}\) as,
where \(\mathbf{W} \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\) and \(\mathbf{w} \in \mathbb{R}^{\texttt{att_dim} \times 1}\) are learnable parameters, and \(\operatorname{act}\) is the activation function. Then, it computes the bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{in_dim}}\) as,
where \(s_n\) is the normalized attention score for the \(n\)-th instance.
Spectral normalization: To ensure that the Dirichlet energy decreases after applying the \(\texttt{Sm}\) operator, the linear layers can be optionally normalized using spectral normalization.
In the original paper, this results in better performance.
If spectral_norm=True, the linear layers after the \(\texttt{Sm}\) operator are normalized using spectral normalization.
__init__(in_dim, att_dim=128, act='gelu', sm_mode='approx', sm_alpha='trainable', sm_steps=10, sm_where='early', spectral_norm=False)
Parameters:
-
in_dim(int) –Input dimension.
-
att_dim(int, default:128) –Attention dimension.
-
act(str, default:'gelu') –Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.
-
sm_mode(str, default:'approx') –Mode for the Sm operator. Possible values: 'approx', 'exact'.
-
sm_alpha(Union[float, str], default:'trainable') –Alpha value for the Sm operator. If 'trainable', alpha is trainable.
-
sm_steps(int, default:10) –Number of steps for the Sm operator.
-
sm_where(str, default:'early') –Where to apply the Sm operator. Possible values: 'early', 'mid', 'late', 'none'.
-
spectral_norm(bool, default:False) –If True, apply spectral normalization to linear layers. If
sm_whereis 'none', all linear layers are normalized.
forward(X, adj, mask=None, return_att=False)
Forward pass.
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, in_dim). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention values (before normalization) in addition to
z.
Returns:
-
z(Tensor) –Bag representation of shape
(batch_size, in_dim). -
f(Tensor) –Only returned when
return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).