Skip to content

Commit

Permalink
refactored from_ methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Aug 14, 2024
1 parent 4ac6e33 commit ce267fb
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/neuroconv/tools/nwb_helpers/_backend_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ def get_existing_backend_configuration(nwbfile: NWBFile) -> Union[HDF5BackendCon
else:
raise ValueError(f"The backend of the NWBFile from io {read_io} is not recognized.")
BackendConfigurationClass = BACKEND_CONFIGURATIONS[backend]
return BackendConfigurationClass.from_existing_nwbfile(nwbfile=nwbfile)
return BackendConfigurationClass.from_nwbfile(nwbfile=nwbfile, mode="existing")
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from ._base_dataset_io import DatasetIOConfiguration
from ._pydantic_pure_json_schema_generator import PureJSONSchemaGenerator
from .._dataset_configuration import get_default_dataset_io_configurations
from .._dataset_configuration import (
get_default_dataset_io_configurations,
get_existing_dataset_io_configurations,
)


class BackendConfiguration(BaseModel):
Expand Down Expand Up @@ -56,11 +59,16 @@ def model_json_schema(cls, **kwargs) -> Dict[str, Any]:
return super().model_json_schema(mode="validation", schema_generator=PureJSONSchemaGenerator, **kwargs)

@classmethod
def from_nwbfile(cls, nwbfile: NWBFile) -> Self:
default_dataset_configurations = get_default_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend)
def from_nwbfile(cls, nwbfile: NWBFile, mode: Literal["default", "existing"] = "default") -> Self:
if mode == "default":
dataset_io_configurations = get_default_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend)
elif mode == "existing":
dataset_io_configurations = get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend)
else:
raise ValueError(f"mode must be either 'default' or 'existing' but got {mode}")
dataset_configurations = {
default_dataset_configuration.location_in_file: default_dataset_configuration
for default_dataset_configuration in default_dataset_configurations
for default_dataset_configuration in dataset_io_configurations
}

return cls(dataset_configurations=dataset_configurations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from typing import ClassVar, Dict, Literal, Type

from pydantic import Field
from pynwb import H5DataIO, NWBFile
from typing_extensions import Self
from pynwb import H5DataIO

from ._base_backend import BackendConfiguration
from ._hdf5_dataset_io import HDF5DatasetIOConfiguration
from .._dataset_configuration import get_existing_dataset_io_configurations


class HDF5BackendConfiguration(BackendConfiguration):
Expand All @@ -24,13 +22,3 @@ class HDF5BackendConfiguration(BackendConfiguration):
"information for writing the datasets to disk using the HDF5 backend."
)
)

@classmethod
def from_existing_nwbfile(cls, nwbfile: NWBFile) -> Self:
existing_dataset_configurations = get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend)
dataset_configurations = {
existing_dataset_configuration.location_in_file: existing_dataset_configuration
for existing_dataset_configuration in existing_dataset_configurations
}

return cls(dataset_configurations=dataset_configurations)
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,36 @@ def get_data_io_kwargs(self) -> Dict[str, Any]:
return dict(chunks=self.chunk_shape, **compression_bundle)

@classmethod
def from_existing_neurodata_object(
cls, neurodata_object: Container, dataset_name: Literal["data", "timestamps"]
def from_neurodata_object(
cls,
neurodata_object: Container,
dataset_name: Literal["data", "timestamps"],
mode: Literal["default", "existing"] = "default",
) -> Self:
location_in_file = _find_location_in_memory_nwbfile(neurodata_object=neurodata_object, field_name=dataset_name)
full_shape = getattr(neurodata_object, dataset_name).shape
dtype = getattr(neurodata_object, dataset_name).dtype
chunk_shape = getattr(neurodata_object, dataset_name).chunks
buffer_shape = getattr(neurodata_object, dataset_name).maxshape
compression_method = getattr(neurodata_object, dataset_name).compression
compression_opts = getattr(neurodata_object, dataset_name).compression_opts
compression_options = dict(compression_opts=compression_opts)
return cls(
object_id=neurodata_object.object_id,
object_name=neurodata_object.name,
location_in_file=location_in_file,
dataset_name=dataset_name,
full_shape=full_shape,
dtype=dtype,
chunk_shape=chunk_shape,
buffer_shape=buffer_shape,
compression_method=compression_method,
compression_options=compression_options,
)
if mode == "default":
return super().from_neurodata_object(neurodata_object=neurodata_object, dataset_name=dataset_name)
elif mode == "existing":
location_in_file = _find_location_in_memory_nwbfile(
neurodata_object=neurodata_object, field_name=dataset_name
)
full_shape = getattr(neurodata_object, dataset_name).shape
dtype = getattr(neurodata_object, dataset_name).dtype
chunk_shape = getattr(neurodata_object, dataset_name).chunks
buffer_shape = getattr(neurodata_object, dataset_name).maxshape
compression_method = getattr(neurodata_object, dataset_name).compression
compression_opts = getattr(neurodata_object, dataset_name).compression_opts
compression_options = dict(compression_opts=compression_opts)
return cls(
object_id=neurodata_object.object_id,
object_name=neurodata_object.name,
location_in_file=location_in_file,
dataset_name=dataset_name,
full_shape=full_shape,
dtype=dtype,
chunk_shape=chunk_shape,
buffer_shape=buffer_shape,
compression_method=compression_method,
compression_options=compression_options,
)
else:
raise ValueError(f"mode must be either 'default' or 'existing' but got {mode}")
12 changes: 8 additions & 4 deletions src/neuroconv/tools/nwb_helpers/_dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ def get_existing_dataset_io_configurations(
if any(axis_length == 0 for axis_length in full_shape):
continue

dataset_io_configuration = DatasetIOConfigurationClass.from_existing_neurodata_object(
neurodata_object=column, dataset_name=dataset_name
dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=column,
dataset_name=dataset_name,
mode="existing",
)

yield dataset_io_configuration
Expand All @@ -227,8 +229,10 @@ def get_existing_dataset_io_configurations(
if any(axis_length == 0 for axis_length in full_shape):
continue

dataset_io_configuration = DatasetIOConfigurationClass.from_existing_neurodata_object(
neurodata_object=neurodata_object, dataset_name=known_dataset_field
dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=neurodata_object,
dataset_name=known_dataset_field,
mode="existing",
)

yield dataset_io_configuration

0 comments on commit ce267fb

Please sign in to comment.