Skip to content

Commit

Permalink
Fix data paralell
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Jan 17, 2025
1 parent 4fff6aa commit 5064b70
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions bris/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from numpy import datetime64
from torch.utils.data import IterableDataset
from einops import rearrange


class Dataset(IterableDataset):
Expand Down Expand Up @@ -43,7 +44,7 @@ def per_worker_init(self, n_workers, worker_id):
raise RuntimeError(
"Warning: Underlying dataset does not implement 'per_worker_init'."
)

def set_comm_group_info(
self,
global_rank: int,
Expand All @@ -53,18 +54,24 @@ def set_comm_group_info(
reader_group_rank: int,
reader_group_size: int,
) -> None:

self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size
if hasattr(self.dataCls, "set_comm_group_info"):
self.dataCls.set_comm_group_info(
global_rank=global_rank,
model_comm_group_id=model_comm_group_id,
model_comm_group_rank=model_comm_group_rank,
model_comm_num_groups=model_comm_num_groups,
reader_group_rank=reader_group_rank,
reader_group_size=reader_group_size,
)
else:
raise RuntimeError(
"Warning: Underlying dataset does not implement 'set_comm_group_info'."
)


def __iter__(
self,
) -> tuple[torch.Tensor, datetime64] | tuple[tuple[torch.Tensor], datetime64]:

for idx, x in enumerate(iter(self.dataCls)):
yield (x, str(self.data.dates[idx + self.dataCls.multi_step - 1]))
yield (x, str(self.data.dates[self.dataCls.chunk_index_range[idx]]))

0 comments on commit 5064b70

Please sign in to comment.