From 7ab066c1caf674aa644d6ba0b9c742896a259cd2 Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Sat, 23 Nov 2024 01:55:18 +0000 Subject: [PATCH 1/3] - add support for multiple learning rate schedulers in sequential mode; params can take list of schedulers in scheduler.name, with lr params in scheduler.name."0" etc. based on number of schedulers - reduce step frequency of lr schedule to each epoch down from 10 epochs - patch in eval key retrieval --- dreem/inference/eval.py | 2 +- dreem/models/gtr_runner.py | 2 +- dreem/models/model_utils.py | 15 ++++++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/dreem/inference/eval.py b/dreem/inference/eval.py index 0627aa5..afb5762 100644 --- a/dreem/inference/eval.py +++ b/dreem/inference/eval.py @@ -52,7 +52,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: "persistent_tracking", False ) logger.info(f"Computing the following metrics:") - logger.info(model.metrics.test) + logger.info(model.metrics['test']) model.test_results["save_path"] = eval_cfg.cfg.runner.save_path logger.info(f"Saving results to {model.test_results['save_path']}") diff --git a/dreem/models/gtr_runner.py b/dreem/models/gtr_runner.py index 2a09ec3..e9fd9f8 100644 --- a/dreem/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -247,7 +247,7 @@ def configure_optimizers(self) -> dict: "scheduler": scheduler, "monitor": "val_loss", "interval": "epoch", - "frequency": 10, + "frequency": 1, }, } diff --git a/dreem/models/model_utils.py b/dreem/models/model_utils.py index b886d4f..6d2a873 100644 --- a/dreem/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -150,14 +150,27 @@ def init_scheduler( if config is None: return None scheduler = config.get("name") + if scheduler is None: scheduler = "ReduceLROnPlateau" scheduler_params = { param: val for param, val in config.items() if param.lower() != "name" } + try: - scheduler_class = getattr(torch.optim.lr_scheduler, scheduler) + # if a list is provided, apply each one sequentially + if isinstance(scheduler, list): + schedulers = [] + milestones = scheduler_params.get("milestones", None) + for ix, s in enumerate(scheduler): + params = scheduler_params[str(ix)] + schedulers.append(getattr(torch.optim.lr_scheduler, s)(optimizer, **params)) + scheduler_class = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones) + return scheduler_class + + else: + scheduler_class = getattr(torch.optim.lr_scheduler, scheduler) except AttributeError: if scheduler_class is None: print( From f028b17d3d8b0980b2a4fdb4edf26aab53ed3069 Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Sat, 23 Nov 2024 02:12:24 +0000 Subject: [PATCH 2/3] lint --- dreem/inference/eval.py | 2 +- dreem/models/model_utils.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dreem/inference/eval.py b/dreem/inference/eval.py index afb5762..44e4c27 100644 --- a/dreem/inference/eval.py +++ b/dreem/inference/eval.py @@ -52,7 +52,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: "persistent_tracking", False ) logger.info(f"Computing the following metrics:") - logger.info(model.metrics['test']) + logger.info(model.metrics["test"]) model.test_results["save_path"] = eval_cfg.cfg.runner.save_path logger.info(f"Saving results to {model.test_results['save_path']}") diff --git a/dreem/models/model_utils.py b/dreem/models/model_utils.py index 6d2a873..c176bd4 100644 --- a/dreem/models/model_utils.py +++ b/dreem/models/model_utils.py @@ -165,8 +165,12 @@ def init_scheduler( milestones = scheduler_params.get("milestones", None) for ix, s in enumerate(scheduler): params = scheduler_params[str(ix)] - schedulers.append(getattr(torch.optim.lr_scheduler, s)(optimizer, **params)) - scheduler_class = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones) + schedulers.append( + getattr(torch.optim.lr_scheduler, s)(optimizer, **params) + ) + scheduler_class = torch.optim.lr_scheduler.SequentialLR( + optimizer, schedulers, milestones + ) return scheduler_class else: From 580fd2b54acd1c3cd39a85cd3ffac80ab055a046 Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Sat, 23 Nov 2024 02:20:50 +0000 Subject: [PATCH 3/3] add sample config params to show how to use sequential LR --- dreem/training/configs/base.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dreem/training/configs/base.yaml b/dreem/training/configs/base.yaml index 7779cd1..b5b3b76 100644 --- a/dreem/training/configs/base.yaml +++ b/dreem/training/configs/base.yaml @@ -45,6 +45,18 @@ scheduler: threshold: 1e-4 threshold_mode: "rel" + # scheduler: # providing a list of schedulers will apply each scheduler sequentially + # name: ["LinearLR", "LinearLR"] + # "0": + # start_factor: 0.1 + # end_factor: 1 + # total_iters: 3 + # "1": + # start_factor: 1 + # end_factor: 0.1 + # total_iters: 30 + # milestones: [10] + tracker: window_size: 8 use_vis_feats: true