Skip to content

Commit

Permalink
Fix chunking bug with compound dtypes (#1146)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pauladkisson and pre-commit-ci[bot] authored Feb 17, 2025
1 parent 89aac67 commit a896663
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## Bug Fixes
* `run_conversion` does not longer trigger append mode an index error when `nwbfile_path` points to a faulty file [PR #1180](https://github.com/catalystneuro/neuroconv/pull/1180)
* `DatasetIOConfiguration` now recommends `chunk_shape = (len(candidate_dataset),)` for datasets with compound dtypes,
as used by hdmf >= 3.14.6.

## Features
* Use the latest version of ndx-pose for `DeepLabCutInterface` and `LightningPoseDataInterface` [PR #1128](https://github.com/catalystneuro/neuroconv/pull/1128)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"PyYAML>=5.4",
"scipy>=1.4.1",
"h5py>=3.9.0",
"hdmf>=3.13.0,<=3.14.5", # Chunking bug
"hdmf>=3.13.0,<4",
"hdmf_zarr>=0.7.0",
"pynwb>=2.7.0",
"pydantic>=2.0.0",
Expand Down
154 changes: 154 additions & 0 deletions src/neuroconv/tools/hdmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

import math
import warnings
from typing import Union

import numpy as np
from hdmf.build.builders import (
BaseBuilder,
LinkBuilder,
)
from hdmf.data_utils import GenericDataChunkIterator as HDMFGenericDataChunkIterator
from hdmf.utils import get_data_shape


class GenericDataChunkIterator(HDMFGenericDataChunkIterator): # noqa: D101
Expand Down Expand Up @@ -153,3 +159,151 @@ def _get_maxshape(self) -> tuple:

def _get_data(self, selection: tuple[slice]) -> np.ndarray:
return self.data[selection]


def get_full_data_shape(
dataset: Union[GenericDataChunkIterator, np.ndarray, list],
location_in_file: str,
builder: Union[BaseBuilder, None] = None,
):
"""Get the full shape of the dataset at the given location in the file.
Parameters
----------
dataset : hdmf.data_utils.GenericDataChunkIterator | np.ndarray | list
The dataset to get the shape of.
location_in_file : str
The location of the dataset within the NWBFile, e.g. 'acquisition/ElectricalSeries/data'.
builder : hdmf.build.builders.BaseBuilder | None
The builder object that would be used to construct the NWBFile object. If None, the dataset is assumed to NOT
have a compound dtype.
Notes
-----
This function is used instead of hdmf.utils.get_data_shape() to handle datasets with compound dtypes. Currently, if
a dataset has a compound dtype in NWB, the builder will write it as (len(dataset,), but hdmf.utils.get_data_shape()
will return the shape of the dataset as if it were a regular single-dtype array (ex. (N, M) instead of (N,)).
"""
if builder is not None and has_compound_dtype(builder=builder, location_in_file=location_in_file):
return (len(dataset),)
return get_data_shape(data=dataset)


def has_compound_dtype(builder: BaseBuilder, location_in_file: str) -> bool:
"""
Determine if the dataset at the given location in the file has a compound dtype.
Parameters
----------
builder : hdmf.build.builders.BaseBuilder
The builder object that would be used to construct the NWBFile object.
location_in_file : str
The location of the dataset within the NWBFile, e.g. 'acquisition/ElectricalSeries/data'.
Returns
-------
bool
Whether the dataset has a compound dtype.
"""
dataset_builder = get_dataset_builder(builder, location_in_file)
return isinstance(dataset_builder.dtype, list)


def get_dataset_builder(builder: BaseBuilder, location_in_file: str) -> BaseBuilder:
"""Find the appropriate sub-builder for the dataset at the given location in the file.
This function will traverse the groups in the location_in_file until it reaches a DatasetBuilder,
and then return that builder.
Parameters
----------
builder : hdmf.build.builders.BaseBuilder
The builder object that would be used to construct the NWBFile object.
location_in_file : str
The location of the dataset within the NWBFile, e.g. 'acquisition/ElectricalSeries/data'.
Returns
-------
hdmf.build.builders.BaseBuilder
The builder object for the dataset at the given location.
Raises
------
ValueError
If the location_in_file is not found in the builder.
Notes
-----
Items in defined top-level places like electrodes may not be in the groups of the nwbfile-level builder,
but rather in hidden locations like general/extracellular_ephys/electrodes.
Also, some items in these top-level locations may interrupt the order of the location_in_file.
For example, when location_in_file is 'stimulus/AcousticWaveformSeries/data', the builder for that dataset is
located at 'stimulus/presentation/AcousticWaveformSeries/data'.
For this reason, we recursively search for the appropriate sub-builder for each name in the location_in_file.
Also, the first name in location_in_file is inherently suspect due to the way that the location is determined
in _find_location_in_memory_nwbfile(), and may not be present in the builder. For example, when location_in_file is
'lab_meta_data/fiber_photometry/fiber_photometry_table/location/data', the builder for that dataset is located at
'general/fiber_photometry/fiber_photometry_table/location/data'.
"""
split_location = iter(location_in_file.split("/"))
name = next(split_location)

if _find_sub_builder(builder, name) is None:
name = next(split_location)

while name not in builder.datasets and name not in builder.links:
builder = _find_sub_builder(builder, name)
if builder is None:
raise ValueError(f"Could not find location '{location_in_file}' in builder ({name} is missing).")
try:
name = next(split_location)
except StopIteration:
raise ValueError(f"Could not find location '{location_in_file}' in builder ({name} is not a dataset).")
builder = builder[name]
if isinstance(builder, LinkBuilder):
builder = builder.builder
return builder


def _find_sub_builder(builder: BaseBuilder, name: str) -> BaseBuilder:
"""Search breadth-first for a sub-builder by name in a builder object.
Parameters
----------
builder : hdmf.build.builders.BaseBuilder
The builder object to search for the sub-builder in.
name : str
The name of the sub-builder to search for.
Returns
-------
hdmf.build.builders.BaseBuilder
The sub-builder with the given name, or None if it could not be found.
"""
sub_builders = list(builder.groups.values())
return _recursively_search_sub_builders(sub_builders=sub_builders, name=name)


def _recursively_search_sub_builders(sub_builders: list[BaseBuilder], name: str) -> BaseBuilder:
"""Recursively search for a sub-builder by name in a list of sub-builders.
Parameters
----------
sub_builders : list[hdmf.build.builders.BaseBuilder]
The list of sub-builders to search for the sub-builder in.
name : str
The name of the sub-builder to search for.
Returns
-------
hdmf.build.builders.BaseBuilder
The sub-builder with the given name, or None if it could not be found.
"""
sub_sub_builders = []
for sub_builder in sub_builders:
if sub_builder.name == name:
return sub_builder
sub_sub_builders.extend(list(sub_builder.groups.values()))
if len(sub_sub_builders) == 0:
return None
return _recursively_search_sub_builders(sub_builders=sub_sub_builders, name=name)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import numpy as np
import zarr
from hdmf import Container
from hdmf.utils import get_data_shape
from hdmf.build.builders import (
BaseBuilder,
)
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -22,6 +24,7 @@
from pynwb.ecephys import ElectricalSeries
from typing_extensions import Self

from neuroconv.tools.hdmf import get_full_data_shape
from neuroconv.utils.str_utils import human_readable_size

from ._pydantic_pure_json_schema_generator import PureJSONSchemaGenerator
Expand Down Expand Up @@ -245,7 +248,12 @@ def model_json_schema(cls, **kwargs) -> dict[str, Any]:
return super().model_json_schema(mode="validation", schema_generator=PureJSONSchemaGenerator, **kwargs)

@classmethod
def from_neurodata_object(cls, neurodata_object: Container, dataset_name: Literal["data", "timestamps"]) -> Self:
def from_neurodata_object(
cls,
neurodata_object: Container,
dataset_name: Literal["data", "timestamps"],
builder: Union[BaseBuilder, None] = None,
) -> Self:
"""
Construct an instance of a DatasetIOConfiguration for a dataset in a neurodata object in an NWBFile.
Expand All @@ -257,11 +265,13 @@ def from_neurodata_object(cls, neurodata_object: Container, dataset_name: Litera
The name of the field that will become a dataset when written to disk.
Some neurodata objects can have multiple such fields, such as `pynwb.TimeSeries` which can have both `data`
and `timestamps`, each of which can be configured separately.
builder : hdmf.build.builders.BaseBuilder, optional
The builder object that would be used to construct the NWBFile object. If None, the dataset is assumed to
NOT have a compound dtype.
"""
location_in_file = _find_location_in_memory_nwbfile(neurodata_object=neurodata_object, field_name=dataset_name)

candidate_dataset = getattr(neurodata_object, dataset_name)
full_shape = get_data_shape(data=candidate_dataset)
full_shape = get_full_data_shape(dataset=candidate_dataset, location_in_file=location_in_file, builder=builder)
dtype = _infer_dtype(dataset=candidate_dataset)

if isinstance(candidate_dataset, GenericDataChunkIterator):
Expand Down
8 changes: 5 additions & 3 deletions src/neuroconv/tools/nwb_helpers/_dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hdmf.data_utils import DataIO
from hdmf.utils import get_data_shape
from hdmf_zarr import NWBZarrIO
from pynwb import NWBHDF5IO, NWBFile
from pynwb import NWBHDF5IO, NWBFile, get_manager
from pynwb.base import DynamicTable, TimeSeriesReferenceVectorData
from pynwb.file import NWBContainer

Expand Down Expand Up @@ -102,6 +102,8 @@ def get_default_dataset_io_configurations(
)

known_dataset_fields = ("data", "timestamps")
manager = get_manager()
builder = manager.build(nwbfile)
for neurodata_object in nwbfile.objects.values():
if isinstance(neurodata_object, DynamicTable):
dynamic_table = neurodata_object # For readability
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_default_dataset_io_configurations(
continue

dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=column, dataset_name=dataset_name
neurodata_object=column, dataset_name=dataset_name, builder=builder
)

yield dataset_io_configuration
Expand Down Expand Up @@ -168,7 +170,7 @@ def get_default_dataset_io_configurations(
continue

dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=neurodata_object, dataset_name=known_dataset_field
neurodata_object=neurodata_object, dataset_name=known_dataset_field, builder=builder
)

yield dataset_io_configuration
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Unit tests for helper functions of DatasetIOConfiguration."""

import numpy as np
from pynwb.testing.mock.base import mock_TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

from neuroconv.tools.nwb_helpers._configuration_models._base_dataset_io import (
_find_location_in_memory_nwbfile,
_infer_dtype,
)


def test_find_location_in_memory_nwbfile():
nwbfile = mock_NWBFile()
time_series = mock_TimeSeries(name="TimeSeries")
nwbfile.add_acquisition(time_series)
neurodata_object = nwbfile.acquisition["TimeSeries"]
location = _find_location_in_memory_nwbfile(neurodata_object=neurodata_object, field_name="data")
assert location == "acquisition/TimeSeries/data"


def test_infer_dtype_array():
nwbfile = mock_NWBFile()
time_series = mock_TimeSeries(name="TimeSeries", data=np.array([1.0, 2.0, 3.0], dtype="float64"))
nwbfile.add_acquisition(time_series)
dataset = nwbfile.acquisition["TimeSeries"].data
dtype = _infer_dtype(dataset)
assert dtype == np.dtype("float64")


def test_infer_dtype_list():
nwbfile = mock_NWBFile()
time_series = mock_TimeSeries(name="TimeSeries", data=[1.0, 2.0, 3.0])
nwbfile.add_acquisition(time_series)
dataset = nwbfile.acquisition["TimeSeries"].data
dtype = _infer_dtype(dataset)
assert dtype == np.dtype("float64")


def test_infer_dtype_object():
nwbfile = mock_NWBFile()
time_series = mock_TimeSeries(name="TimeSeries", data=(1.0, 2.0, 3.0))
nwbfile.add_acquisition(time_series)
dataset = nwbfile.acquisition["TimeSeries"]
dtype = _infer_dtype(dataset)
assert dtype == np.dtype("object")
Loading

0 comments on commit a896663

Please sign in to comment.