diff --git a/bris/data/datamodule.py b/bris/data/datamodule.py index eda3de6..52c5dbb 100644 --- a/bris/data/datamodule.py +++ b/bris/data/datamodule.py @@ -14,11 +14,6 @@ from torch_geometric.data import HeteroData import anemoi.datasets.data.subset -try: - from anemoi.training.data.grid_indices import BaseGridIndices, FullGrid -except ImportError as e: - print("Warning! Could not import BaseGridIndices and FullGrid. Continuing without this module") - from bris.checkpoint import Checkpoint from bris.data.dataset import Dataset from bris.utils import check_anemoi_training, recursive_list_to_tuple @@ -250,6 +245,12 @@ def timeincrement(self) -> int: @cached_property def grid_indices(self): + if check_anemoi_training(self.ckptObj.metadata): + try: + from anemoi.training.data.grid_indices import BaseGridIndices, FullGrid + except ImportError as e: + print("Warning! Could not import BaseGridIndices and FullGrid. Continuing without this module") + reader_group_size = self.config.dataloader.read_group_size grid_indices = FullGrid( nodes_name="data",