Skip to content

Commit

Permalink
Add transforms for BYOL (#1325)
Browse files Browse the repository at this point in the history
* Add transforms for BYOL

* Add test

* Fix test

* Update BYOL transforms

* Update all transforms

* Fix typos

* Implement feedback
  • Loading branch information
IgorSusmelj authored Jul 19, 2023
1 parent 9e1ca15 commit 775829b
Show file tree
Hide file tree
Showing 16 changed files with 339 additions and 31 deletions.
8 changes: 5 additions & 3 deletions benchmarks/imagenet/resnet50/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import get_weight_decay_parameters, update_momentum
from lightly.transforms import SimCLRTransform
from lightly.transforms import BYOLTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
Expand Down Expand Up @@ -144,5 +144,7 @@ def configure_optimizers(self):
return [optimizer], [scheduler]


# BYOL uses same transform as SimCLR.
transform = SimCLRTransform()
# BYOL uses a slight modification of the SimCLR transforms.
# It uses asymmetric augmentation and solarize.
# Check table 6 in the BYOL paper for more info.
transform = BYOLTransform()
13 changes: 11 additions & 2 deletions docs/source/getting_started/benchmarks/cifar10_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
from lightly.models import ResNetGenerator, modules, utils
from lightly.models.modules import heads
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
DINOTransform,
FastSiamTransform,
SimCLRTransform,
Expand Down Expand Up @@ -161,6 +164,12 @@
path_to_train = "/datasets/cifar10/train/"
path_to_test = "/datasets/cifar10/test/"

# Use BYOL augmentations
byol_transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)

# Use SimCLR augmentations
simclr_transform = SimCLRTransform(
input_size=32,
Expand Down Expand Up @@ -229,8 +238,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
BarlowTwinsModel: simclr_transform,
BYOLModel: simclr_transform,
BarlowTwinsModel: byol_transform,
BYOLModel: byol_transform,
DCL: simclr_transform,
DCLW: simclr_transform,
DINOModel: dino_transform,
Expand Down
14 changes: 9 additions & 5 deletions docs/source/getting_started/benchmarks/imagenet100_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from lightly.models import modules, utils
from lightly.models.modules import heads
from lightly.transforms import (
BYOLTransform,
DINOTransform,
FastSiamTransform,
SimCLRTransform,
Expand Down Expand Up @@ -108,14 +109,17 @@
path_to_train = "/datasets/imagenet100/train/"
path_to_test = "/datasets/imagenet100/val/"

# Use BYOL augmentations
byol_transform = BYOLTransform()

# Use SimCLR augmentations
simclr_transform = SimCLRTransform(input_size=input_size)
simclr_transform = SimCLRTransform()

# Use SimSiam augmentations
simsiam_transform = SimSiamTransform(input_size=input_size)
simsiam_transform = SimSiamTransform()

# Multi crop augmentation for FastSiam
fast_siam_transform = FastSiamTransform(input_size=input_size)
fast_siam_transform = FastSiamTransform()

# Multi crop augmentation for SwAV
swav_transform = SwaVTransform()
Expand Down Expand Up @@ -155,8 +159,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
BarlowTwinsModel: simclr_transform,
BYOLModel: simclr_transform,
BarlowTwinsModel: byol_transform,
BYOLModel: byol_transform,
DINOModel: dino_transform,
FastSiamModel: fast_siam_transform,
MocoModel: simclr_transform,
Expand Down
15 changes: 12 additions & 3 deletions docs/source/getting_started/benchmarks/imagenette_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
from lightly.models import modules, utils
from lightly.models.modules import heads, masked_autoencoder
from lightly.transforms import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
DINOTransform,
FastSiamTransform,
MAETransform,
Expand Down Expand Up @@ -154,6 +157,12 @@
path_to_train = "/datasets/imagenette2-160/train/"
path_to_test = "/datasets/imagenette2-160/val/"

# Use BYOL augmentations
byol_transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=input_size),
view_2_transform=BYOLView2Transform(input_size=input_size),
)

# Use SimCLR augmentations
simclr_transform = SimCLRTransform(
input_size=input_size,
Expand Down Expand Up @@ -244,8 +253,8 @@ def create_dataset_train_ssl(model):
Model class for which to select the transform.
"""
model_to_transform = {
BarlowTwinsModel: simclr_transform,
BYOLModel: simclr_transform,
BarlowTwinsModel: byol_transform,
BYOLModel: byol_transform,
DCL: simclr_transform,
DCLW: simclr_transform,
DINOModel: dino_transform,
Expand All @@ -261,7 +270,7 @@ def create_dataset_train_ssl(model):
SwaVModel: swav_transform,
SwaVQueueModel: swav_transform,
SMoGModel: smog_transform,
TiCoModel: simclr_transform,
TiCoModel: byol_transform,
VICRegModel: vicreg_transform,
VICRegLModel: vicregl_transform,
}
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch/barlowtwins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)


class BarlowTwins(nn.Module):
Expand All @@ -30,7 +34,12 @@ def forward(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = SimCLRTransform(input_size=32)
# BarlowTwins uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
12 changes: 10 additions & 2 deletions examples/pytorch/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -49,7 +53,11 @@ def forward_momentum(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = SimCLRTransform(input_size=32)
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -47,7 +51,12 @@ def forward_momentum(self, x):
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = SimCLRTransform(input_size=32)
# TiCo uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch_lightning/barlowtwins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)


class BarlowTwins(pl.LightningModule):
Expand Down Expand Up @@ -39,7 +43,12 @@ def configure_optimizers(self):

model = BarlowTwins()

transform = SimCLRTransform(input_size=32)
# BarlowTwins uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
12 changes: 10 additions & 2 deletions examples/pytorch_lightning/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -62,7 +66,11 @@ def configure_optimizers(self):

model = BYOL()

transform = SimCLRTransform(input_size=32)
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch_lightning/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -56,7 +60,12 @@ def configure_optimizers(self):

model = TiCo()

transform = SimCLRTransform(input_size=32)
# TiCo uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch_lightning_distributed/barlowtwins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)


class BarlowTwins(pl.LightningModule):
Expand Down Expand Up @@ -42,7 +46,12 @@ def configure_optimizers(self):

model = BarlowTwins()

transform = SimCLRTransform(input_size=32)
# BarlowTwins uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
12 changes: 10 additions & 2 deletions examples/pytorch_lightning_distributed/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from lightly.models.modules import BYOLProjectionHead
from lightly.models.modules.heads import BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -63,7 +67,11 @@ def configure_optimizers(self):

model = BYOL()

transform = SimCLRTransform(input_size=32)
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
13 changes: 11 additions & 2 deletions examples/pytorch_lightning_distributed/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -56,7 +60,12 @@ def configure_optimizers(self):

model = TiCo()

transform = SimCLRTransform(input_size=32)
# TiCo uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
Expand Down
5 changes: 5 additions & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform
from lightly.transforms.fast_siam_transform import FastSiamTransform
from lightly.transforms.gaussian_blur import GaussianBlur
Expand Down
Loading

0 comments on commit 775829b

Please sign in to comment.