From 68218cf709b9953c0e2067f59e566dd8acc4d05a Mon Sep 17 00:00:00 2001 From: Jonny Saunders Date: Fri, 2 Feb 2024 11:15:42 -0800 Subject: [PATCH] Write references in compound datasets as an array (#146) * Write references in compound datasets as an array rather than iteratively * Write references in compound datasets as an array rather than iteratively * Test that compound reference arrays are written as a compound array * Update base_tests_zarrio.py * Use explicit type map instead of stringlike types Expand check for string types Make 'else' a failure condition in `__resolve_dtype_helper__` (rather than implicitly assuming list) * Write strings as objects in compound reference datasets --------- Co-authored-by: Matthew Avaylon Co-authored-by: Oliver Ruebel --- src/hdmf_zarr/backend.py | 50 +++++++++++++++++++++++++++------ tests/unit/base_tests_zarrio.py | 5 ++++ 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/hdmf_zarr/backend.py b/src/hdmf_zarr/backend.py index 7ca788c1..0a3fae8e 100644 --- a/src/hdmf_zarr/backend.py +++ b/src/hdmf_zarr/backend.py @@ -1017,18 +1017,50 @@ def write_dataset(self, **kwargs): # noqa: C901 type_str.append(self.__serial_dtype__(t)[0]) if len(refs) > 0: - dset = parent.require_dataset(name, - shape=(len(data), ), - dtype=object, - object_codec=self.__codec_cls(), - **options['io_settings']) + self._written_builders.set_written(builder) # record that the builder has been written - dset.attrs['zarr_dtype'] = type_str + + # gather items to write + new_items = [] for j, item in enumerate(data): new_item = list(item) for i in refs: new_item[i] = self.__get_ref(item[i], export_source=export_source) - dset[j] = new_item + new_items.append(tuple(new_item)) + + # Create dtype for storage, replacing values to match hdmf's hdf5 behavior + # --- + # TODO: Replace with a simple one-liner once __resolve_dtype_helper__ is + # compatible with zarr's need for fixed-length string dtypes. + # dtype = self.__resolve_dtype_helper__(options['dtype']) + + new_dtype = [] + for field in options['dtype']: + if field['dtype'] is str or field['dtype'] in ( + 'str', 'text', 'utf', 'utf8', 'utf-8', 'isodatetime' + ): + # Zarr does not support variable length strings + new_dtype.append((field['name'], 'O')) + elif isinstance(field['dtype'], dict): + # eg. for some references, dtype will be of the form + # {'target_type': 'Baz', 'reftype': 'object'} + # which should just get serialized as an object + new_dtype.append((field['name'], 'O')) + else: + new_dtype.append((field['name'], self.__resolve_dtype_helper__(field['dtype']))) + dtype = np.dtype(new_dtype) + + # cast and store compound dataset + arr = np.array(new_items, dtype=dtype) + dset = parent.require_dataset( + name, + shape=(len(arr),), + dtype=dtype, + object_codec=self.__codec_cls(), + **options['io_settings'] + ) + dset.attrs['zarr_dtype'] = type_str + dset[...] = arr else: # write a compound datatype dset = self.__list_fill__(parent, name, data, options) @@ -1153,8 +1185,10 @@ def __resolve_dtype_helper__(cls, dtype): return cls.__dtypes.get(dtype) elif isinstance(dtype, dict): return cls.__dtypes.get(dtype['reftype']) - else: + elif isinstance(dtype, list): return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype]) + else: + raise ValueError(f'Cant resolve dtype {dtype}') @classmethod def get_type(cls, data): diff --git a/tests/unit/base_tests_zarrio.py b/tests/unit/base_tests_zarrio.py index 693f9ff5..d142499d 100644 --- a/tests/unit/base_tests_zarrio.py +++ b/tests/unit/base_tests_zarrio.py @@ -434,6 +434,11 @@ def test_read_reference_compound(self): self.read() builder = self.createReferenceCompoundBuilder()['ref_dataset'] read_builder = self.root['ref_dataset'] + + # ensure the array was written as a compound array + ref_dtype = np.dtype([('id', '