diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index f22d521..f4b3cdf 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -25,6 +25,7 @@ def __init__( n_epochs=100, batch_size=512, out_dim=2, + pretrain_out_dim=128, optimizer_name="sgd", lr_scheduler_name="cos_annealing", lr="auto_batch", @@ -41,6 +42,7 @@ def __init__( self.projection_head = projection_head self.n_epochs = n_epochs self.out_dim = out_dim + self.pretrain_out_dim = pretrain_out_dim self.batch_size = batch_size self.optimizer_name = optimizer_name self.lr_scheduler_name = lr_scheduler_name @@ -57,7 +59,7 @@ def _handle_parameters(self): self.model = make_model( backbone=self.backbone, projection_head=self.projection_head, - out_dim=self.out_dim, + out_dim=self.pretrain_out_dim, ) if self.loss == "infonce": @@ -201,6 +203,10 @@ class TSimCNE: of dimensions (so it could also be used to plot a dataset in 3D, for example). + + :param 128 pretrain_out_dim: The number of output dimensions + during pretraining (the first stage). + :param "sgd" optimizer: The optimizer to use. Currently only ``"sgd"`` is allowed. @@ -255,6 +261,7 @@ def __init__( total_epochs=[1000, 50, 450], batch_size=512, out_dim=2, + pretrain_out_dim=128, optimizer="sgd", lr_scheduler="cos_annealing", lr="auto_batch", @@ -272,6 +279,7 @@ def __init__( self.projection_head = projection_head self.data_transform = data_transform self.out_dim = out_dim + self.pretrain_out_dim = pretrain_out_dim self.batch_size = batch_size self.optimizer = optimizer self.lr_scheduler = lr_scheduler @@ -440,7 +448,7 @@ def fit(self, X: torch.utils.data.Dataset | str): metric=self.metric, backbone=self.backbone, projection_head=self.projection_head, - out_dim=128, + out_dim=self.pretrain_out_dim, **train_kwargs, )