From 26746b0b4fad6aaa5eec080f90aaa49ea73ec4a1 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Fri, 17 May 2024 23:22:58 +0000 Subject: [PATCH] feat: np based layer --- docs/source/examples.rst | 4 +- ...chunkable_apply_flow_quick_start_guide.rst | 10 ++--- .../layer/volumetric/cloudvol/test_backend.py | 42 +++++-------------- .../volumetric/tensorstore/test_backend.py | 17 ++++---- tests/unit/layer/volumetric/test_layer.py | 32 +++++++------- tests/unit/layer/volumetric/test_layer_set.py | 8 ++-- tests/unit/layer/volumetric/test_tools.py | 14 ++++--- zetta_utils/internal | 2 +- zetta_utils/layer/backend_base.py | 7 ++-- zetta_utils/layer/db_layer/backend.py | 2 +- zetta_utils/layer/db_layer/layer.py | 2 +- zetta_utils/layer/layer_base.py | 12 ++++-- zetta_utils/layer/layer_set/backend.py | 11 +++-- zetta_utils/layer/layer_set/layer.py | 7 ++-- zetta_utils/layer/volumetric/backend.py | 11 +++-- zetta_utils/layer/volumetric/build.py | 6 ++- .../layer/volumetric/cloudvol/backend.py | 38 +++++++---------- .../layer/volumetric/cloudvol/build.py | 9 ++-- .../layer/volumetric/constant/backend.py | 13 +++--- .../layer/volumetric/constant/build.py | 6 +-- zetta_utils/layer/volumetric/frontend.py | 13 +++--- zetta_utils/layer/volumetric/layer.py | 14 ++++--- .../layer/volumetric/layer_set/backend.py | 12 +++--- .../layer/volumetric/layer_set/build.py | 3 +- .../layer/volumetric/layer_set/layer.py | 21 +++++++--- zetta_utils/layer/volumetric/protocols.py | 1 + .../layer/volumetric/tensorstore/backend.py | 22 ++++------ .../layer/volumetric/tensorstore/build.py | 9 ++-- zetta_utils/layer/volumetric/tools.py | 17 ++++---- .../common/apply_mask_fn.py | 10 ++--- .../common/interpolate_flow.py | 6 +-- .../common/volumetric_apply_flow.py | 26 ++++++++---- .../common/volumetric_callable_operation.py | 5 ++- zetta_utils/tensor_ops/convert.py | 23 ++++++++++ 34 files changed, 236 insertions(+), 199 deletions(-) diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 1a5a61c9d..4b5c73eea 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -98,7 +98,7 @@ Layers for CloudVolume IO: ... ) >>> data = cvl[Vec3D(64, 64, 40), 7500:7564, 2250:2314, 2000:2001] >>> data.shape # channel, x, y, z - torch.Size([1, 64, 64, 1]) + (1, 64, 64, 1) >>> from zetta_utils.layer.volumetric.cloudvol import build_cv_layer @@ -114,7 +114,7 @@ Layers for CloudVolume IO: ... ) >>> data = cvl[120000:121024, 36000:37024, 2000:2001] # (4, 4, 40) indexing >>> data.shape # channel, x, y, z - torch.Size([1, 64, 64, 1]) + (1, 64, 64, 1) Layer sets for grouping layers together: diff --git a/docs/source/subchunkable_apply_flow_quick_start_guide.rst b/docs/source/subchunkable_apply_flow_quick_start_guide.rst index b7e2d037f..cd7cbdec9 100644 --- a/docs/source/subchunkable_apply_flow_quick_start_guide.rst +++ b/docs/source/subchunkable_apply_flow_quick_start_guide.rst @@ -49,7 +49,7 @@ For this guide, we will use FAFB v15, which is a ``precomputed`` dataset, and us ... ) >>> data = cvl[Vec3D(64, 64, 40), 13000:13100, 4000:4100, 2000:2001] >>> data.shape # channel, x, y, z - torch.Size([1, 100, 100, 1]) + (1, 100, 100, 1) .. collapse:: Vec3D @@ -89,7 +89,7 @@ For instance, suppose that you were looking at the FAFB v15 in neuroglancer, and ... ) >>> data = cvl[211200:216000, 64800:69600, 2000:2002] # (4, 4, 40) indexing >>> data.shape # channel, x, y, z at (192, 192, 80) resolution - torch.Size([1, 100, 100, 1]) + (1, 100, 100, 1) This feature can be used to: @@ -171,7 +171,7 @@ Using ``VolumetricIndex``, the first example above becomes: ... ) >>> data = cvl[idx] >>> data.shape # channel, x, y, z - torch.Size([1, 100, 100, 1]) + (1, 100, 100, 1) .. note:: Since ``VolumetricIndex`` already contains the resolution information, the ``index_resolution`` provided at the initialisation of ``VolumetricLayer`` is overridden when indexing into it using a ``VolumetricIndex``. @@ -273,9 +273,9 @@ We initialise the ``VolumetricLayer`` with this ``DataProcessor``, and compare t ... resolution = Vec3D(64, 64, 40) ... ) >>> cvl_without_proc[idx].min() - tensor(-2.9492) + -2.9491823 >>> cvl_with_proc[idx].min() - tensor(0.) + 0.0 This ``VolumetricLayer`` will now apply the ``__call__`` from the ``ThresholdProcessor`` before returning the output for each read. diff --git a/tests/unit/layer/volumetric/cloudvol/test_backend.py b/tests/unit/layer/volumetric/cloudvol/test_backend.py index e75cf61d0..c7a274c46 100644 --- a/tests/unit/layer/volumetric/cloudvol/test_backend.py +++ b/tests/unit/layer/volumetric/cloudvol/test_backend.py @@ -4,7 +4,6 @@ import numpy as np import pytest -import torch from zetta_utils.geometry import BBox3D, IntVec3D, Vec3D from zetta_utils.layer.volumetric import VolumetricIndex @@ -40,14 +39,7 @@ def test_cv_backend_dtype(clear_caches_reset_mocks): info_spec = PrecomputedInfoSpec(reference_path=LAYER_X0_PATH, data_type="uint8") cvb = CVBackend(path=LAYER_X0_PATH, info_spec=info_spec, on_info_exists="overwrite") - assert cvb.dtype == torch.uint8 - - -def test_cv_backend_dtype_exc(clear_caches_reset_mocks): - info_spec = PrecomputedInfoSpec(reference_path=LAYER_X0_PATH, data_type="nonsense") - cvb = CVBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - with pytest.raises(ValueError): - cvb.dtype + assert cvb.dtype == np.dtype("uint8") def test_cv_backend_info_expect_same_exc(clear_caches_reset_mocks, mocker): @@ -101,8 +93,8 @@ def test_cv_backend_info_overwrite(clear_caches_reset_mocks, path, reference, mo def test_cv_backend_read(clear_caches_reset_mocks, mocker): - data_read = torch.ones([3, 4, 5, 2]) - expected = torch.ones([2, 3, 4, 5]) + data_read = np.ones([3, 4, 5, 2]) + expected = np.ones([2, 3, 4, 5]) cv_m = mocker.MagicMock() cv_m.__getitem__ = mocker.MagicMock(return_value=data_read) mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) @@ -125,8 +117,8 @@ def test_cv_backend_write(clear_caches_reset_mocks, mocker): cv_m.__setitem__ = mocker.MagicMock() mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) cvb = CVBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.ones([2, 3, 4, 5]) - expected_written = torch.ones([3, 4, 5, 2]) # channel as ch 0 + value = np.ones([2, 3, 4, 5]) + expected_written = np.ones([3, 4, 5, 2]) # channel as ch 0 index = VolumetricIndex( bbox=BBox3D.from_slices((slice(0, 1), slice(1, 2), slice(2, 3))), @@ -148,7 +140,7 @@ def test_cv_backend_write_scalar(clear_caches_reset_mocks, mocker): cv_m.__setitem__ = mocker.MagicMock() mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) cvb = CVBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.tensor([1]) + value = np.array([1]) expected_written = 1 index = VolumetricIndex( @@ -163,7 +155,7 @@ def test_cv_backend_write_scalar(clear_caches_reset_mocks, mocker): def test_cv_backend_read_uint63(clear_caches_reset_mocks, mocker): data_read = np.array([[[[2 ** 63 - 1]]]], dtype=np.uint64) - expected = torch.tensor([[[[2 ** 63 - 1]]]], dtype=torch.int64) + expected = np.array([[[[2 ** 63 - 1]]]], dtype=np.int64) cv_m = mocker.MagicMock() cv_m.__getitem__ = mocker.MagicMock(return_value=data_read) mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) @@ -177,20 +169,6 @@ def test_cv_backend_read_uint63(clear_caches_reset_mocks, mocker): cv_m.__getitem__.assert_called_with(index.bbox.to_slices(index.resolution)) -def test_cv_backend_read_uint64_exc(clear_caches_reset_mocks, mocker): - data_read = np.array([[[[2 ** 63 + 1]]]], dtype=np.uint64) - cv_m = mocker.MagicMock() - cv_m.__getitem__ = mocker.MagicMock(return_value=data_read) - mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) - cvb = CVBackend(path=LAYER_UINT63_0_PATH) - index = VolumetricIndex( - bbox=BBox3D.from_slices((slice(0, 1), slice(0, 1), slice(0, 1))), - resolution=Vec3D(1, 1, 1), - ) - with pytest.raises(ValueError): - cvb.read(index) - - def test_cv_backend_write_scalar_uint63(clear_caches_reset_mocks, mocker): info_spec = PrecomputedInfoSpec( reference_path=LAYER_UINT63_0_PATH, @@ -200,7 +178,7 @@ def test_cv_backend_write_scalar_uint63(clear_caches_reset_mocks, mocker): cv_m.dtype = "uint64" mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) cvb = CVBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.tensor([2 ** 63 - 1], dtype=torch.int64) + value = np.array([2 ** 63 - 1], dtype=np.int64) expected_written = np.uint64(2 ** 63 - 1) index = VolumetricIndex( @@ -222,7 +200,7 @@ def test_cv_backend_write_scalar_uint63_exc(clear_caches_reset_mocks, mocker): cv_m.dtype = "uint64" mocker.patch("cloudvolume.CloudVolume.__new__", return_value=cv_m) cvb = CVBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.tensor([-1], dtype=torch.int64) + value = np.array([-1], dtype=np.int64) index = VolumetricIndex( bbox=BBox3D.from_slices((slice(0, 1), slice(0, 1), slice(0, 1))), @@ -236,7 +214,7 @@ def test_cv_backend_write_scalar_uint63_exc(clear_caches_reset_mocks, mocker): "data_in,expected_exc", [ # Too many dims - [torch.ones((1, 2, 3, 4, 5, 6)), ValueError], + [np.ones((1, 2, 3, 4, 5, 6)), ValueError], ], ) def test_cv_backend_write_exc(clear_caches_reset_mocks, data_in, expected_exc, mocker): diff --git a/tests/unit/layer/volumetric/tensorstore/test_backend.py b/tests/unit/layer/volumetric/tensorstore/test_backend.py index 79c63805f..fdd33c5b3 100644 --- a/tests/unit/layer/volumetric/tensorstore/test_backend.py +++ b/tests/unit/layer/volumetric/tensorstore/test_backend.py @@ -6,7 +6,6 @@ import numpy as np import pytest import tensorstore -import torch from zetta_utils.geometry import BBox3D, IntVec3D, Vec3D from zetta_utils.layer.volumetric import ( @@ -48,7 +47,7 @@ def test_ts_backend_dtype(clear_caches_reset_mocks): info_spec = PrecomputedInfoSpec(reference_path=LAYER_X0_PATH, data_type="uint8") tsb = TSBackend(path=LAYER_X0_PATH, info_spec=info_spec, on_info_exists="overwrite") - assert tsb.dtype == torch.uint8 + assert tsb.dtype == np.dtype("uint8") def test_ts_backend_dtype_exc(clear_caches_reset_mocks): @@ -148,8 +147,8 @@ def test_ts_backend_read_partial(clear_caches_reset_mocks, mocker): resolution=Vec3D(1, 1, 1), ) result = tsb.read(index) - assert result[:, 0:1, :, :] == torch.zeros((1, 1, 1, 1), dtype=torch.uint8) - assert result[:, 1:2, :, :] == torch.ones((1, 1, 1, 1), dtype=torch.uint8) + assert result[:, 0:1, :, :] == np.zeros((1, 1, 1, 1), dtype=np.uint8) + assert result[:, 1:2, :, :] == np.ones((1, 1, 1, 1), dtype=np.uint8) def test_ts_backend_write_idx(clear_caches_reset_mocks, mocker): @@ -160,7 +159,7 @@ def test_ts_backend_write_idx(clear_caches_reset_mocks, mocker): default_voxel_offset=IntVec3D(1, 2, 3), ) tsb = TSBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.ones([1, 3, 5, 7], dtype=torch.uint8) + value = np.ones([1, 3, 5, 7], dtype=np.uint8) index = VolumetricIndex( bbox=BBox3D.from_slices((slice(1, 4), slice(2, 7), slice(3, 10))), @@ -185,7 +184,7 @@ def test_ts_backend_write_idx_partial(clear_caches_reset_mocks, mocker): on_info_exists="overwrite", enforce_chunk_aligned_writes=False, ) - value = torch.ones([1, 3, 4, 5], dtype=torch.uint8) + value = np.ones([1, 3, 4, 5], dtype=np.uint8) index = VolumetricIndex( bbox=BBox3D.from_slices((slice(-2, 1), slice(-3, 1), slice(-4, 1))), @@ -205,7 +204,7 @@ def test_ts_backend_write_scalar_idx(clear_caches_reset_mocks, mocker): default_voxel_offset=IntVec3D(1, 2, 3), ) tsb = TSBackend(path=LAYER_SCRATCH0_PATH, info_spec=info_spec, on_info_exists="overwrite") - value = torch.tensor([1], dtype=torch.uint8) + value = np.array([1], dtype=np.uint8) index = VolumetricIndex( bbox=BBox3D.from_slices((slice(1, 4), slice(2, 7), slice(3, 10))), @@ -221,7 +220,7 @@ def test_ts_backend_write_scalar_idx(clear_caches_reset_mocks, mocker): "data_in,expected_exc", [ # Too many dims - [torch.ones((1, 2, 3, 4, 5, 6)), ValueError], + [np.ones((1, 2, 3, 4, 5, 6)), ValueError], ], ) def test_ts_backend_write_exc_dims(data_in, expected_exc, clear_caches_reset_mocks, mocker): @@ -243,7 +242,7 @@ def test_ts_backend_write_exc_dims(data_in, expected_exc, clear_caches_reset_moc "data_in,expected_exc", [ # idx not chunk aligned - [torch.ones((3, 3, 3, 3), dtype=torch.uint8), ValueError], + [np.ones((3, 3, 3, 3), dtype=np.uint8), ValueError], ], ) def test_ts_backend_write_exc_chunks(data_in, expected_exc, clear_caches_reset_mocks, mocker): diff --git a/tests/unit/layer/volumetric/test_layer.py b/tests/unit/layer/volumetric/test_layer.py index 8b8d08955..b7cd90c69 100644 --- a/tests/unit/layer/volumetric/test_layer.py +++ b/tests/unit/layer/volumetric/test_layer.py @@ -1,8 +1,8 @@ # pylint: disable=missing-docstring,redefined-outer-name,unused-argument,pointless-statement,line-too-long,protected-access,unsubscriptable-object from __future__ import annotations +import numpy as np import pytest -import torch from zetta_utils.geometry import BBox3D, Vec3D from zetta_utils.layer.volumetric import ( @@ -11,10 +11,12 @@ build_volumetric_layer, ) +from ...helpers import assert_array_equal + def test_data_resolution_read_interp(mocker): backend = mocker.MagicMock() - backend.read = mocker.MagicMock(return_value=torch.ones((2, 2, 2, 2)) * 2) + backend.read = mocker.MagicMock(return_value=np.ones((2, 2, 2, 2)) * 2) layer = build_volumetric_layer( backend, @@ -32,9 +34,9 @@ def test_data_resolution_read_interp(mocker): bbox=BBox3D.from_slices((slice(0, 3), slice(0, 3), slice(0, 3))), ) ) - assert torch.equal( + assert_array_equal( read_data, - torch.ones((2, 1, 1, 1)), + np.ones((2, 1, 1, 1)), ) @@ -60,11 +62,11 @@ def test_data_resolution_write_interp(mocker): bbox=BBox3D.from_slices((slice(0, 4), slice(0, 4), slice(0, 4))), ) - layer[0:4, 0:4, 0:4] = torch.ones((2, 1, 1, 1)) + layer[0:4, 0:4, 0:4] = np.ones((2, 1, 1, 1)) assert backend.write.call_args.kwargs["idx"] == idx - assert torch.equal( + assert_array_equal( backend.write.call_args.kwargs["data"], - torch.ones((2, 2, 2, 2)) * 2, + np.ones((2, 2, 2, 2)) * 2, ) @@ -79,9 +81,9 @@ def test_write_scalar(mocker): ) layer[0:1, 0:1, 0:1] = 1.0 - assert torch.equal( + assert_array_equal( backend.write.call_args.kwargs["data"], - torch.Tensor([1]), + np.array([1]), ) @@ -97,16 +99,16 @@ def test_write_scalar_with_processor(mocker): ) layer[0:1, 0:1, 0:1] = 1.0 - assert torch.equal( + assert_array_equal( backend.write.call_args.kwargs["data"], - torch.Tensor([2]), + np.array([2]), ) def test_read_write_with_idx_processor(mocker): backend = mocker.MagicMock() backend.write = mocker.MagicMock() - backend.read = mocker.MagicMock(return_value=torch.ones((2, 1, 1, 1))) + backend.read = mocker.MagicMock(return_value=np.ones((2, 1, 1, 1))) layer = build_volumetric_layer( backend, @@ -122,14 +124,14 @@ def test_read_write_with_idx_processor(mocker): bbox=BBox3D.from_slices((slice(2, 4), slice(4, 6), slice(8, 10))), ) layer[0:1, 0:1, 0:1] = 1.0 - assert torch.equal( + assert_array_equal( backend.write.call_args.kwargs["data"], - torch.Tensor([2]), + np.array([2]), ) assert backend.write.call_args.kwargs["idx"] == expected_idx data_read = layer[0:1, 0:1, 0:1] - assert torch.equal(data_read, torch.zeros((2, 1, 1, 1))) + assert_array_equal(data_read, np.zeros((2, 1, 1, 1))) assert backend.write.call_args.kwargs["idx"] == expected_idx diff --git a/tests/unit/layer/volumetric/test_layer_set.py b/tests/unit/layer/volumetric/test_layer_set.py index bea2e3d94..8eb39035b 100644 --- a/tests/unit/layer/volumetric/test_layer_set.py +++ b/tests/unit/layer/volumetric/test_layer_set.py @@ -1,4 +1,4 @@ -import torch +import numpy as np from zetta_utils.geometry import BBox3D, Vec3D from zetta_utils.layer.volumetric import VolumetricIndex, build_volumetric_layer_set @@ -6,13 +6,13 @@ def test_read(mocker): layer_a = mocker.MagicMock() - layer_a.read_with_procs = mocker.MagicMock(return_value=torch.tensor([1])) + layer_a.read_with_procs = mocker.MagicMock(return_value=np.array([1])) layer_b = mocker.MagicMock() - layer_b.read_with_procs = mocker.MagicMock(return_value=torch.tensor([2])) + layer_b.read_with_procs = mocker.MagicMock(return_value=np.array([2])) layer_set = build_volumetric_layer_set(layers={"a": layer_a, "b": layer_b}) idx = VolumetricIndex(bbox=BBox3D(bounds=((0, 1), (0, 1), (0, 1))), resolution=Vec3D(1, 1, 1)) result = layer_set[idx] - assert result == {"a": torch.Tensor([1]), "b": torch.Tensor([2])} + assert result == {"a": np.array([1]), "b": np.array([2])} layer_a.read_with_procs.called_with(idx=idx) layer_b.read_with_procs.called_with(idx=idx) diff --git a/tests/unit/layer/volumetric/test_tools.py b/tests/unit/layer/volumetric/test_tools.py index 39deba017..73674f444 100644 --- a/tests/unit/layer/volumetric/test_tools.py +++ b/tests/unit/layer/volumetric/test_tools.py @@ -1,6 +1,6 @@ # pylint: disable=missing-docstring,redefined-outer-name,unused-argument,pointless-statement,line-too-long,protected-access,too-few-public-methods +import numpy as np import pytest -import torch from zetta_utils.geometry import BBox3D, IntVec3D, Vec3D from zetta_utils.layer.volumetric import ( @@ -166,9 +166,11 @@ def test_roi_mask_processor_read( ) processor.process_index(idx, "read") - data = {target: torch.rand(*data_shape) for target in targets} + data = {target: np.random.rand(*data_shape).astype(np.float32) for target in targets} for target in existing_masks: - data[target + "_mask"] = torch.ones(*data_shape) # Pre-existing mask for the target + data[target + "_mask"] = np.ones(data_shape).astype( + np.float32 + ) # Pre-existing mask for the target processed_data = processor.process_data(data, "read") @@ -176,13 +178,13 @@ def test_roi_mask_processor_read( assert target in processed_data if target in existing_masks: - assert torch.all(processed_data[target + "_mask"] == data[target + "_mask"]) + assert np.all(processed_data[target + "_mask"] == data[target + "_mask"]) else: mask = processed_data[target + "_mask"] inside_roi = mask[ 0, expected_mask_region[0], expected_mask_region[1], expected_mask_region[2] ] - assert torch.all(inside_roi == 1) + assert np.all(inside_roi == 1) full_slice = slice(None) outside_roi_slices = tuple( @@ -192,4 +194,4 @@ def test_roi_mask_processor_read( outside_roi = mask[ 0, outside_roi_slices[0], outside_roi_slices[1], outside_roi_slices[2] ] - assert torch.all(outside_roi == 0) + assert np.all(outside_roi == 0) diff --git a/zetta_utils/internal b/zetta_utils/internal index b5f4b532e..a60d4ddc3 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit b5f4b532e265b5c172d60f48b6ca07b2487cfe5a +Subproject commit a60d4ddc339b179bcdccedc748384a089e826a22 diff --git a/zetta_utils/layer/backend_base.py b/zetta_utils/layer/backend_base.py index 576d234cd..1421fe700 100644 --- a/zetta_utils/layer/backend_base.py +++ b/zetta_utils/layer/backend_base.py @@ -6,9 +6,10 @@ IndexT = TypeVar("IndexT") DataT = TypeVar("DataT") +DataWriteT = TypeVar("DataWriteT") -class Backend(ABC, Generic[IndexT, DataT]): # pylint: disable=too-few-public-methods +class Backend(ABC, Generic[IndexT, DataT, DataWriteT]): # pylint: disable=too-few-public-methods @property @abstractmethod def name(self) -> str: @@ -19,11 +20,11 @@ def read(self, idx: IndexT) -> DataT: """Reads data from the given index""" @abstractmethod - def write(self, idx: IndexT, data: DataT): + def write(self, idx: IndexT, data: DataWriteT): """Writes given data to the given index""" @abstractmethod - def with_changes(self, **kwargs) -> Backend[IndexT, DataT]: # pragma: no cover + def with_changes(self, **kwargs) -> Backend[IndexT, DataT, DataWriteT]: # pragma: no cover """Remakes the Layer with the requested backend changes. The kwargs are not typed since the implementation is currently based on `attrs.evolve` and the base Backend class does not have any attrs, leaving all implementation to the inherited diff --git a/zetta_utils/layer/db_layer/backend.py b/zetta_utils/layer/db_layer/backend.py index b200a2865..c7c411430 100644 --- a/zetta_utils/layer/db_layer/backend.py +++ b/zetta_utils/layer/db_layer/backend.py @@ -14,7 +14,7 @@ DBDataT = Sequence[DBRowDataT] -class DBBackend(Backend[DBIndex, DBDataT]): # pylint: disable=too-few-public-methods +class DBBackend(Backend[DBIndex, DBDataT, DBDataT]): # pylint: disable=too-few-public-methods @abstractmethod def __contains__(self, idx: str) -> bool: ... diff --git a/zetta_utils/layer/db_layer/layer.py b/zetta_utils/layer/db_layer/layer.py index 1a7e08703..7ba3d26de 100644 --- a/zetta_utils/layer/db_layer/layer.py +++ b/zetta_utils/layer/db_layer/layer.py @@ -29,7 +29,7 @@ def is_rowdata_seq(values: Sequence[Any]) -> TypeGuard[Sequence[DBRowDataT]]: @attrs.mutable -class DBLayer(Layer[DBIndex, DBDataT]): +class DBLayer(Layer[DBIndex, DBDataT, DBDataT]): backend: DBBackend readonly: bool = False diff --git a/zetta_utils/layer/layer_base.py b/zetta_utils/layer/layer_base.py index 488966395..fbc35f46f 100644 --- a/zetta_utils/layer/layer_base.py +++ b/zetta_utils/layer/layer_base.py @@ -10,13 +10,14 @@ BackendIndexT = TypeVar("BackendIndexT") BackendDataT = TypeVar("BackendDataT") +BackendDataWriteT = TypeVar("BackendDataWriteT") BackendT = TypeVar("BackendT", bound=Backend) LayerT = TypeVar("LayerT", bound="Layer") @attrs.frozen -class Layer(Generic[BackendIndexT, BackendDataT]): - backend: Backend[BackendIndexT, BackendDataT] +class Layer(Generic[BackendIndexT, BackendDataT, BackendDataWriteT]): + backend: Backend[BackendIndexT, BackendDataT, BackendDataWriteT] readonly: bool = False index_procs: tuple[IndexProcessor[BackendIndexT], ...] = () @@ -25,7 +26,10 @@ class Layer(Generic[BackendIndexT, BackendDataT]): ..., ] = () write_procs: tuple[ - Union[DataProcessor[BackendDataT], JointIndexDataProcessor[BackendDataT, BackendIndexT]], + Union[ + DataProcessor[BackendDataWriteT], + JointIndexDataProcessor[BackendDataWriteT, BackendIndexT], + ], ..., ] = () @@ -55,7 +59,7 @@ def read_with_procs( def write_with_procs( self, idx: BackendIndexT, - data: BackendDataT, + data: BackendDataWriteT, ): if self.readonly: raise IOError(f"Attempting to write to a read only layer {self}") diff --git a/zetta_utils/layer/layer_set/backend.py b/zetta_utils/layer/layer_set/backend.py index 893732841..b7c0dc7b2 100644 --- a/zetta_utils/layer/layer_set/backend.py +++ b/zetta_utils/layer/layer_set/backend.py @@ -9,16 +9,19 @@ IndexT = TypeVar("IndexT") DataT = TypeVar("DataT") +DataWriteT = TypeVar("DataWriteT") @attrs.frozen -class LayerSetBackend(Backend[IndexT, dict[str, DataT]]): # pylint: disable=too-few-public-methods - layers: dict[str, Layer[IndexT, DataT]] +class LayerSetBackend( + Backend[IndexT, dict[str, DataT], dict[str, DataWriteT]] +): # pylint: disable=too-few-public-methods + layers: dict[str, Layer[IndexT, DataT, DataWriteT]] def read(self, idx: IndexT) -> dict[str, DataT]: return {k: v.read_with_procs(idx) for k, v in self.layers.items()} - def write(self, idx: IndexT, data: dict[str, DataT]): + def write(self, idx: IndexT, data: dict[str, DataWriteT]): for k, v in data.items(): self.layers[k].write_with_procs(idx, v) @@ -26,5 +29,5 @@ def write(self, idx: IndexT, data: dict[str, DataT]): def name(self) -> str: return f"LayerSet[f{'_'.join(self.layers.keys())}]" # pragma: no cover - def with_changes(self, **kwargs) -> LayerSetBackend[IndexT, DataT]: + def with_changes(self, **kwargs) -> LayerSetBackend[IndexT, DataT, DataWriteT]: return attrs.evolve(self, **kwargs) # pragma: no cover diff --git a/zetta_utils/layer/layer_set/layer.py b/zetta_utils/layer/layer_set/layer.py index 634d3eab3..9ea057a93 100644 --- a/zetta_utils/layer/layer_set/layer.py +++ b/zetta_utils/layer/layer_set/layer.py @@ -7,6 +7,7 @@ IndexT = TypeVar("IndexT") DataT = TypeVar("DataT") +DataWriteT = TypeVar("DataWriteT") LayerSetDataProcT = Union[ DataProcessor[dict[str, DataT]], @@ -15,8 +16,8 @@ @attrs.frozen -class LayerSet(Layer[IndexT, dict[str, DataT]]): - backend: LayerSetBackend[IndexT, DataT] +class LayerSet(Layer[IndexT, dict[str, DataT], dict[str, DataWriteT]]): + backend: LayerSetBackend[IndexT, DataT, DataWriteT] readonly: bool = False @@ -27,5 +28,5 @@ class LayerSet(Layer[IndexT, dict[str, DataT]]): def __getitem__(self, idx: IndexT) -> dict[str, DataT]: return self.read_with_procs(idx=idx) - def __setitem__(self, idx: IndexT, data: dict[str, DataT]): + def __setitem__(self, idx: IndexT, data: dict[str, DataWriteT]): self.write_with_procs(idx=idx, data=data) diff --git a/zetta_utils/layer/volumetric/backend.py b/zetta_utils/layer/volumetric/backend.py index 42d17146a..e48c1b8bb 100644 --- a/zetta_utils/layer/volumetric/backend.py +++ b/zetta_utils/layer/volumetric/backend.py @@ -5,7 +5,7 @@ from typing import Literal, TypeVar import attrs -import torch +import numpy as np from zetta_utils.geometry import Vec3D @@ -13,10 +13,13 @@ from . import VolumetricIndex DataT = TypeVar("DataT") +DataWriteT = TypeVar("DataWriteT") @attrs.mutable -class VolumetricBackend(Backend[VolumetricIndex, DataT]): # pylint: disable=too-few-public-methods +class VolumetricBackend( + Backend[VolumetricIndex, DataT, DataWriteT] +): # pylint: disable=too-few-public-methods @property @abstractmethod def is_local(self) -> bool: @@ -24,7 +27,7 @@ def is_local(self) -> bool: @property @abstractmethod - def dtype(self) -> torch.dtype: + def dtype(self) -> np.dtype: ... @property @@ -75,7 +78,7 @@ def get_dataset_size(self, resolution: Vec3D) -> Vec3D[int]: "dataest_size_res" = (dataset_size, resolution): Tuple[Vec3D[int], Vec3D] """ - def with_changes(self, **kwargs) -> VolumetricBackend[DataT]: + def with_changes(self, **kwargs) -> VolumetricBackend[DataT, DataWriteT]: return attrs.evolve(self, **kwargs) # pragma: no cover @abstractmethod diff --git a/zetta_utils/layer/volumetric/build.py b/zetta_utils/layer/volumetric/build.py index 4e9420448..0bbbe4be7 100644 --- a/zetta_utils/layer/volumetric/build.py +++ b/zetta_utils/layer/volumetric/build.py @@ -5,6 +5,7 @@ from typing import Iterable, Sequence import torch +from numpy import typing as npt from typeguard import typechecked from zetta_utils.geometry import Vec3D @@ -31,10 +32,11 @@ def build_volumetric_layer( readonly: bool = False, index_procs: Iterable[IndexProcessor[VolumetricIndex]] = (), read_procs: Iterable[ - DataProcessor[torch.Tensor] | JointIndexDataProcessor[torch.Tensor, VolumetricIndex] + DataProcessor[npt.NDArray] | JointIndexDataProcessor[npt.NDArray, VolumetricIndex] ] = (), write_procs: Iterable[ - DataProcessor[torch.Tensor] | JointIndexDataProcessor[torch.Tensor, VolumetricIndex] + DataProcessor[npt.NDArray | torch.Tensor] + | JointIndexDataProcessor[torch.Tensor | npt.NDArray, VolumetricIndex] ] = (), ) -> VolumetricLayer: """Build a Volumetric Layer. diff --git a/zetta_utils/layer/volumetric/cloudvol/backend.py b/zetta_utils/layer/volumetric/cloudvol/backend.py index f315943ac..44820f900 100644 --- a/zetta_utils/layer/volumetric/cloudvol/backend.py +++ b/zetta_utils/layer/volumetric/cloudvol/backend.py @@ -8,10 +8,9 @@ import cachetools import cloudvolume as cv import numpy as np -import torch from cloudvolume import CloudVolume +from numpy import typing as npt -from zetta_utils import tensor_ops from zetta_utils.common import abspath, is_local from zetta_utils.geometry import Vec3D @@ -78,8 +77,8 @@ def _clear_cv_cache(path: str | None = None) -> None: # pragma: no cover class CVBackend(VolumetricBackend): # pylint: disable=too-few-public-methods """ Backend for peforming IO on Neuroglancer datasts using CloudVolume library. - Read data will be a ``torch.Tensor`` in ``CXYZ`` dimension order. - Write data is expected to be a ``torch.Tensor`` or ``np.ndarray`` in ``CXYZ`` + Read data will be a ``npt.NDArray`` in ``CXYZ`` dimension order. + Write data is expected to be a ``npt.NDArray`` or ``np.ndarray`` in ``CXYZ`` dimension order. :param path: CloudVolume path. Can be given as relative or absolute. :param cv_kwargs: Parameters that will be passed to the CloudVolume constructor. @@ -137,17 +136,10 @@ def name(self, name: str) -> None: # pragma: no cover ) @property - def dtype(self) -> torch.dtype: + def dtype(self) -> np.dtype: result = _get_cv_cached(self.path, **self.cv_kwargs) - dtype = result.data_type - try: - return getattr(torch, dtype) - except Exception as e: - raise ValueError( # pylint: disable=raise-missing-from - f"CVBackend has data_type '{dtype}'," - " which cannot be parsed as a valid torch dtype." - ) from e + return result.data_type @property def num_channels(self) -> int: # pragma: no cover @@ -200,28 +192,26 @@ def clear_disk_cache(self) -> None: # pragma: no cover def clear_cache(self) -> None: # pragma: no cover _clear_cv_cache(self.path) - def read(self, idx: VolumetricIndex) -> torch.Tensor: + def read(self, idx: VolumetricIndex) -> npt.NDArray: # Data out: cxyz cvol = _get_cv_cached(self.path, idx.resolution, **self.cv_kwargs) data_raw = cvol[idx.to_slices()] - result_np = np.transpose(data_raw, (3, 0, 1, 2)) - result = tensor_ops.to_torch(result_np) - return result + result = np.transpose(data_raw, (3, 0, 1, 2)) + return np.array(result) - def write(self, idx: VolumetricIndex, data: torch.Tensor): + def write(self, idx: VolumetricIndex, data: npt.NDArray): # Data in: cxyz # Write format: xyzc (b == 1) - data_np = tensor_ops.convert.to_np(data) - if data_np.size == 1 and len(data_np.shape) == 1: - data_final = data_np[0] - elif len(data_np.shape) == 4: - data_final = np.transpose(data_np, (1, 2, 3, 0)) + if data.size == 1 and len(data.shape) == 1: + data_final = data[0] + elif len(data.shape) == 4: + data_final = np.transpose(data, (1, 2, 3, 0)) else: raise ValueError( "Data written to CloudVolume backend must be in `cxyz` dimension format, " - f"but got a tensor of with ndim == {data_np.ndim}" + f"but got a tensor of with ndim == {data.ndim}" ) cvol = _get_cv_cached(self.path, idx.resolution, **self.cv_kwargs) diff --git a/zetta_utils/layer/volumetric/cloudvol/build.py b/zetta_utils/layer/volumetric/cloudvol/build.py index ed7125693..1e5c875f1 100644 --- a/zetta_utils/layer/volumetric/cloudvol/build.py +++ b/zetta_utils/layer/volumetric/cloudvol/build.py @@ -4,6 +4,7 @@ from typing import Any, Iterable, Literal, Sequence, Union import torch +from numpy import typing as npt from zetta_utils import builder from zetta_utils.tensor_ops import InterpolationMode @@ -42,14 +43,14 @@ def build_cv_layer( # pylint: disable=too-many-locals index_procs: Iterable[IndexProcessor[VolumetricIndex]] = (), read_procs: Iterable[ Union[ - DataProcessor[torch.Tensor], - JointIndexDataProcessor[torch.Tensor, VolumetricIndex], + DataProcessor[npt.NDArray], + JointIndexDataProcessor[npt.NDArray, VolumetricIndex], ] ] = (), write_procs: Iterable[ Union[ - DataProcessor[torch.Tensor], - JointIndexDataProcessor[torch.Tensor, VolumetricIndex], + DataProcessor[npt.NDArray | torch.Tensor], + JointIndexDataProcessor[npt.NDArray | torch.Tensor, VolumetricIndex], ] ] = (), ) -> VolumetricLayer: # pragma: no cover # trivial conditional, delegation only diff --git a/zetta_utils/layer/volumetric/constant/backend.py b/zetta_utils/layer/volumetric/constant/backend.py index 39de6903e..f838b30c5 100644 --- a/zetta_utils/layer/volumetric/constant/backend.py +++ b/zetta_utils/layer/volumetric/constant/backend.py @@ -4,7 +4,8 @@ from typing import Literal, Union import attrs -import torch +import numpy as np +from numpy import typing as npt from zetta_utils.geometry import Vec3D @@ -33,8 +34,8 @@ def name(self, name: str) -> None: # pragma: no cover raise NotImplementedError("cannot set `name` for `ConstantVolumetricBackend` directly;") @property - def dtype(self) -> torch.dtype: # pragma: no cover - return torch.float + def dtype(self) -> np.dtype: # pragma: no cover + return np.dtype("float32") @property def is_local(self) -> bool: # pragma: no cover @@ -73,11 +74,11 @@ def use_compression(self, value: bool) -> None: # pragma: no cover def clear_cache(self) -> None: # pragma: no cover pass - def read(self, idx: VolumetricIndex) -> torch.Tensor: + def read(self, idx: VolumetricIndex) -> npt.NDArray: # Data out: cxyz slices = idx.to_slices() result = ( - torch.ones( + np.ones( ( self.num_channels, slices[0].stop - slices[0].start, @@ -89,7 +90,7 @@ def read(self, idx: VolumetricIndex) -> torch.Tensor: ) return result - def write(self, idx: VolumetricIndex, data: torch.Tensor): # pragma: no cover + def write(self, idx: VolumetricIndex, data: npt.NDArray): # pragma: no cover raise RuntimeError("cannot perform `write` operation on a ConstantVolumetricBackend") def with_changes(self, **kwargs) -> ConstantVolumetricBackend: # pragma: no cover diff --git a/zetta_utils/layer/volumetric/constant/build.py b/zetta_utils/layer/volumetric/constant/build.py index 6f9009350..9139e3285 100644 --- a/zetta_utils/layer/volumetric/constant/build.py +++ b/zetta_utils/layer/volumetric/constant/build.py @@ -3,7 +3,7 @@ from typing import Iterable, Sequence, Union -import torch +from numpy import typing as npt from zetta_utils import builder from zetta_utils.tensor_ops import InterpolationMode @@ -27,8 +27,8 @@ def build_constant_volumetric_layer( # pylint: disable=too-many-locals index_procs: Iterable[IndexProcessor[VolumetricIndex]] = (), read_procs: Iterable[ Union[ - DataProcessor[torch.Tensor], - JointIndexDataProcessor[torch.Tensor, VolumetricIndex], + DataProcessor[npt.NDArray], + JointIndexDataProcessor[npt.NDArray, VolumetricIndex], ] ] = (), ) -> VolumetricLayer: # pragma: no cover # trivial conditional, delegation only diff --git a/zetta_utils/layer/volumetric/frontend.py b/zetta_utils/layer/volumetric/frontend.py index 1159a42be..743e4ceda 100644 --- a/zetta_utils/layer/volumetric/frontend.py +++ b/zetta_utils/layer/volumetric/frontend.py @@ -3,6 +3,7 @@ from typing import Optional, Union import attrs +import numpy as np import torch from numpy import typing as npt @@ -109,17 +110,17 @@ def convert_write( self, idx_user: UserVolumetricIndex, data_user: npt.NDArray | torch.Tensor | float | int | bool, - ) -> tuple[VolumetricIndex, torch.Tensor]: + ) -> tuple[VolumetricIndex, npt.NDArray]: idx = self.convert_idx(idx_user) if isinstance(data_user, (float, int)): dtype_mapping = { - float: torch.float32, - int: torch.int32, - bool: torch.int32, + float: np.dtype("float32"), + int: np.dtype("int32"), + bool: np.dtype("int32"), } dtype = dtype_mapping[type(data_user)] - data = torch.Tensor([data_user]).to(dtype) + data = np.array([data_user]).astype(dtype) else: - data = tensor_ops.convert.to_torch(data_user) + data = tensor_ops.convert.to_np(data_user) return idx, data diff --git a/zetta_utils/layer/volumetric/layer.py b/zetta_utils/layer/volumetric/layer.py index 6ba25762a..7bf6dc513 100644 --- a/zetta_utils/layer/volumetric/layer.py +++ b/zetta_utils/layer/volumetric/layer.py @@ -15,21 +15,25 @@ ) VolumetricDataProcT = Union[ - DataProcessor[torch.Tensor], JointIndexDataProcessor[torch.Tensor, VolumetricIndex] + DataProcessor[npt.NDArray], JointIndexDataProcessor[npt.NDArray, VolumetricIndex] +] +VolumetricDataWriteProcT = Union[ + DataProcessor[npt.NDArray | torch.Tensor], + JointIndexDataProcessor[npt.NDArray | torch.Tensor, VolumetricIndex], ] @attrs.frozen -class VolumetricLayer(Layer[VolumetricIndex, torch.Tensor]): - backend: VolumetricBackend[torch.Tensor] +class VolumetricLayer(Layer[VolumetricIndex, npt.NDArray, npt.NDArray | torch.Tensor]): + backend: VolumetricBackend[npt.NDArray, npt.NDArray | torch.Tensor] frontend: VolumetricFrontend readonly: bool = False index_procs: tuple[IndexProcessor[VolumetricIndex], ...] = () read_procs: tuple[VolumetricDataProcT, ...] = () - write_procs: tuple[VolumetricDataProcT, ...] = () + write_procs: tuple[VolumetricDataWriteProcT, ...] = () - def __getitem__(self, idx: UserVolumetricIndex) -> torch.Tensor: + def __getitem__(self, idx: UserVolumetricIndex) -> npt.NDArray: idx_backend = self.frontend.convert_idx(idx) return self.read_with_procs(idx=idx_backend) diff --git a/zetta_utils/layer/volumetric/layer_set/backend.py b/zetta_utils/layer/volumetric/layer_set/backend.py index 156d10729..5dc5e1ce6 100644 --- a/zetta_utils/layer/volumetric/layer_set/backend.py +++ b/zetta_utils/layer/volumetric/layer_set/backend.py @@ -1,10 +1,12 @@ # pylint: disable=missing-docstring from __future__ import annotations -from typing import Literal +from typing import Literal, Mapping import attrs +import numpy as np import torch +from numpy import typing as npt from zetta_utils.geometry import Vec3D @@ -13,7 +15,7 @@ @attrs.frozen class VolumetricSetBackend( - VolumetricBackend[dict[str, torch.Tensor]] + VolumetricBackend[dict[str, npt.NDArray], Mapping[str, npt.NDArray | torch.Tensor]] ): # pylint: disable=too-few-public-methods layers: dict[str, VolumetricLayer] @@ -30,7 +32,7 @@ def name(self, name: str) -> None: # pragma: no cover ) @property - def dtype(self) -> torch.dtype: # pragma: no cover + def dtype(self) -> np.dtype: # pragma: no cover dtypes = {k: v.backend.dtype for k, v in self.layers.items()} if not len(set(dtypes.values())) == 1: raise ValueError( @@ -166,10 +168,10 @@ def assert_idx_is_chunk_aligned(self, idx: VolumetricIndex) -> None: # pragma: for e in self.layers.values(): e.backend.assert_idx_is_chunk_aligned(idx=idx) - def read(self, idx: VolumetricIndex) -> dict[str, torch.Tensor]: + def read(self, idx: VolumetricIndex) -> dict[str, npt.NDArray]: return {k: v.read_with_procs(idx) for k, v in self.layers.items()} - def write(self, idx: VolumetricIndex, data: dict[str, torch.Tensor]): + def write(self, idx: VolumetricIndex, data: Mapping[str, npt.NDArray | torch.Tensor]): for k, v in data.items(): self.layers[k].write_with_procs(idx, v) diff --git a/zetta_utils/layer/volumetric/layer_set/build.py b/zetta_utils/layer/volumetric/layer_set/build.py index d31e09b76..2176b9759 100644 --- a/zetta_utils/layer/volumetric/layer_set/build.py +++ b/zetta_utils/layer/volumetric/layer_set/build.py @@ -7,6 +7,7 @@ from zetta_utils import builder from zetta_utils.geometry import Vec3D +from zetta_utils.layer.volumetric.layer_set.layer import VolumetricSetDataWriteProcT from ... import IndexProcessor from .. import VolumetricFrontend, VolumetricIndex, VolumetricLayer @@ -23,7 +24,7 @@ def build_volumetric_layer_set( allow_slice_rounding: bool = False, index_procs: Iterable[IndexProcessor[VolumetricIndex]] = (), read_procs: Iterable[VolumetricSetDataProcT] = (), - write_procs: Iterable[VolumetricSetDataProcT] = (), + write_procs: Iterable[VolumetricSetDataWriteProcT] = (), ) -> VolumetricLayerSet: """Build a set of volumetric layers. :param layers: Mapping from layer names to layers. diff --git a/zetta_utils/layer/volumetric/layer_set/layer.py b/zetta_utils/layer/volumetric/layer_set/layer.py index 7f01922c6..674d57727 100644 --- a/zetta_utils/layer/volumetric/layer_set/layer.py +++ b/zetta_utils/layer/volumetric/layer_set/layer.py @@ -4,6 +4,7 @@ import attrs import torch +from numpy import typing as npt from typeguard import typechecked from ... import DataProcessor, IndexProcessor, JointIndexDataProcessor, Layer @@ -11,14 +12,20 @@ from . import VolumetricSetBackend VolumetricSetDataProcT = Union[ - DataProcessor[dict[str, torch.Tensor]], - JointIndexDataProcessor[dict[str, torch.Tensor], VolumetricIndex], + DataProcessor[dict[str, npt.NDArray]], + JointIndexDataProcessor[dict[str, npt.NDArray], VolumetricIndex], +] +VolumetricSetDataWriteProcT = Union[ + DataProcessor[Mapping[str, npt.NDArray | torch.Tensor]], + JointIndexDataProcessor[Mapping[str, npt.NDArray | torch.Tensor], VolumetricIndex], ] @typechecked @attrs.frozen -class VolumetricLayerSet(Layer[VolumetricIndex, dict[str, torch.Tensor]]): +class VolumetricLayerSet( + Layer[VolumetricIndex, dict[str, npt.NDArray], Mapping[str, npt.NDArray | torch.Tensor]] +): backend: VolumetricSetBackend frontend: VolumetricFrontend @@ -26,14 +33,16 @@ class VolumetricLayerSet(Layer[VolumetricIndex, dict[str, torch.Tensor]]): index_procs: tuple[IndexProcessor[VolumetricIndex], ...] = () read_procs: tuple[VolumetricSetDataProcT, ...] = () - write_procs: tuple[VolumetricSetDataProcT, ...] = () + write_procs: tuple[VolumetricSetDataWriteProcT, ...] = () - def __getitem__(self, idx: UserVolumetricIndex) -> dict[str, torch.Tensor]: + def __getitem__(self, idx: UserVolumetricIndex) -> dict[str, npt.NDArray]: idx_backend = self.frontend.convert_idx(idx) return self.read_with_procs(idx=idx_backend) def __setitem__( - self, idx: UserVolumetricIndex, data: Mapping[str, Union[torch.Tensor, int, float, bool]] + self, + idx: UserVolumetricIndex, + data: Mapping[str, Union[npt.NDArray, torch.Tensor, int, float, bool]], ): idx_backend: VolumetricIndex | None = None idx_last: VolumetricIndex | None = None diff --git a/zetta_utils/layer/volumetric/protocols.py b/zetta_utils/layer/volumetric/protocols.py index 5ddcaccd4..f27cc8db1 100644 --- a/zetta_utils/layer/volumetric/protocols.py +++ b/zetta_utils/layer/volumetric/protocols.py @@ -7,6 +7,7 @@ IndexT = TypeVar("IndexT", bound=VolumetricIndex) DataT = TypeVar("DataT") + VolumetricBasedLayerProtocolT = TypeVar( "VolumetricBasedLayerProtocolT", bound="VolumetricBasedLayerProtocol" ) diff --git a/zetta_utils/layer/volumetric/tensorstore/backend.py b/zetta_utils/layer/volumetric/tensorstore/backend.py index 81c743fb7..d9efe8181 100644 --- a/zetta_utils/layer/volumetric/tensorstore/backend.py +++ b/zetta_utils/layer/volumetric/tensorstore/backend.py @@ -10,6 +10,7 @@ import numpy as np import tensorstore import torch +from numpy import typing as npt from typeguard import suppress_type_checks from zetta_utils import tensor_ops @@ -66,8 +67,8 @@ def _clear_ts_cache(path: str | None = None) -> None: # pragma: no cover class TSBackend(VolumetricBackend): # pylint: disable=too-few-public-methods """ Backend for peforming IO on Neuroglancer datasts using TensorStore library. - Read data will be a ``torch.Tensor`` in ``CXYZ`` dimension order. - Write data is expected to be a ``torch.Tensor`` or ``np.ndarray`` in ``CXYZ`` + Read data will be a ``npt.NDArray`` in ``CXYZ`` dimension order. + Write data is expected to be a ``npt.NDArray`` or ``np.ndarray`` in ``CXYZ`` dimension order. :param path: Precomputed path. Can be given as relative or absolute. :param info_spec: Specification for the info file for the layer. If None, the @@ -137,13 +138,9 @@ def name(self, name: str) -> None: # pragma: no cover ) @property - def dtype(self) -> torch.dtype: - try: - result = _get_ts_at_resolution(self.path, self.cache_bytes_limit) - dtype = result.dtype.name - return getattr(torch, dtype) - except Exception as e: - raise e + def dtype(self) -> np.dtype: + result = _get_ts_at_resolution(self.path, self.cache_bytes_limit) + return result.dtype.name @property # TODO: Figure out a way to access 'multiscale metadata' directly @@ -190,7 +187,7 @@ def use_compression(self, value: bool) -> None: # pragma: no cover def clear_cache(self) -> None: # pragma: no cover _clear_ts_cache(self.path) - def read(self, idx: VolumetricIndex) -> torch.Tensor: + def read(self, idx: VolumetricIndex) -> npt.NDArray: # Data out: cxyz ts = _get_ts_at_resolution(self.path, self.cache_bytes_limit, str(list(idx.resolution))) @@ -208,11 +205,10 @@ def read(self, idx: VolumetricIndex) -> torch.Tensor: else: data_final = data_raw - result_np = np.transpose(data_final, (3, 0, 1, 2)) - result = tensor_ops.to_torch(result_np) + result = np.transpose(data_final, (3, 0, 1, 2)) return result - def write(self, idx: VolumetricIndex, data: torch.Tensor): + def write(self, idx: VolumetricIndex, data: torch.Tensor | npt.NDArray): if self._enforce_chunk_aligned_writes: self.assert_idx_is_chunk_aligned(idx) diff --git a/zetta_utils/layer/volumetric/tensorstore/build.py b/zetta_utils/layer/volumetric/tensorstore/build.py index 42b64e423..3a542d77a 100644 --- a/zetta_utils/layer/volumetric/tensorstore/build.py +++ b/zetta_utils/layer/volumetric/tensorstore/build.py @@ -4,6 +4,7 @@ from typing import Any, Iterable, Sequence, Union import torch +from numpy import typing as npt from zetta_utils import builder from zetta_utils.tensor_ops import InterpolationMode @@ -38,14 +39,14 @@ def build_ts_layer( # pylint: disable=too-many-locals index_procs: Iterable[IndexProcessor[VolumetricIndex]] = (), read_procs: Iterable[ Union[ - DataProcessor[torch.Tensor], - JointIndexDataProcessor[torch.Tensor, VolumetricIndex], + DataProcessor[npt.NDArray], + JointIndexDataProcessor[npt.NDArray, VolumetricIndex], ] ] = (), write_procs: Iterable[ Union[ - DataProcessor[torch.Tensor], - JointIndexDataProcessor[torch.Tensor, VolumetricIndex], + DataProcessor[npt.NDArray | torch.Tensor], + JointIndexDataProcessor[npt.NDArray | torch.Tensor, VolumetricIndex], ] ] = (), ) -> VolumetricLayer: # pragma: no cover # trivial conditional, delegation only diff --git a/zetta_utils/layer/volumetric/tools.py b/zetta_utils/layer/volumetric/tools.py index 15306222e..2758aad5e 100644 --- a/zetta_utils/layer/volumetric/tools.py +++ b/zetta_utils/layer/volumetric/tools.py @@ -3,7 +3,8 @@ from typing import List, Literal, Optional, Sequence, Tuple import attrs -import torch +import numpy as np +from numpy import typing as npt from typeguard import typechecked from zetta_utils import builder, log, tensor_ops @@ -114,7 +115,7 @@ def process_index( ) return result - def process_data(self, data: torch.Tensor, mode: Literal["read", "write"]) -> torch.Tensor: + def process_data(self, data: npt.NDArray, mode: Literal["read", "write"]) -> npt.NDArray: assert self.prepared_scale_factor is not None result = tensor_ops.interpolate( @@ -139,11 +140,11 @@ def process_index( ) -> VolumetricIndex: return idx - def process_data(self, data: torch.Tensor, mode: Literal["read", "write"]) -> torch.Tensor: + def process_data(self, data: npt.NDArray, mode: Literal["read", "write"]) -> npt.NDArray: if self.invert: - if not data.dtype == torch.uint8: + if not data.dtype == np.uint8: raise NotImplementedError("InvertProcessor is only supported for UInt8 layers.") - result = torch.bitwise_not(data) + 2 + result = np.bitwise_not(data) + 2 else: result = data return result @@ -201,15 +202,15 @@ def process_index( return idx def process_data( - self, data: dict[str, torch.Tensor], mode: Literal["read", "write"] - ) -> dict[str, torch.Tensor]: + self, data: dict[str, npt.NDArray], mode: Literal["read", "write"] + ) -> dict[str, npt.NDArray]: assert self.prepared_subidx is not None for target in self.targets: assert target in data if target + "_mask" in data: continue - roi_mask = torch.zeros_like(data[target]) + roi_mask = np.zeros_like(data[target]) extra_dims = roi_mask.ndim - len(self.prepared_subidx) slices = [slice(0, None) for _ in range(extra_dims)] slices += list(self.prepared_subidx) diff --git a/zetta_utils/mazepa_layer_processing/common/apply_mask_fn.py b/zetta_utils/mazepa_layer_processing/common/apply_mask_fn.py index 02f1a47a2..a404e73b1 100644 --- a/zetta_utils/mazepa_layer_processing/common/apply_mask_fn.py +++ b/zetta_utils/mazepa_layer_processing/common/apply_mask_fn.py @@ -2,19 +2,19 @@ from typing import Iterable -import torch +from numpy import typing as npt from zetta_utils import builder @builder.register("apply_mask_fn") def apply_mask_fn( - src: torch.Tensor, - masks: Iterable[torch.Tensor], + src: npt.NDArray, + masks: Iterable[npt.NDArray], fill_value: float = 0, -) -> torch.Tensor: +) -> npt.NDArray: result = src for mask in masks: result[mask > 0] = fill_value - result = result.to(src.dtype) + result = result.astype(src.dtype) return result diff --git a/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py b/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py index ed499c27e..95064bfa1 100644 --- a/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py @@ -2,7 +2,7 @@ from typing import Optional, Sequence, Union -import torch +from numpy import typing as npt from zetta_utils import builder, mazepa, tensor_ops from zetta_utils.common import ComparablePartial @@ -17,11 +17,11 @@ def _interpolate( - src: torch.Tensor, + src: npt.NDArray, scale_factor: Union[float, Sequence[float]], mode: tensor_ops.InterpolationMode, mask_value_thr: float = 0, -) -> torch.Tensor: +) -> npt.NDArray: # This dummy function is necessary to rename `src` to `data` arg result = tensor_ops.interpolate( data=src, diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index d55817f59..3396592cb 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -10,6 +10,7 @@ import attrs import cachetools import fsspec +import numpy as np import torch from typeguard import suppress_type_checks from typing_extensions import ParamSpec @@ -22,6 +23,7 @@ VolumetricIndexChunker, ) from zetta_utils.mazepa import semaphore +from zetta_utils.tensor_ops import convert from ..operation_protocols import VolumetricOpProtocol @@ -83,7 +85,10 @@ def __call__( with suppress_type_checks(): if len(src_layers) == 0: return - res = torch.zeros((dst.backend.num_channels, *red_idx.shape), dtype=dst.backend.dtype) + res = torch.zeros( + (dst.backend.num_channels, *red_idx.shape), + dtype=convert.to_torch_dtype(dst.backend.dtype), + ) assert len(src_layers) > 0 if processing_blend_pad != Vec3D[int](0, 0, 0): for src_idx, layer in zip(src_idxs, src_layers): @@ -101,6 +106,10 @@ def __call__( dst[red_idx] = res +def is_floating_point_dtype(dtype: np.dtype) -> bool: + return np.issubdtype(dtype, np.floating) + + @mazepa.taskable_operation_cls @attrs.mutable class ReduceByWeightedSum(ReduceOperation): @@ -124,15 +133,16 @@ def __call__( with suppress_type_checks(): if len(src_layers) == 0: return - if not dst.backend.dtype.is_floating_point and processing_blend_pad != Vec3D[int]( - 0, 0, 0 - ): + if not is_floating_point_dtype(dst.backend.dtype) and processing_blend_pad != Vec3D[ + int + ](0, 0, 0): # backend is integer, but blending is requested - need to use float to avoid # rounding errors res = torch.zeros((dst.backend.num_channels, *red_idx.shape), dtype=torch.float) else: res = torch.zeros( - (dst.backend.num_channels, *red_idx.shape), dtype=dst.backend.dtype + (dst.backend.num_channels, *red_idx.shape), + dtype=convert.to_torch_dtype(dst.backend.dtype), ) assert len(src_layers) > 0 if processing_blend_pad != Vec3D[int](0, 0, 0): @@ -147,7 +157,7 @@ def __call__( intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = [slice(0, res.shape[0])] + list(subidx) with semaphore("read"): - if not dst.backend.dtype.is_floating_point: + if not is_floating_point_dtype(dst.backend.dtype): # Temporarily convert integer cutout to float for rounding res[subidx_channels] = ( res[subidx_channels] + layer[intscn].float() * weight @@ -155,8 +165,8 @@ def __call__( else: res[subidx_channels] = res[subidx_channels] + layer[intscn] * weight - if not dst.backend.dtype.is_floating_point: - res = res.round().to(dtype=dst.backend.dtype) + if not is_floating_point_dtype(dst.backend.dtype): + res = res.round().to(dtype=convert.to_torch_dtype(dst.backend.dtype)) else: for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py index 4ce3b9e5c..33ac22353 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py @@ -7,6 +7,7 @@ import attrs import torch +from numpy import typing as npt from typing_extensions import ParamSpec from zetta_utils import builder, mazepa, tensor_ops @@ -37,7 +38,7 @@ class VolumetricCallableOperation(Generic[P]): `fn` does either of these things. """ - fn: Callable[P, torch.Tensor] + fn: Callable[P, npt.NDArray | torch.Tensor] fn_semaphores: Sequence[SemaphoreType] | None = None crop_pad: Sequence[int] = (0, 0, 0) res_change_mult: Sequence[float] = (1, 1, 1) @@ -111,7 +112,7 @@ def __call__( # pylint: disable=keyword-arg-before-vararg # TODO: remove as soon as `interpolate_flow` is cut and ComputeField is configured # to use subchunkable def build_chunked_volumetric_callable_flow_schema( - fn: Callable[P, torch.Tensor], + fn: Callable[P, npt.NDArray | torch.Tensor], chunker: IndexChunker[IndexT], crop_pad: Vec3D[int] = Vec3D[int](0, 0, 0), res_change_mult: Vec3D = Vec3D(1, 1, 1), diff --git a/zetta_utils/tensor_ops/convert.py b/zetta_utils/tensor_ops/convert.py index 039e1c525..d4872de15 100644 --- a/zetta_utils/tensor_ops/convert.py +++ b/zetta_utils/tensor_ops/convert.py @@ -10,6 +10,29 @@ from zetta_utils.tensor_ops.common import supports_dict from zetta_utils.tensor_typing import Tensor, TensorTypeVar +TORCH_TO_NP_DTYPE_MAP: dict[torch.dtype, np.dtype] = { + torch.float32: np.dtype("float32"), + torch.float64: np.dtype("float64"), + torch.float16: np.dtype("float16"), + torch.int32: np.dtype("int32"), + torch.int64: np.dtype("int64"), + torch.int16: np.dtype("int16"), + torch.int8: np.dtype("int8"), + torch.uint8: np.dtype("uint8"), + torch.bool: np.dtype("bool"), +} +NP_TO_TORCH_DTYPE_MAP: dict[np.dtype, torch.dtype] = { + v: k for k, v in TORCH_TO_NP_DTYPE_MAP.items() +} + + +def dtype_to_np_dtype(dtype: torch.dtype) -> np.dtype: # pragma: no cover + return TORCH_TO_NP_DTYPE_MAP[dtype] + + +def to_torch_dtype(dtype: np.dtype) -> torch.dtype: # pragma: no cover + return NP_TO_TORCH_DTYPE_MAP[dtype] + @typechecked def to_np(data: Tensor) -> npt.NDArray: