From 3234937b514047ab2b9774abc54cbb6e945ab447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20B=C3=B6hm?= Date: Tue, 30 Jan 2024 20:23:11 +0100 Subject: [PATCH] Add default resnet18 from torch to model dict --- tsimcne/models/simclr_like.py | 9 ++++++++- tsimcne/tsimcne.py | 12 +++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tsimcne/models/simclr_like.py b/tsimcne/models/simclr_like.py index 62b508b..e7d9f95 100644 --- a/tsimcne/models/simclr_like.py +++ b/tsimcne/models/simclr_like.py @@ -357,6 +357,12 @@ def forward(self, x): return self.model.avgpool(feat) +def pytorch_resnet(in_channel=3): + model = torchvision.models.resnet18(pretrained=False) + model.fc = torch.nn.Identity() + return model + + def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) @@ -394,7 +400,8 @@ def efficientnet_v2_l(**kwargs): model_dict = { - "resnet18": [resnet18, 512], + "resnet18": [pytorch_resnet, 512], + "resnet18_simclr": [resnet18, 512], "resnet34": [resnet34, 512], "resnet50": [resnet50, 2048], "resnet101": [resnet101, 2048], diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index 1fe8385..1136047 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -20,7 +20,7 @@ def __init__( model=None, loss="infonce", metric=None, - backbone="resnet18", + backbone="resnet18_simclr", projection_head="mlp", n_epochs=100, batch_size=512, @@ -160,9 +160,11 @@ class TSimCNE: kernel). Another option is ``"cosine"`` to get the default SimCLR loss. - :param "resnet18" backbone: Backbone to use for the contrastive - model. Defaults to ResNet18. Other options are - ``"resnet50"``, etc. or simply pass in a torch model directly. + :param "resnet18_simclr" backbone: Backbone to use for the + contrastive model. Defaults to ResNet18 as defined in the + original SimCLR paper (so with a smaller kernel size). Other + options are ``"resnet50"``, etc. or simply pass in a torch + model directly. :param "mlp" projection_head: The projection head that maps from the backbone features down to the ``"out_dim"``. Also accepts @@ -275,7 +277,7 @@ def __init__( model=None, loss="infonce", metric=None, - backbone="resnet18", + backbone="resnet18_simclr", projection_head="mlp", data_transform=None, total_epochs=[1000, 50, 450],