Skip to content

Commit

Permalink
Add option is_included for data_transform
Browse files Browse the repository at this point in the history
This parameter can be used if the augmentations are already fully
specified within the dataset class.
  • Loading branch information
jnboehm committed Jan 30, 2024
1 parent 0fff767 commit f879df3
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions tsimcne/tsimcne.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,17 @@ class TSimCNE:
then this needs to be appropriately reflected in the
projection head as well.
:param None data_transform: The data augmentations to create the
differing views of the input. By default it will use the same
augmentations as written in Böhm et al. (2023); random
cropping, greyscaling, color jitter, horizontal flips. This
parameter should be changed with care.
:param None | ``"is_included"`` | torchvision.transforms |
ffcv.transforms data_transform: The data augmentations to
create the differing views of the input. By default it will
use the same augmentations as written in Böhm et al. (2023);
random cropping, greyscaling, color jitter, horizontal flips.
This parameter should be changed with care.
If ``data_transform="is_included"``, then it is assumed that
all of the data augmentations are already included within the
dataset object ``X`` (which is passed to `tsimcne.fit`) and
will return the augmented samples along with a (dummy) label.
:param [1000, 50, 450] total_epochs: A list of the number of
epochs per training stage. The ratio between the stages
Expand Down Expand Up @@ -445,7 +451,7 @@ def fit_transform(
return_backbone_feat=return_backbone_feat,
)

def fit(self, X: torch.utils.data.Dataset | str):
def fit(self, X: torch.utils.data.Dataset | str | Path):
"""Learn the mapping from the dataset ``X`` to 2D.
:param X: The image dataset to be used for training. Will be
Expand All @@ -462,7 +468,7 @@ def fit(self, X: torch.utils.data.Dataset | str):
self.use_ffcv = False
self.check_ffcv(self.use_ffcv)

train_dl = self.make_dataloader(X, True, None)
train_dl = self.make_dataloader(X, True, self.data_transform)

self.data_transform_none = get_transforms_unnormalized(
size=self.image_size, setting="none", use_ffcv=self.use_ffcv
Expand Down Expand Up @@ -613,11 +619,17 @@ def make_dataloader(self, X, train_or_test, data_transform):
data_transform = get_transforms_unnormalized(
size=size, setting="none", use_ffcv=self.use_ffcv
)
else:
data_transform = self.data_transform

if not self.use_ffcv:
# dataset that returns two augmented views of a given
# datapoint (and label)
dataset_contrastive = TransformedPairDataset(X, data_transform)
if data_transform != "is_included":
# dataset that returns two augmented views of a given
# datapoint (and label)
dataset_contrastive = TransformedPairDataset(X, data_transform)
else:
dataset_contrastive = X

# wrap dataset into dataloader
loader = torch.utils.data.DataLoader(
dataset_contrastive,
Expand Down

0 comments on commit f879df3

Please sign in to comment.