Skip to content

Shifted Mean MIL Dataset

torchmil.datasets.ShiftedMeanMILDataset

Bases: Dataset

Shifted Mean MIL Dataset for correlated instance labels.

This dataset generates synthetic MIL bags where positive bags contain a contiguous sequence of R instances with shifted mean features. This creates correlation between instance labels within positive bags, which is particularly useful for testing MIL algorithms for their ability to detect local instance-to-instance patterns within bags.

This dataset was introduced in the paper: "Synthetic Data Reveals Generalization Gaps in Correlated Multiple Instance Learning" arXiv preprint: https://arxiv.org/abs/2510.25759

Bag generation: - Each bag has a variable number of instances (between S_low and S_high) - Features are drawn from a normal distribution N(mu, sigma^2) - For positive bags (Y=1), a contiguous sequence of R instances has their first K features shifted by Delta - Instance labels reflect whether an instance has shifted features (1 if shifted, 0 otherwise)

Each bag is returned as a TensorDict with the following keys: - X: The bag's feature matrix of shape (bag_size, M) - Y: The bag's label (1 for positive, 0 for negative) - y_inst: The instance-level labels within the bag - bag_size: The number of instances in the bag

Example usage:

from torchmil.datasets import ShiftedMeanMILDataset

# Create dataset with 100 bags
dataset = ShiftedMeanMILDataset(N=100, R=3, S_low=15, S_high=45, K=1, M=768,
                                 p_y1=0.5, Delta=1.0, seed=42)

# Get a bag
bag = dataset[0]
X, Y, y_inst = bag['X'], bag['Y'], bag['y_inst']
print(f"Bag shape: {X.shape}")
print(f"Bag label: {Y}")
print(f"Instance labels: {y_inst}")

__init__(N=100, R=3, S_low=15, S_high=45, K=1, M=768, p_y1=0.5, Delta=1.0, mu=0.0, sigma=1.0, seed=42)

Parameters:

  • N (int, default: 100 ) –

    Number of bags to generate

  • R (int, default: 3 ) –

    Number of contiguous instances with shifted mean in positive bags

  • S_low (int, default: 15 ) –

    Minimum bag size

  • S_high (int, default: 45 ) –

    Maximum bag size

  • K (int, default: 1 ) –

    Number of features to shift in positive instances

  • M (int, default: 768 ) –

    Total number of features

  • p_y1 (float, default: 0.5 ) –

    Probability of generating a positive bag

  • Delta (float, default: 1.0 ) –

    Shift amount for positive instances

  • mu (float, default: 0.0 ) –

    Mean of the normal distribution

  • sigma (float, default: 1.0 ) –

    Standard deviation of the normal distribution

  • seed (int, default: 42 ) –

    Random seed for reproducibility

__getitem__(index)

Parameters:

  • index (int) –

    Index of the bag to retrieve

Returns:

  • bag_dict ( TensorDict ) –

    TensorDict containing the following keys: - X: Bag features of shape (bag_size, M) - Y: Label of the bag - y_inst: Instance labels of the bag - bag_size: Number of instances in the bag