From 4341167af804cb24523310abf4fb031d600a3f8b Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Thu, 12 Dec 2024 15:08:31 +0400 Subject: [PATCH] Log train split files (#102) --- dreem/io/config.py | 19 +++++++++++++++++++ dreem/training/train.py | 9 +++++++++ 2 files changed, 28 insertions(+) diff --git a/dreem/io/config.py b/dreem/io/config.py index 2c4e70e..4ac8105 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -42,6 +42,8 @@ def __init__(self, cfg: DictConfig, params_cfg: DictConfig | None = None): OmegaConf.set_struct(self.cfg, False) + self._vid_files = {} + def __repr__(self): """Object representation of config class.""" return f"Config({self.cfg})" @@ -276,6 +278,8 @@ def get_dataset( ): dataset_params["normalize_image"] = False + self.data_paths = (mode, vid_files) + return SleapDataset(**dataset_params) elif "tracks" in dataset_params or "source" in dataset_params: @@ -298,6 +302,21 @@ def get_dataset( either `slp_files` or `tracks`/`source`" ) + @property + def data_paths(self): + """Get data paths.""" + return self._vid_files + + @data_paths.setter + def data_paths(self, paths: tuple[str, list[str]]): + """Set data paths. + + Args: + paths: A tuple containing (mode, vid_files) + """ + mode, vid_files = paths + self._vid_files[mode] = vid_files + def get_dataloader( self, dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset", diff --git a/dreem/training/train.py b/dreem/training/train.py index f55394f..b6fbd3d 100644 --- a/dreem/training/train.py +++ b/dreem/training/train.py @@ -78,6 +78,15 @@ def run(cfg: DictConfig): run_logger = train_cfg.get_logger() + if run_logger is not None and isinstance(run_logger, pl.loggers.wandb.WandbLogger): + data_paths = train_cfg.data_paths + flattened_paths = [ + [item] for sublist in data_paths.values() for item in sublist + ] + run_logger.log_text( + "training_files", columns=["data_paths"], data=flattened_paths + ) + callbacks = [] _ = callbacks.extend(train_cfg.get_checkpointing()) _ = callbacks.append(pl.callbacks.LearningRateMonitor())