Shifted Mean MIL Dataset
torchmil.datasets.ShiftedMeanMILDataset
Bases: Dataset
Shifted Mean MIL Dataset for correlated instance labels.
This dataset generates synthetic MIL bags where positive bags contain a contiguous sequence of R instances with shifted mean features. This creates correlation between instance labels within positive bags, which is particularly useful for testing MIL algorithms for their ability to detect local instance-to-instance patterns within bags.
This dataset was introduced in the paper: "Synthetic Data Reveals Generalization Gaps in Correlated Multiple Instance Learning" arXiv preprint: https://arxiv.org/abs/2510.25759
Bag generation: - Each bag has a variable number of instances (between S_low and S_high) - Features are drawn from a normal distribution N(mu, sigma^2) - For positive bags (Y=1), a contiguous sequence of R instances has their first K features shifted by Delta - Instance labels reflect whether an instance has shifted features (1 if shifted, 0 otherwise)
Each bag is returned as a TensorDict with the following keys:
- X: The bag's feature matrix of shape (bag_size, M)
- Y: The bag's label (1 for positive, 0 for negative)
- y_inst: The instance-level labels within the bag
- bag_size: The number of instances in the bag
Example usage:
from torchmil.datasets import ShiftedMeanMILDataset
# Create dataset with 100 bags
dataset = ShiftedMeanMILDataset(N=100, R=3, S_low=15, S_high=45, K=1, M=768,
p_y1=0.5, Delta=1.0, seed=42)
# Get a bag
bag = dataset[0]
X, Y, y_inst = bag['X'], bag['Y'], bag['y_inst']
print(f"Bag shape: {X.shape}")
print(f"Bag label: {Y}")
print(f"Instance labels: {y_inst}")
__init__(N=100, R=3, S_low=15, S_high=45, K=1, M=768, p_y1=0.5, Delta=1.0, mu=0.0, sigma=1.0, seed=42)
Parameters:
-
N(int, default:100) –Number of bags to generate
-
R(int, default:3) –Number of contiguous instances with shifted mean in positive bags
-
S_low(int, default:15) –Minimum bag size
-
S_high(int, default:45) –Maximum bag size
-
K(int, default:1) –Number of features to shift in positive instances
-
M(int, default:768) –Total number of features
-
p_y1(float, default:0.5) –Probability of generating a positive bag
-
Delta(float, default:1.0) –Shift amount for positive instances
-
mu(float, default:0.0) –Mean of the normal distribution
-
sigma(float, default:1.0) –Standard deviation of the normal distribution
-
seed(int, default:42) –Random seed for reproducibility
__getitem__(index)
Parameters:
-
index(int) –Index of the bag to retrieve
Returns:
-
bag_dict(TensorDict) –TensorDict containing the following keys: - X: Bag features of shape
(bag_size, M)- Y: Label of the bag - y_inst: Instance labels of the bag - bag_size: Number of instances in the bag