Skip to content

Commit

Permalink
Add param for the output dim of 1st stage training
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklas Böhm committed Dec 7, 2023
1 parent 5c3a658 commit b4fd415
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tsimcne/tsimcne.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
n_epochs=100,
batch_size=512,
out_dim=2,
pretrain_out_dim=128,
optimizer_name="sgd",
lr_scheduler_name="cos_annealing",
lr="auto_batch",
Expand All @@ -41,6 +42,7 @@ def __init__(
self.projection_head = projection_head
self.n_epochs = n_epochs
self.out_dim = out_dim
self.pretrain_out_dim = pretrain_out_dim
self.batch_size = batch_size
self.optimizer_name = optimizer_name
self.lr_scheduler_name = lr_scheduler_name
Expand All @@ -57,7 +59,7 @@ def _handle_parameters(self):
self.model = make_model(
backbone=self.backbone,
projection_head=self.projection_head,
out_dim=self.out_dim,
out_dim=self.pretrain_out_dim,
)

if self.loss == "infonce":
Expand Down Expand Up @@ -201,6 +203,10 @@ class TSimCNE:
of dimensions (so it could also be used to plot a dataset in
3D, for example).
:param 128 pretrain_out_dim: The number of output dimensions
during pretraining (the first stage).
:param "sgd" optimizer: The optimizer to use. Currently only
``"sgd"`` is allowed.
Expand Down Expand Up @@ -255,6 +261,7 @@ def __init__(
total_epochs=[1000, 50, 450],
batch_size=512,
out_dim=2,
pretrain_out_dim=128,
optimizer="sgd",
lr_scheduler="cos_annealing",
lr="auto_batch",
Expand All @@ -272,6 +279,7 @@ def __init__(
self.projection_head = projection_head
self.data_transform = data_transform
self.out_dim = out_dim
self.pretrain_out_dim = pretrain_out_dim
self.batch_size = batch_size
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
Expand Down Expand Up @@ -440,7 +448,7 @@ def fit(self, X: torch.utils.data.Dataset | str):
metric=self.metric,
backbone=self.backbone,
projection_head=self.projection_head,
out_dim=128,
out_dim=self.pretrain_out_dim,
**train_kwargs,
)

Expand Down

0 comments on commit b4fd415

Please sign in to comment.