From 775829bbe5c63ca18974e70e7451fc8a57d97f5a Mon Sep 17 00:00:00 2001 From: Igor Susmelj Date: Wed, 19 Jul 2023 17:57:37 +0200 Subject: [PATCH] Add transforms for BYOL (#1325) * Add transforms for BYOL * Add test * Fix test * Update BYOL transforms * Update all transforms * Fix typos * Implement feedback --- benchmarks/imagenet/resnet50/byol.py | 8 +- .../benchmarks/cifar10_benchmark.py | 13 +- .../benchmarks/imagenet100_benchmark.py | 14 +- .../benchmarks/imagenette_benchmark.py | 15 +- examples/pytorch/barlowtwins.py | 13 +- examples/pytorch/byol.py | 12 +- examples/pytorch/tico.py | 13 +- examples/pytorch_lightning/barlowtwins.py | 13 +- examples/pytorch_lightning/byol.py | 12 +- examples/pytorch_lightning/tico.py | 13 +- .../barlowtwins.py | 13 +- .../pytorch_lightning_distributed/byol.py | 12 +- .../pytorch_lightning_distributed/tico.py | 13 +- lightly/transforms/__init__.py | 5 + lightly/transforms/byol_transform.py | 175 ++++++++++++++++++ tests/transforms/test_byol_transform.py | 26 +++ 16 files changed, 339 insertions(+), 31 deletions(-) create mode 100644 lightly/transforms/byol_transform.py create mode 100644 tests/transforms/test_byol_transform.py diff --git a/benchmarks/imagenet/resnet50/byol.py b/benchmarks/imagenet/resnet50/byol.py index bfc2f1080..0a1c7eaf2 100644 --- a/benchmarks/imagenet/resnet50/byol.py +++ b/benchmarks/imagenet/resnet50/byol.py @@ -10,7 +10,7 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import get_weight_decay_parameters, update_momentum -from lightly.transforms import SimCLRTransform +from lightly.transforms import BYOLTransform from lightly.utils.benchmarking import OnlineLinearClassifier from lightly.utils.lars import LARS from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule @@ -144,5 +144,7 @@ def configure_optimizers(self): return [optimizer], [scheduler] -# BYOL uses same transform as SimCLR. -transform = SimCLRTransform() +# BYOL uses a slight modification of the SimCLR transforms. +# It uses asymmetric augmentation and solarize. +# Check table 6 in the BYOL paper for more info. +transform = BYOLTransform() diff --git a/docs/source/getting_started/benchmarks/cifar10_benchmark.py b/docs/source/getting_started/benchmarks/cifar10_benchmark.py index 30ef3aa80..509c008cf 100644 --- a/docs/source/getting_started/benchmarks/cifar10_benchmark.py +++ b/docs/source/getting_started/benchmarks/cifar10_benchmark.py @@ -84,6 +84,9 @@ from lightly.models import ResNetGenerator, modules, utils from lightly.models.modules import heads from lightly.transforms import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, DINOTransform, FastSiamTransform, SimCLRTransform, @@ -161,6 +164,12 @@ path_to_train = "/datasets/cifar10/train/" path_to_test = "/datasets/cifar10/test/" +# Use BYOL augmentations +byol_transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) + # Use SimCLR augmentations simclr_transform = SimCLRTransform( input_size=32, @@ -229,8 +238,8 @@ def create_dataset_train_ssl(model): Model class for which to select the transform. """ model_to_transform = { - BarlowTwinsModel: simclr_transform, - BYOLModel: simclr_transform, + BarlowTwinsModel: byol_transform, + BYOLModel: byol_transform, DCL: simclr_transform, DCLW: simclr_transform, DINOModel: dino_transform, diff --git a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py index e4fa787cb..3dcc1beb2 100644 --- a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py @@ -50,6 +50,7 @@ from lightly.models import modules, utils from lightly.models.modules import heads from lightly.transforms import ( + BYOLTransform, DINOTransform, FastSiamTransform, SimCLRTransform, @@ -108,14 +109,17 @@ path_to_train = "/datasets/imagenet100/train/" path_to_test = "/datasets/imagenet100/val/" +# Use BYOL augmentations +byol_transform = BYOLTransform() + # Use SimCLR augmentations -simclr_transform = SimCLRTransform(input_size=input_size) +simclr_transform = SimCLRTransform() # Use SimSiam augmentations -simsiam_transform = SimSiamTransform(input_size=input_size) +simsiam_transform = SimSiamTransform() # Multi crop augmentation for FastSiam -fast_siam_transform = FastSiamTransform(input_size=input_size) +fast_siam_transform = FastSiamTransform() # Multi crop augmentation for SwAV swav_transform = SwaVTransform() @@ -155,8 +159,8 @@ def create_dataset_train_ssl(model): Model class for which to select the transform. """ model_to_transform = { - BarlowTwinsModel: simclr_transform, - BYOLModel: simclr_transform, + BarlowTwinsModel: byol_transform, + BYOLModel: byol_transform, DINOModel: dino_transform, FastSiamModel: fast_siam_transform, MocoModel: simclr_transform, diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py index 7f63a00b9..b5c5e8e71 100644 --- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py @@ -90,6 +90,9 @@ from lightly.models import modules, utils from lightly.models.modules import heads, masked_autoencoder from lightly.transforms import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, DINOTransform, FastSiamTransform, MAETransform, @@ -154,6 +157,12 @@ path_to_train = "/datasets/imagenette2-160/train/" path_to_test = "/datasets/imagenette2-160/val/" +# Use BYOL augmentations +byol_transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=input_size), + view_2_transform=BYOLView2Transform(input_size=input_size), +) + # Use SimCLR augmentations simclr_transform = SimCLRTransform( input_size=input_size, @@ -244,8 +253,8 @@ def create_dataset_train_ssl(model): Model class for which to select the transform. """ model_to_transform = { - BarlowTwinsModel: simclr_transform, - BYOLModel: simclr_transform, + BarlowTwinsModel: byol_transform, + BYOLModel: byol_transform, DCL: simclr_transform, DCLW: simclr_transform, DINOModel: dino_transform, @@ -261,7 +270,7 @@ def create_dataset_train_ssl(model): SwaVModel: swav_transform, SwaVQueueModel: swav_transform, SMoGModel: smog_transform, - TiCoModel: simclr_transform, + TiCoModel: byol_transform, VICRegModel: vicreg_transform, VICRegLModel: vicregl_transform, } diff --git a/examples/pytorch/barlowtwins.py b/examples/pytorch/barlowtwins.py index 8bf4524dc..c6631518d 100644 --- a/examples/pytorch/barlowtwins.py +++ b/examples/pytorch/barlowtwins.py @@ -8,7 +8,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(nn.Module): @@ -30,7 +34,12 @@ def forward(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch/byol.py b/examples/pytorch/byol.py index cf80fdf11..4abcdf0a0 100644 --- a/examples/pytorch/byol.py +++ b/examples/pytorch/byol.py @@ -11,7 +11,11 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -49,7 +53,11 @@ def forward_momentum(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py index bbda91bd1..82274f3f9 100644 --- a/examples/pytorch/tico.py +++ b/examples/pytorch/tico.py @@ -11,7 +11,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -47,7 +51,12 @@ def forward_momentum(self, x): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/barlowtwins.py b/examples/pytorch_lightning/barlowtwins.py index 4053fe65d..86d069e5a 100644 --- a/examples/pytorch_lightning/barlowtwins.py +++ b/examples/pytorch_lightning/barlowtwins.py @@ -9,7 +9,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(pl.LightningModule): @@ -39,7 +43,12 @@ def configure_optimizers(self): model = BarlowTwins() -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/byol.py b/examples/pytorch_lightning/byol.py index 2b227d34d..a87feb4df 100644 --- a/examples/pytorch_lightning/byol.py +++ b/examples/pytorch_lightning/byol.py @@ -12,7 +12,11 @@ from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -62,7 +66,11 @@ def configure_optimizers(self): model = BYOL() -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py index b55246549..48f9d0db1 100644 --- a/examples/pytorch_lightning/tico.py +++ b/examples/pytorch_lightning/tico.py @@ -8,7 +8,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -56,7 +60,12 @@ def configure_optimizers(self): model = TiCo() -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/barlowtwins.py b/examples/pytorch_lightning_distributed/barlowtwins.py index c837dde04..876d7e5f2 100644 --- a/examples/pytorch_lightning_distributed/barlowtwins.py +++ b/examples/pytorch_lightning_distributed/barlowtwins.py @@ -9,7 +9,11 @@ from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) class BarlowTwins(pl.LightningModule): @@ -42,7 +46,12 @@ def configure_optimizers(self): model = BarlowTwins() -transform = SimCLRTransform(input_size=32) +# BarlowTwins uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/byol.py b/examples/pytorch_lightning_distributed/byol.py index 5dd49a16f..6fb9d88fe 100644 --- a/examples/pytorch_lightning_distributed/byol.py +++ b/examples/pytorch_lightning_distributed/byol.py @@ -13,7 +13,11 @@ from lightly.models.modules import BYOLProjectionHead from lightly.models.modules.heads import BYOLPredictionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -63,7 +67,11 @@ def configure_optimizers(self): model = BYOL() -transform = SimCLRTransform(input_size=32) +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py index 0ba938bd5..5ed02bad2 100644 --- a/examples/pytorch_lightning_distributed/tico.py +++ b/examples/pytorch_lightning_distributed/tico.py @@ -8,7 +8,11 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.simclr_transform import SimCLRTransform +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.utils.scheduler import cosine_schedule @@ -56,7 +60,12 @@ def configure_optimizers(self): model = TiCo() -transform = SimCLRTransform(input_size=32) +# TiCo uses BYOL augmentations. +# We disable resizing and gaussian blur for cifar10. +transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform ) diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 59fafe34d..33c70bdad 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -8,6 +8,11 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur diff --git a/lightly/transforms/byol_transform.py b/lightly/transforms/byol_transform.py new file mode 100644 index 000000000..d3eff6a5e --- /dev/null +++ b/lightly/transforms/byol_transform.py @@ -0,0 +1,175 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur +from lightly.transforms.multi_view_transform import MultiViewTransform +from lightly.transforms.rotation import random_rotation_transform +from lightly.transforms.solarize import RandomSolarization +from lightly.transforms.utils import IMAGENET_NORMALIZE + + +class BYOLView1Transform: + def __init__( + self, + input_size: int = 224, + cj_prob: float = 0.8, + cj_strength: float = 1.0, + cj_bright: float = 0.4, + cj_contrast: float = 0.4, + cj_sat: float = 0.2, + cj_hue: float = 0.1, + min_scale: float = 0.08, + random_gray_scale: float = 0.2, + gaussian_blur: float = 1.0, + solarization_prob: float = 0.0, + kernel_size: Optional[float] = None, + sigmas: Tuple[float, float] = (0.1, 2), + vf_prob: float = 0.0, + hf_prob: float = 0.5, + rr_prob: float = 0.0, + rr_degrees: Union[None, float, Tuple[float, float]] = None, + normalize: Union[None, dict] = IMAGENET_NORMALIZE, + ): + color_jitter = T.ColorJitter( + brightness=cj_strength * cj_bright, + contrast=cj_strength * cj_contrast, + saturation=cj_strength * cj_sat, + hue=cj_strength * cj_hue, + ) + + transform = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees), + T.RandomHorizontalFlip(p=hf_prob), + T.RandomVerticalFlip(p=vf_prob), + T.RandomApply([color_jitter], p=cj_prob), + T.RandomGrayscale(p=random_gray_scale), + GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur), + RandomSolarization(prob=solarization_prob), + T.ToTensor(), + ] + if normalize: + transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + self.transform = T.Compose(transform) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """ + Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + return self.transform(image) + + +class BYOLView2Transform: + def __init__( + self, + input_size: int = 224, + cj_prob: float = 0.8, + cj_strength: float = 1.0, + cj_bright: float = 0.4, + cj_contrast: float = 0.4, + cj_sat: float = 0.2, + cj_hue: float = 0.1, + min_scale: float = 0.08, + random_gray_scale: float = 0.2, + gaussian_blur: float = 0.1, + solarization_prob: float = 0.2, + kernel_size: Optional[float] = None, + sigmas: Tuple[float, float] = (0.1, 2), + vf_prob: float = 0.0, + hf_prob: float = 0.5, + rr_prob: float = 0.0, + rr_degrees: Union[None, float, Tuple[float, float]] = None, + normalize: Union[None, dict] = IMAGENET_NORMALIZE, + ): + color_jitter = T.ColorJitter( + brightness=cj_strength * cj_bright, + contrast=cj_strength * cj_contrast, + saturation=cj_strength * cj_sat, + hue=cj_strength * cj_hue, + ) + + transform = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + random_rotation_transform(rr_prob=rr_prob, rr_degrees=rr_degrees), + T.RandomHorizontalFlip(p=hf_prob), + T.RandomVerticalFlip(p=vf_prob), + T.RandomApply([color_jitter], p=cj_prob), + T.RandomGrayscale(p=random_gray_scale), + GaussianBlur(kernel_size=kernel_size, sigmas=sigmas, prob=gaussian_blur), + RandomSolarization(prob=solarization_prob), + T.ToTensor(), + ] + if normalize: + transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + self.transform = T.Compose(transform) + + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """ + Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + return self.transform(image) + + +class BYOLTransform(MultiViewTransform): + """Implements the transformations for BYOL[0]. + + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of Tensor of length 2. + + Applies the following augmentations by default: + - Random resized crop + - Random horizontal flip + - Color jitter + - Random gray scale + - Gaussian blur + - Solarization + - ImageNet normalization + + Note that SimCLR v1 and v2 use similar augmentations. In detail, BYOL has + asymmetric gaussian blur and solarization. Furthermore, BYOL has weaker + color jitter compared to SimCLR. + + - [0]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf + + Input to this transform: + PIL Image or Tensor. + + Output of this transform: + List of [tensor, tensor]. + + Attributes: + view_1_transform: The transform for the first view. + view_2_transform: The transform for the second view. + """ + + def __init__( + self, + view_1_transform: Union[BYOLView1Transform, None] = None, + view_2_transform: Union[BYOLView2Transform, None] = None, + ): + # We need to initialize the transforms here + view_1_transform = view_1_transform or BYOLView1Transform() + view_2_transform = view_2_transform or BYOLView2Transform() + super().__init__(transforms=[view_1_transform, view_2_transform]) diff --git a/tests/transforms/test_byol_transform.py b/tests/transforms/test_byol_transform.py new file mode 100644 index 000000000..b13950709 --- /dev/null +++ b/tests/transforms/test_byol_transform.py @@ -0,0 +1,26 @@ +from PIL import Image + +from lightly.transforms.byol_transform import ( + BYOLTransform, + BYOLView1Transform, + BYOLView2Transform, +) + + +def test_view_on_pil_image(): + single_view_transform = BYOLView1Transform(input_size=32) + sample = Image.new("RGB", (100, 100)) + output = single_view_transform(sample) + assert output.shape == (3, 32, 32) + + +def test_multi_view_on_pil_image(): + multi_view_transform = BYOLTransform( + view_1_transform=BYOLView1Transform(input_size=32), + view_2_transform=BYOLView2Transform(input_size=32), + ) + sample = Image.new("RGB", (100, 100)) + output = multi_view_transform(sample) + assert len(output) == 2 + assert output[0].shape == (3, 32, 32) + assert output[1].shape == (3, 32, 32)