diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index 1136047..82ad930 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -180,11 +180,17 @@ class TSimCNE: then this needs to be appropriately reflected in the projection head as well. - :param None data_transform: The data augmentations to create the - differing views of the input. By default it will use the same - augmentations as written in Böhm et al. (2023); random - cropping, greyscaling, color jitter, horizontal flips. This - parameter should be changed with care. + :param None | ``"is_included"`` | torchvision.transforms | + ffcv.transforms data_transform: The data augmentations to + create the differing views of the input. By default it will + use the same augmentations as written in Böhm et al. (2023); + random cropping, greyscaling, color jitter, horizontal flips. + This parameter should be changed with care. + + If ``data_transform="is_included"``, then it is assumed that + all of the data augmentations are already included within the + dataset object ``X`` (which is passed to `tsimcne.fit`) and + will return the augmented samples along with a (dummy) label. :param [1000, 50, 450] total_epochs: A list of the number of epochs per training stage. The ratio between the stages @@ -445,7 +451,7 @@ def fit_transform( return_backbone_feat=return_backbone_feat, ) - def fit(self, X: torch.utils.data.Dataset | str): + def fit(self, X: torch.utils.data.Dataset | str | Path): """Learn the mapping from the dataset ``X`` to 2D. :param X: The image dataset to be used for training. Will be @@ -462,7 +468,7 @@ def fit(self, X: torch.utils.data.Dataset | str): self.use_ffcv = False self.check_ffcv(self.use_ffcv) - train_dl = self.make_dataloader(X, True, None) + train_dl = self.make_dataloader(X, True, self.data_transform) self.data_transform_none = get_transforms_unnormalized( size=self.image_size, setting="none", use_ffcv=self.use_ffcv @@ -613,11 +619,17 @@ def make_dataloader(self, X, train_or_test, data_transform): data_transform = get_transforms_unnormalized( size=size, setting="none", use_ffcv=self.use_ffcv ) + else: + data_transform = self.data_transform if not self.use_ffcv: - # dataset that returns two augmented views of a given - # datapoint (and label) - dataset_contrastive = TransformedPairDataset(X, data_transform) + if data_transform != "is_included": + # dataset that returns two augmented views of a given + # datapoint (and label) + dataset_contrastive = TransformedPairDataset(X, data_transform) + else: + dataset_contrastive = X + # wrap dataset into dataloader loader = torch.utils.data.DataLoader( dataset_contrastive,