From bb2586cca6f33cd6eeb9b68d00815f418dcd8846 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Thu, 9 Jan 2025 15:12:30 +0000 Subject: [PATCH] Attempt to generalize field_shape --- bris/data/datamodule.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/bris/data/datamodule.py b/bris/data/datamodule.py index bfb8b38..eda3de6 100644 --- a/bris/data/datamodule.py +++ b/bris/data/datamodule.py @@ -303,15 +303,13 @@ def field_shape(self) -> tuple: return recursive_list_to_tuple(field_shape) def _get_field_shape(self, decoder_index, dataset_index): - - if isinstance(self.data_reader, anemoi.datasets.data.subset.Subset): - data_reader = self.data_reader.dataset - else: - data_reader = self.data_reader + data_reader = self.data_reader + while isinstance(data_reader, (anemoi.datasets.data.subset.Subset, anemoi.datasets.data.select.Select)): + data_reader = data_reader.dataset if hasattr(data_reader, "datasets"): dataset = data_reader.datasets[decoder_index] - if isinstance(dataset, anemoi.datasets.data.subset.Subset): + while isinstance(dataset, (anemoi.datasets.data.subset.Subset, anemoi.datasets.data.select.Select)): dataset = dataset.dataset if hasattr(dataset, "datasets"):