From 5064b70ac5b3a917ab18ba5ab164ba0d27e38233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Haugen?= Date: Fri, 17 Jan 2025 15:15:04 +0200 Subject: [PATCH] Fix data paralell --- bris/data/dataset.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/bris/data/dataset.py b/bris/data/dataset.py index 02fd34c..5c4b997 100644 --- a/bris/data/dataset.py +++ b/bris/data/dataset.py @@ -3,6 +3,7 @@ import torch from numpy import datetime64 from torch.utils.data import IterableDataset +from einops import rearrange class Dataset(IterableDataset): @@ -43,7 +44,7 @@ def per_worker_init(self, n_workers, worker_id): raise RuntimeError( "Warning: Underlying dataset does not implement 'per_worker_init'." ) - + def set_comm_group_info( self, global_rank: int, @@ -53,13 +54,19 @@ def set_comm_group_info( reader_group_rank: int, reader_group_size: int, ) -> None: - - self.global_rank = global_rank - self.model_comm_group_id = model_comm_group_id - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.reader_group_rank = reader_group_rank - self.reader_group_size = reader_group_size + if hasattr(self.dataCls, "set_comm_group_info"): + self.dataCls.set_comm_group_info( + global_rank=global_rank, + model_comm_group_id=model_comm_group_id, + model_comm_group_rank=model_comm_group_rank, + model_comm_num_groups=model_comm_num_groups, + reader_group_rank=reader_group_rank, + reader_group_size=reader_group_size, + ) + else: + raise RuntimeError( + "Warning: Underlying dataset does not implement 'set_comm_group_info'." + ) def __iter__( @@ -67,4 +74,4 @@ def __iter__( ) -> tuple[torch.Tensor, datetime64] | tuple[tuple[torch.Tensor], datetime64]: for idx, x in enumerate(iter(self.dataCls)): - yield (x, str(self.data.dates[idx + self.dataCls.multi_step - 1])) + yield (x, str(self.data.dates[self.dataCls.chunk_index_range[idx]]))