Skip to content

Commit

Permalink
feat: make vol layers np/torch writeable
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed May 15, 2024
1 parent 4a564ee commit 3a8fd6f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
6 changes: 6 additions & 0 deletions zetta_utils/layer/volumetric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch

from zetta_utils.tensor_ops import convert
from .index import (
VolumetricIndex,
)
Expand Down Expand Up @@ -27,3 +30,6 @@
from .constant import ConstantVolumetricBackend, build_constant_volumetric_layer
from .layer_set import VolumetricLayerSet, build_volumetric_layer_set
from .protocols import VolumetricBasedLayerProtocol

VolumetricLayerDType = torch.Tensor
to_vol_layer_dtype = convert.to_torch
8 changes: 6 additions & 2 deletions zetta_utils/layer/volumetric/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import attrs
import torch
from numpy import typing as npt

from zetta_utils import tensor_ops
from zetta_utils.geometry import BBox3D, Vec3D

from . import VolumetricIndex
Expand Down Expand Up @@ -104,7 +106,9 @@ def convert_idx(self, idx_user: UserVolumetricIndex) -> VolumetricIndex:
return result

def convert_write(
self, idx_user: UserVolumetricIndex, data_user: Union[torch.Tensor, float, int, bool]
self,
idx_user: UserVolumetricIndex,
data_user: npt.NDArray | torch.Tensor | float | int | bool,
) -> tuple[VolumetricIndex, torch.Tensor]:
idx = self.convert_idx(idx_user)
if isinstance(data_user, (float, int)):
Expand All @@ -116,6 +120,6 @@ def convert_write(
dtype = dtype_mapping[type(data_user)]
data = torch.Tensor([data_user]).to(dtype)
else:
data = data_user
data = tensor_ops.convert.to_torch(data_user)

return idx, data
5 changes: 4 additions & 1 deletion zetta_utils/layer/volumetric/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import attrs
import torch
from numpy import typing as npt

from .. import DataProcessor, IndexProcessor, JointIndexDataProcessor, Layer
from . import (
Expand Down Expand Up @@ -32,7 +33,9 @@ def __getitem__(self, idx: UserVolumetricIndex) -> torch.Tensor:
idx_backend = self.frontend.convert_idx(idx)
return self.read_with_procs(idx=idx_backend)

def __setitem__(self, idx: UserVolumetricIndex, data: torch.Tensor | float | int | bool):
def __setitem__(
self, idx: UserVolumetricIndex, data: npt.NDArray | torch.Tensor | float | int | bool
):
idx_backend, data_backend = self.frontend.convert_write(idx, data)
self.write_with_procs(idx=idx_backend, data=data_backend)

Expand Down

0 comments on commit 3a8fd6f

Please sign in to comment.