General MIL model
torchmil.models.MILModel
Bases: Module
Base class for Multiple Instance Learning (MIL) models in torchmil.
Subclasses should implement the following methods:
forward: Forward pass of the model. Accepts bag features (and optionally other arguments) and returns the bag label prediction (and optionally other outputs).compute_loss: Compute inner losses of the model. Accepts bag features (and optionally other arguments) and returns the output of the forward method a dictionary of pairs (loss_name, loss_value). By default, the model has no inner losses, so this dictionary is empty.predict: Predict bag and (optionally) instance labels. Accepts bag features (and optionally other arguments) and returns label predictions (and optionally instance label predictions).
__init__(*args, **kwargs)
Initializes the module.
forward(X, *args, **kwargs)
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...).
Returns:
-
Y_pred(Tensor) –Bag label prediction of shape
(batch_size,).
compute_loss(Y, X, *args, **kwargs)
Parameters:
-
Y(Tensor) –Bag labels of shape
(batch_size,). -
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...).
Returns:
-
Y_pred(Tensor) –Bag label prediction of shape
(batch_size,). -
loss_dict(dict) –Dictionary containing the loss values.
predict(X, return_inst_pred=False, *args, **kwargs)
Parameters:
-
X(Tensor) –Bag features of shape
(batch_size, bag_size, ...).
Returns:
-
Y_pred(Tensor) –Bag label prediction of shape
(batch_size,). -
y_inst_pred(Tensor) –If
return_inst_pred=True, returns instance labels predictions of shape(batch_size, bag_size).
torchmil.models.MILModelWrapper
Bases: MILModel
A wrapper class for MIL models in torchmil.
It allows to use all models that inherit from MILModel using a common interface:
model_A = ... # forward accepts arguments 'X', 'adj'
model_B = ... # forward accepts arguments 'X''
model_A_w = MILModelWrapper(model_A)
model_B_w = MILModelWrapper(model_B)
bag = TensorDict({'X': ..., 'adj': ..., ...})
Y_pred_A = model_A_w(bag) # calls model_A.forward(X=bag['X'], adj=bag['adj'])
Y_pred_B = model_B_w(bag) # calls model_B.forward(X=bag['X'])
__init__(model)
forward(bag, **kwargs)
Parameters:
-
bag(TensorDict) –Dictionary containing one key for each argument accepted by the model's
forwardmethod.
Returns:
-
out(Any) –Output of the model's
forwardmethod.
compute_loss(bag, **kwargs)
Parameters:
-
bag(TensorDict) –Dictionary containing one key for each argument accepted by the model's
forwardmethod.
Returns:
-
out(tuple[Any, dict]) –Output of the model's
compute_lossmethod.
predict(bag, **kwargs)
Parameters:
-
bag(TensorDict) –Dictionary containing one key for each argument accepted by the model's
forwardmethod.
Returns:
-
out(Any) –Output of the model's
predictmethod.