Skip to content

Commit

Permalink
fix: numerical stability on decimal resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
dodamih authored and nkemnitz committed Dec 6, 2024
1 parent 358847e commit 5d2d0f4
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/unit/geometry/test_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,46 @@ def test_constructor(
assert result.bounds == expected_bounds


@pytest.mark.parametrize(
"bbox, vec, expected",
[
[
BBox3D(bounds=((0, 1), (0, 2), (0, 3))),
Vec3D(1, 2, 3),
BBox3D(bounds=((0, 1), (0, 1), (0, 1))),
],
[
BBox3D(bounds=((0, 5734.4), (0, 8601.6), (0, 1))),
Vec3D(2.8, 2.8, 1),
BBox3D(bounds=((0, 2048), (0, 3072), (0, 1))),
],
],
)
def test_truediv(bbox: BBox3D, vec: Vec3D, expected: BBox3D):
result = bbox / vec
assert result == expected


@pytest.mark.parametrize(
"bbox, vec, expected",
[
[
BBox3D(bounds=((0, 1), (0, 2), (0, 3))),
Vec3D(1, 2, 3),
BBox3D(bounds=((0, 1), (0, 4), (0, 9))),
],
[
BBox3D(bounds=((0, 2048), (0, 3072), (0, 1))),
Vec3D(2.8, 2.8, 1),
BBox3D(bounds=((0, 5734.4), (0, 8601.6), (0, 1))),
],
],
)
def test_mul(bbox: BBox3D, vec: Vec3D, expected: BBox3D):
result = bbox * vec
assert result == expected


@pytest.mark.parametrize(
"bbox, resolution, allow_slice_rounding, expected",
[
Expand Down
19 changes: 19 additions & 0 deletions zetta_utils/geometry/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class BBox3D: # pylint: disable=too-many-public-methods # fundamental class
unit: str = DEFAULT_UNIT
pprint_px_resolution: Sequence[float] = (1, 1, 1)

def __attrs_post_init__(self):
object.__setattr__(
self,
"bounds",
tuple(
(
round(start, VEC3D_PRECISION),
round(end, VEC3D_PRECISION),
)
for start, end in self.bounds
),
)

@property
def ndim(self) -> int:
return 3
Expand Down Expand Up @@ -154,6 +167,12 @@ def from_points(

return cls(bounds=bounds, unit=unit)

def __truediv__(self, vec: Vec3D) -> BBox3D:
return BBox3D.from_coords(self.start / vec, self.end / vec)

def __mul__(self, vec: Vec3D) -> BBox3D:
return BBox3D.from_coords(self.start * vec, self.end * vec)

def get_slice(
self,
dim: int,
Expand Down
10 changes: 10 additions & 0 deletions zetta_utils/layer/volumetric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def from_coords(
allow_slice_rounding=allow_slice_rounding,
)

def __truediv__(self, vec: Vec3D) -> VolumetricIndex:
return VolumetricIndex(
self.resolution, self.bbox / vec, self.chunk_id, self.allow_slice_rounding
)

def __mul__(self, vec: Vec3D) -> VolumetricIndex:
return VolumetricIndex(
self.resolution, self.bbox * vec, self.chunk_id, self.allow_slice_rounding
)

def to_slices(self):
"""
Represent this index as a tuple of slices.
Expand Down

0 comments on commit 5d2d0f4

Please sign in to comment.