Skip to content

Sm Attention Pool

torchmil.nn.attention.SmAttentionPool

Bases: Module

Attention-based pooling with the Sm operator, as proposed in Sm: enhanced localization in Multiple Instance Learning for medical imaging classification.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this module aggregates the instance features into a bag representation in a similar way as ABMIL, but it incorporates the \(\texttt{Sm}\) operator to promote smoothness in the attention values, see Sm for more details.

Formally, if \(\texttt{sm_where == "early"}\), the module first applies the \(\texttt{Sm}\) operator to the input bag features \(\mathbf{X}\),

\[\begin{gather} \mathbf{X} = \operatorname{Sm}(\mathbf{X}). \end{gather}\]

Then, it computes the attention values \(\mathbf{f} \in \mathbb{R}^{N}\) as,

\[\begin{gather} \mathbf{f} = \operatorname{act}( \operatorname{Sm}(\mathbf{X} \mathbf{W}^\top)) \mathbf{w}, \quad \text{if } \texttt{sm_where == "mid"},\\ \mathbf{f} = \operatorname{Sm}(\operatorname{act}(\mathbf{X} \mathbf{W}^\top) \mathbf{w}), \quad \text{if } \texttt{sm_where == "late"}, \end{gather}\]

where \(\mathbf{W} \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\) and \(\mathbf{w} \in \mathbb{R}^{\texttt{att_dim} \times 1}\) are learnable parameters, and \(\operatorname{act}\) is the activation function. Then, it computes the bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{in_dim}}\) as,

\[\begin{gather} \mathbf{z} = \mathbf{X}^\top \operatorname{Softmax}(\mathbf{f}) = \sum_{n=1}^N s_n \mathbf{x}_n, \end{gather}\]

where \(s_n\) is the normalized attention score for the \(n\)-th instance.

Spectral normalization: To ensure that the Dirichlet energy decreases after applying the \(\texttt{Sm}\) operator, the linear layers can be optionally normalized using spectral normalization. In the original paper, this results in better performance. If spectral_norm=True, the linear layers after the \(\texttt{Sm}\) operator are normalized using spectral normalization.

__init__(in_dim, att_dim=128, act='gelu', sm_mode='approx', sm_alpha='trainable', sm_steps=10, sm_where='early', spectral_norm=False)

Parameters:

  • in_dim (int) –

    Input dimension.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • act (str, default: 'gelu' ) –

    Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.

  • sm_mode (str, default: 'approx' ) –

    Mode for the Sm operator. Possible values: 'approx', 'exact'.

  • sm_alpha (Union[float, str], default: 'trainable' ) –

    Alpha value for the Sm operator. If 'trainable', alpha is trainable.

  • sm_steps (int, default: 10 ) –

    Number of steps for the Sm operator.

  • sm_where (str, default: 'early' ) –

    Where to apply the Sm operator. Possible values: 'early', 'mid', 'late', 'none'.

  • spectral_norm (bool, default: False ) –

    If True, apply spectral normalization to linear layers. If sm_where is 'none', all linear layers are normalized.

forward(X, adj, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, in_dim).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to z.

Returns:

  • z ( Tensor ) –

    Bag representation of shape (batch_size, in_dim).

  • f ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).