diff --git a/monai/data/utils.py b/monai/data/utils.py index 8c5ae88289..164fa78814 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -50,8 +50,12 @@ issequenceiterable, look_up_option, optional_import, + pytorch_after, ) +if pytorch_after(1, 13): + # import private code for reuse purposes, comment in case things break in the future + from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") nib, _ = optional_import("nibabel") @@ -444,6 +448,23 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): return data +def collate_meta_tensor_fn(batch, *, collate_fn_map=None): + """ + Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` + and so should not be used as a collate function directly in dataloaders. + """ + collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate + collated = collate_fn(batch) # type: ignore + meta_dicts = [i.meta or TraceKeys.NONE for i in batch] + common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) + if common_: + meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] + collated.meta = default_collate(meta_dicts) + collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] + collated.is_batch = True + return collated + + def collate_meta_tensor(batch): """collate a sequence of meta tensor sequences/dictionaries into a single batched metatensor or a dictionary of batched metatensor""" @@ -451,15 +472,7 @@ def collate_meta_tensor(batch): raise NotImplementedError() elem_0 = first(batch) if isinstance(elem_0, MetaObj): - collated = default_collate(batch) - meta_dicts = [i.meta or TraceKeys.NONE for i in batch] - common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) - if common_: - meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] - collated.meta = default_collate(meta_dicts) - collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] - collated.is_batch = True - return collated + return collate_meta_tensor_fn(batch) if isinstance(elem_0, Mapping): return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} if isinstance(elem_0, (tuple, list)): @@ -479,9 +492,16 @@ def list_data_collate(batch: Sequence): Need to use this collate if apply some transforms that can generate batch data. """ + + if pytorch_after(1, 13): + # needs to go here to avoid circular import + from monai.data.meta_tensor import MetaTensor + + default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None + collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor try: if config.USE_META_DICT: data = pickle_operations(data) # bc 0.9.0 @@ -490,9 +510,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = collate_meta_tensor(data_for_batch) + ret[key] = collate_fn(data_for_batch) else: - ret = collate_meta_tensor(data) + ret = collate_fn(data) return ret except RuntimeError as re: re_str = str(re)