Skip to content

Commit

Permalink
feat(generators): field from line annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Jan 3, 2025
1 parent 2107a1d commit d29321c
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
64 changes: 64 additions & 0 deletions tests/unit/tensor_ops/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np
import pytest
import torch
from torch import nan

from zetta_utils.db_annotations.annotation import AnnotationDBEntry
from zetta_utils.geometry import Vec3D
from zetta_utils.layer.volumetric.index import VolumetricIndex
from zetta_utils.tensor_ops import generators


Expand Down Expand Up @@ -230,3 +234,63 @@ def test_get_field_from_matrix(mat, size, expected_shape, expected_dtype):
def test_get_field_from_matrix_exceptions(mat, size):
with pytest.raises(ValueError):
generators.get_field_from_matrix(mat, size)


@pytest.mark.parametrize(
"annotations, index, expected",
[
# Single vector
[
[
AnnotationDBEntry.from_dict(
"1", {"type": "line", "pointA": [1.0, 0.0, 0.0], "pointB": [0.0, 0.0, 1.0]}
),
],
VolumetricIndex.from_coords([0, 0, 0], [2, 2, 1], Vec3D(*[1, 1, 1])),
torch.tensor(
[[[[nan, nan], [0.0, nan]], [[nan, nan], [-1.0, nan]]]], dtype=torch.float32
),
],
# Two vectors, same target pixel -- should average
[
[
AnnotationDBEntry.from_dict(
"1", {"type": "line", "pointA": [0.4, 0.0, 0.0], "pointB": [0.0, 0.0, 1.0]}
),
AnnotationDBEntry.from_dict(
"2", {"type": "line", "pointA": [0.6, 0.0, 0.0], "pointB": [1.0, 0.0, 1.0]}
),
],
VolumetricIndex.from_coords([0, 0, 0], [2, 2, 1], Vec3D(*[1, 1, 1])),
torch.tensor(
[[[[0.0, nan], [nan, nan]], [[0.0, nan], [nan, nan]]]], dtype=torch.float32
),
],
# Annotation outside ROI
[
[
AnnotationDBEntry.from_dict(
"1", {"type": "line", "pointA": [0.4, 0.0, 0.0], "pointB": [0.0, 0.0, 1.0]}
),
],
VolumetricIndex.from_coords([0, 0, 1], [2, 2, 2], Vec3D(*[1, 1, 1])),
torch.tensor(
[[[[nan, nan], [nan, nan]], [[nan, nan], [nan, nan]]]], dtype=torch.float32
),
],
],
)
def test_get_field_from_annotations(annotations, index, expected):
result = generators.get_field_from_annotations(annotations, index, device="cpu")
torch.testing.assert_close(result, expected, equal_nan=True)
assert result.device == torch.device("cpu")


def test_get_field_from_annotations_exception():
with pytest.raises(ValueError):
generators.get_field_from_annotations(
[AnnotationDBEntry.from_dict("1", {"type": "point", "point": [0.0, 0.0, 0.0]})],
VolumetricIndex.from_coords([0, 0, 0], [2, 2, 1], Vec3D(*[1, 1, 1])),
)
45 changes: 44 additions & 1 deletion zetta_utils/tensor_ops/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import math
from typing import Callable, Sequence
from typing import Callable, Iterable, Sequence

import affine
import einops
import torch
import torchfields # pylint: disable=unused-import
from neuroglancer.viewer_state import LineAnnotation
from typeguard import typechecked

from zetta_utils import builder
from zetta_utils.db_annotations.annotation import AnnotationDBEntry
from zetta_utils.geometry.vec import VEC3D_PRECISION, Vec3D
from zetta_utils.layer.volumetric.index import VolumetricIndex
from zetta_utils.tensor_ops import convert
from zetta_utils.tensor_typing import Tensor

Expand Down Expand Up @@ -112,6 +116,45 @@ def get_field_from_matrix(
return displacement_field


def get_field_from_annotations(
line_annotations: Iterable[AnnotationDBEntry],
index: VolumetricIndex,
device: torch.types.Device | None = None,
) -> torchfields.Field:
"""
Returns a sparse 2D displacement field based on the provided line annotations.
:param line_annotations: Iterable line annotations.
:param index: VolumetricIndex specifying bounds and resolution of returned field.
:param device: Device to use for returned field.
:return: DisplacementField for the given line annotations. Unspecified values are NaN.
"""
sparse_field = torch.full((1, 2, index.shape[0], index.shape[1]), torch.nan, device=device)
contrib_sum = torch.zeros(sparse_field.shape[-2:], device=device)
for line in line_annotations:
annotation = line.ng_annotation
if not isinstance(annotation, LineAnnotation):
raise ValueError(f"Expected LineAnnotation, got {type(annotation)}")

pointA: Vec3D = Vec3D(*annotation.pointA) / index.resolution
pointB: Vec3D = Vec3D(*annotation.pointB) / index.resolution
if index.contains(round(pointA, VEC3D_PRECISION)):
x = math.floor(pointA.x) - index.start[0]
y = math.floor(pointA.y) - index.start[1]
# Contribution is 1 - sqrt(2) at corner of pixel, 1 at center
# Maybe should consider adjacent pixel, but hopefully good enough for now
contrib = 1.0 - math.sqrt((pointA.x % 1 - 0.5) ** 2 + (pointA.y % 1 - 0.5) ** 2)
contrib_sum[x, y] += contrib

if sparse_field[0, :, x, y].isnan().all():
sparse_field[0, :, x, y] = 0.0

# Field xy is flipped
sparse_field[0, :, x, y] += contrib * torch.tensor(pointB - pointA)[:2].flipud()

return sparse_field / contrib_sum


# https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
@builder.register("rand_perlin_2d")
@typechecked
Expand Down

0 comments on commit d29321c

Please sign in to comment.