iRPE Multihead Self-Attention
torchmil.nn.attention.iRPEMultiheadSelfAttention
Bases: Module
Multihead Self-Attention with image Relative Position Encoding (iRPE), as described in Rethinking and Improving Relative Position Encoding for Vision Transformer.
The iRPE implementation is based on the official codebase.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k')
Parameters:
-
in_dim(int) –Input dimension.
-
att_dim(int, default:512) –Attention dimension. Must be divisible by
n_heads. -
out_dim(int, default:None) –Output dimension. If None,
out_dim=in_dim. -
n_heads(int, default:4) –Number of heads.
-
dropout(float, default:0.0) –Dropout rate.
-
learn_weights(bool, default:True) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dimmust be divisible byn_heads. -
rpe_ratio(float, default:1.9) –Relative position encoding ratio.
-
rpe_method(str, default:'product') –Relative position encoding method. Possible values: ['euc', 'quant', 'cross', 'product']
-
rpe_mode(str, default:'contextual') –Relative position encoding mode. Possible values: [None, 'bias', 'contextual']
-
rpe_shared_head(bool, default:True) –Whether to share weights across heads.
-
rpe_skip(int, default:1) –Relative position encoding skip. Possible values: [0, 1].
-
rpe_on(str, default:'k') –Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].
Note. When 'v' is in rpe_on, rpe_mode must be 'contextual'.
forward(x, mask=None, return_att=False, height=None, width=None)
Forward pass.
Parameters:
-
x(Tensor) –Input tensor of shape
(batch_size, seq_len, in_dim). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, seq_len). -
height(int, default:None) –Height of the input sequence. If None,
height = floor(sqrt(seq_len)). -
width(int, default:None) –Width of the input sequence. If None,
width = floor(sqrt(seq_len)).
Returns:
-
y(Tensor) –Output tensor of shape
(batch_size, seq_len, att_dim).