diff --git a/dreem/io/config.py b/dreem/io/config.py index bdc1751..4ac8105 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -304,10 +304,9 @@ def get_dataset( @property def data_paths(self): - """Get data paths. - """ + """Get data paths.""" return self._vid_files - + @data_paths.setter def data_paths(self, paths: tuple[str, list[str]]): """Set data paths. diff --git a/dreem/training/train.py b/dreem/training/train.py index 74f04f6..b6fbd3d 100644 --- a/dreem/training/train.py +++ b/dreem/training/train.py @@ -80,8 +80,12 @@ def run(cfg: DictConfig): if run_logger is not None and isinstance(run_logger, pl.loggers.wandb.WandbLogger): data_paths = train_cfg.data_paths - flattened_paths = [[item] for sublist in data_paths.values() for item in sublist] - run_logger.log_text("training_files", columns=["data_paths"], data=flattened_paths) + flattened_paths = [ + [item] for sublist in data_paths.values() for item in sublist + ] + run_logger.log_text( + "training_files", columns=["data_paths"], data=flattened_paths + ) callbacks = [] _ = callbacks.extend(train_cfg.get_checkpointing())