Skip to content

Commit

Permalink
Add default resnet18 from torch to model dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklas Böhm committed Jan 30, 2024
1 parent 830bf0b commit 3234937
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
9 changes: 8 additions & 1 deletion tsimcne/models/simclr_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ def forward(self, x):
return self.model.avgpool(feat)


def pytorch_resnet(in_channel=3):
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Identity()
return model


def resnet18(**kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

Expand Down Expand Up @@ -394,7 +400,8 @@ def efficientnet_v2_l(**kwargs):


model_dict = {
"resnet18": [resnet18, 512],
"resnet18": [pytorch_resnet, 512],
"resnet18_simclr": [resnet18, 512],
"resnet34": [resnet34, 512],
"resnet50": [resnet50, 2048],
"resnet101": [resnet101, 2048],
Expand Down
12 changes: 7 additions & 5 deletions tsimcne/tsimcne.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
model=None,
loss="infonce",
metric=None,
backbone="resnet18",
backbone="resnet18_simclr",
projection_head="mlp",
n_epochs=100,
batch_size=512,
Expand Down Expand Up @@ -160,9 +160,11 @@ class TSimCNE:
kernel). Another option is ``"cosine"`` to get the default
SimCLR loss.
:param "resnet18" backbone: Backbone to use for the contrastive
model. Defaults to ResNet18. Other options are
``"resnet50"``, etc. or simply pass in a torch model directly.
:param "resnet18_simclr" backbone: Backbone to use for the
contrastive model. Defaults to ResNet18 as defined in the
original SimCLR paper (so with a smaller kernel size). Other
options are ``"resnet50"``, etc. or simply pass in a torch
model directly.
:param "mlp" projection_head: The projection head that maps from
the backbone features down to the ``"out_dim"``. Also accepts
Expand Down Expand Up @@ -275,7 +277,7 @@ def __init__(
model=None,
loss="infonce",
metric=None,
backbone="resnet18",
backbone="resnet18_simclr",
projection_head="mlp",
data_transform=None,
total_epochs=[1000, 50, 450],
Expand Down

0 comments on commit 3234937

Please sign in to comment.