From d96702d7f64cd45c4b746d17d72b4a972c94eb2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20B=C3=B6hm?= Date: Sun, 15 Dec 2024 13:10:11 +0100 Subject: [PATCH] Remove unnecesasry code from simclr_like.py --- tsimcne/models/simclr_like.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tsimcne/models/simclr_like.py b/tsimcne/models/simclr_like.py index 83b481d..468ba84 100644 --- a/tsimcne/models/simclr_like.py +++ b/tsimcne/models/simclr_like.py @@ -3,8 +3,6 @@ import torch.nn.functional as F import torchvision -from ..base import ProjectBase - def make_model( backbone, @@ -83,26 +81,6 @@ def make_projection_head(name="mlp", in_dim=512, hidden_dim=1024, out_dim=128): ) -class SimCLRModel(ProjectBase): - def __init__(self, path, random_state=None, **kwargs): - super().__init__(path, random_state=random_state) - self.kwargs = kwargs - - def get_deps(self): - return [] - - def load(self): - pass - - def compute(self): - self.torch_seed = self.random_state.integers(2**64 - 1, dtype="uint") - self.model = make_model(**self.kwargs, seed=self.torch_seed) - - def save(self): - save_data = dict(model=self.model, model_sd=self.model.state_dict()) - self.save_lambda_alt(self.outdir / "model.pt", save_data, torch.save) - - class ContrastiveFC(nn.Module): def __init__( self,