Skip to content

Commit

Permalink
Attempt to generalize field_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Håvard Homleid Haugen committed Jan 9, 2025
1 parent 53e83ff commit bb2586c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions bris/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,13 @@ def field_shape(self) -> tuple:
return recursive_list_to_tuple(field_shape)

def _get_field_shape(self, decoder_index, dataset_index):

if isinstance(self.data_reader, anemoi.datasets.data.subset.Subset):
data_reader = self.data_reader.dataset
else:
data_reader = self.data_reader
data_reader = self.data_reader
while isinstance(data_reader, (anemoi.datasets.data.subset.Subset, anemoi.datasets.data.select.Select)):
data_reader = data_reader.dataset

if hasattr(data_reader, "datasets"):
dataset = data_reader.datasets[decoder_index]
if isinstance(dataset, anemoi.datasets.data.subset.Subset):
while isinstance(dataset, (anemoi.datasets.data.subset.Subset, anemoi.datasets.data.select.Select)):
dataset = dataset.dataset

if hasattr(dataset, "datasets"):
Expand Down

0 comments on commit bb2586c

Please sign in to comment.