diff --git a/dreem/inference/eval.py b/dreem/inference/eval.py index 0627aa57..44e4c277 100644 --- a/dreem/inference/eval.py +++ b/dreem/inference/eval.py @@ -52,7 +52,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: "persistent_tracking", False ) logger.info(f"Computing the following metrics:") - logger.info(model.metrics.test) + logger.info(model.metrics["test"]) model.test_results["save_path"] = eval_cfg.cfg.runner.save_path logger.info(f"Saving results to {model.test_results['save_path']}") diff --git a/dreem/models/gtr_runner.py b/dreem/models/gtr_runner.py index 2a09ec3a..e9fd9f84 100644 --- a/dreem/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -247,7 +247,7 @@ def configure_optimizers(self) -> dict: "scheduler": scheduler, "monitor": "val_loss", "interval": "epoch", - "frequency": 10, + "frequency": 1, }, } diff --git a/dreem/models/model_utils.py b/dreem/models/model_utils.py index b886d4fb..c176bd42 100644 --- a/dreem/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -150,14 +150,31 @@ def init_scheduler( if config is None: return None scheduler = config.get("name") + if scheduler is None: scheduler = "ReduceLROnPlateau" scheduler_params = { param: val for param, val in config.items() if param.lower() != "name" } + try: - scheduler_class = getattr(torch.optim.lr_scheduler, scheduler) + # if a list is provided, apply each one sequentially + if isinstance(scheduler, list): + schedulers = [] + milestones = scheduler_params.get("milestones", None) + for ix, s in enumerate(scheduler): + params = scheduler_params[str(ix)] + schedulers.append( + getattr(torch.optim.lr_scheduler, s)(optimizer, **params) + ) + scheduler_class = torch.optim.lr_scheduler.SequentialLR( + optimizer, schedulers, milestones + ) + return scheduler_class + + else: + scheduler_class = getattr(torch.optim.lr_scheduler, scheduler) except AttributeError: if scheduler_class is None: print( diff --git a/dreem/training/configs/base.yaml b/dreem/training/configs/base.yaml index 7779cd13..b5b3b766 100644 --- a/dreem/training/configs/base.yaml +++ b/dreem/training/configs/base.yaml @@ -45,6 +45,18 @@ scheduler: threshold: 1e-4 threshold_mode: "rel" + # scheduler: # providing a list of schedulers will apply each scheduler sequentially + # name: ["LinearLR", "LinearLR"] + # "0": + # start_factor: 0.1 + # end_factor: 1 + # total_iters: 3 + # "1": + # start_factor: 1 + # end_factor: 0.1 + # total_iters: 30 + # milestones: [10] + tracker: window_size: 8 use_vis_feats: true