Sm Transformer
torchmil.nn.transformers.SmTransformerEncoder
Bases: Encoder
A Transformer encoder with the \(\texttt{Sm}\) operator, skip connections and layer normalization.
Given an input bag input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), it computes:
This module outputs \(\text{SmTransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L}\) if add_self=False,
and \(\text{SmTransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L} + \mathbf{X}\) if add_self=True.
See Sm for more details on the Sm operator.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, n_layers=4, use_mlp=True, add_self=False, dropout=0.0, sm_alpha='trainable', sm_mode='approx', sm_steps=10)
Class constructor
Parameters:
-
in_dim(int) –Input dimension.
-
out_dim(int, default:None) –Output dimension. If None, out_dim = in_dim.
-
att_dim(int, default:512) –Attention dimension.
-
n_heads(int, default:4) –Number of heads.
-
n_layers(int, default:4) –Number of layers.
-
use_mlp(bool, default:True) –Whether to use feedforward layer.
-
add_self(bool, default:False) –Whether to add input to output. If True,
att_dimmust be equal toin_dim. -
dropout(float, default:0.0) –Dropout rate.
forward(X, adj, mask=None, return_att=False)
Forward method.
Parameters:
-
X(Tensor) –Input tensor of shape
(batch_size, bag_size, in_dim). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, bag_size).
Returns:
-
Y(Tensor) –Output tensor of shape
(batch_size, bag_size, in_dim).
torchmil.nn.transformers.SmTransformerLayer
Bases: Layer
One layer of the Transformer encoder with the \(\texttt{Sm}\) operator.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), this module computes:
and outputs \(\mathbf{Y}\).
See Sm for more details on the Sm operator.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, use_mlp=True, dropout=0.0, sm_alpha='trainable', sm_mode='approx', sm_steps=10)
Class constructor.
Parameters:
-
in_dim(int) –Input dimension.
-
out_dim(int, default:None) –Output dimension. If None, out_dim = in_dim.
-
att_dim(int, default:512) –Attention dimension.
-
n_heads(int, default:4) –Number of heads.
-
use_mlp(bool, default:True) –Whether to use feedforward layer.
-
dropout(float, default:0.0) –Dropout rate
-
sm_alpha(float, default:'trainable') –Alpha value for the Sm operator.
-
sm_mode(str, default:'approx') –Sm mode.
-
sm_steps(int, default:10) –Number of steps to approximate the exact Sm operator.
forward(X, adj, mask=None, return_att=False)
Forward method.
Parameters:
-
X(Tensor) –Input tensor of shape
(batch_size, bag_size, in_dim). -
adj(Tensor) –Adjacency matrix of shape
(batch_size, bag_size, bag_size). -
mask(Tensor, default:None) –Mask tensor of shape
(batch_size, bag_size). -
return_att(bool, default:False) –If True, returns attention weights, of shape
(batch_size, n_heads, bag_size, bag_size).
Returns:
-
Y(Tensor) –Output tensor of shape
(batch_size, bag_size, in_dim).