Nyström Attention
torchmil.nn.attention.NystromAttention
Bases: Module
Nystrom attention, as described in the paper Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention.
Implementation based on the official code.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, learn_weights=True, n_landmarks=256, pinv_iterations=6)
Parameters:
-
in_dim(int) –Input dimension.
-
out_dim(int, default:None) –Output dimension. If None, out_dim = in_dim.
-
att_dim(int, default:512) –Attention dimension. Must be divisible by
n_heads. -
n_heads(int, default:4) –Number of heads.
-
learn_weights(bool, default:True) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dimmust be divisible byn_heads. -
n_landmarks(int, default:256) –Number of landmarks.
-
pinv_iterations(int, default:6) –Number of iterations for Moore-Penrose pseudo-inverse.
forward(x, mask=None, return_att=False)
Forward pass.
Parameters:
-
x(Tensor) –Input tensor of shape
(batch_size, seq_len, in_dim). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, seq_len). -
return_att(bool, default:False) –Whether to return attention weights.
Returns:
-
y(Tensor) –Output tensor of shape
(batch_size, seq_len, att_dim). -
att(Tensor) –Only returned when
return_att=True. Attention weights of shape(batch_size, n_heads, seq_len, seq_len).