Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- add support for multiple learning rate schedulers in sequential mod… #99

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"persistent_tracking", False
)
logger.info(f"Computing the following metrics:")
logger.info(model.metrics.test)
logger.info(model.metrics['test'])
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
logger.info(f"Saving results to {model.test_results['save_path']}")

Expand Down
2 changes: 1 addition & 1 deletion dreem/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def configure_optimizers(self) -> dict:
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 10,
"frequency": 1,
},
}

Expand Down
15 changes: 14 additions & 1 deletion dreem/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,27 @@ def init_scheduler(
if config is None:
return None
scheduler = config.get("name")

if scheduler is None:
scheduler = "ReduceLROnPlateau"

scheduler_params = {
param: val for param, val in config.items() if param.lower() != "name"
}

try:
scheduler_class = getattr(torch.optim.lr_scheduler, scheduler)
# if a list is provided, apply each one sequentially
if isinstance(scheduler, list):
schedulers = []
milestones = scheduler_params.get("milestones", None)
for ix, s in enumerate(scheduler):
params = scheduler_params[str(ix)]
schedulers.append(getattr(torch.optim.lr_scheduler, s)(optimizer, **params))
scheduler_class = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones)
return scheduler_class

else:
scheduler_class = getattr(torch.optim.lr_scheduler, scheduler)
except AttributeError:
if scheduler_class is None:
print(
Expand Down
Loading