Multihead Cross-Attention
torchmil.nn.attention.MultiheadCrossAttention
Bases: Module
The Multihead Cross Attention module, as described in Attention is All You Need.
Given input bags \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), and \(\mathbf{Y} = \left[ \mathbf{y}_1, \ldots, \mathbf{y}_M \right]^\top \in \mathbb{R}^{M \times \texttt{in_dim}}\), this module computes:
where \(d = \texttt{att_dim}\) and \(\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\) are learnable weight matrices.
If \(\texttt{out_dim} \neq \texttt{att_dim}\), \(\mathbf{Y}\) is passed through a linear layer with output dimension \(\texttt{out_dim}\).
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True)
Parameters:
-
in_dim(int) –Input dimension.
-
out_dim(int, default:None) –Output dimension. If None,
out_dim=in_dim. -
att_dim(int, default:512) –Attention dimension, must be divisible by
n_heads. -
n_heads(int, default:4) –Number of heads.
-
dropout(float, default:0.0) –Dropout rate.
-
learn_weights(bool, default:True) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dimmust be divisible byn_heads.
forward(x, y, mask=None)
Forward pass.
Parameters:
-
x(Tensor) –Input tensor of shape
(batch_size, seq_len_x, in_dim). -
y(Tensor) –Input tensor of shape
(batch_size, seq_len_y, in_dim). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, seq_len_x).
Returns:
y: Output tensor of shape (batch_size, seq_len_x, att_dim).