Skip to content

Commit

Permalink
feat(generators): field from line annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Oct 4, 2024
1 parent 2f4c61f commit 8849ff4
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion zetta_utils/tensor_ops/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import math
from typing import Callable, Sequence
from typing import Callable, Iterable, Sequence

import affine
import einops
import torch
import torchfields # pylint: disable=unused-import
from neuroglancer.viewer_state import LineAnnotation
from typeguard import typechecked

from zetta_utils import builder
from zetta_utils.db_annotations.annotation import AnnotationDBEntry
from zetta_utils.geometry.vec import VEC3D_PRECISION, Vec3D
from zetta_utils.layer.volumetric.index import VolumetricIndex
from zetta_utils.tensor_ops import convert
from zetta_utils.tensor_typing import Tensor

Expand Down Expand Up @@ -112,6 +116,37 @@ def get_field_from_matrix(
return displacement_field


def get_field_from_annotations(
line_annotations: Iterable[AnnotationDBEntry],
index: VolumetricIndex,
device: torch.types.Device | None = None,
) -> torchfields.Field:
"""
Returns a sparse 2D displacement field based on the provided line annotations.
:param line_annotations: List of line annotations.
:param index: VolumetricIndex specifying bounds and resolution of returned field.
:param device: Device to use for returned field.
:return: DisplacementField for the given line annotations. Unspecified values are NaN.
"""
sparse_field = torch.full((1, 2, index.shape[0], index.shape[1]), torch.nan, device=device)
for line in line_annotations:
annotation = line.ng_annotation
if not isinstance(annotation, LineAnnotation):
raise ValueError(f"Expected LineAnnotation, got {type(annotation)}")

Check warning on line 136 in zetta_utils/tensor_ops/generators.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/tensor_ops/generators.py#L132-L136

Added lines #L132 - L136 were not covered by tests

pointA: Vec3D = Vec3D(*annotation.pointA) / index.resolution
pointB: Vec3D = Vec3D(*annotation.pointB) / index.resolution
if index.contains(round(pointA, VEC3D_PRECISION)):
sparse_field[

Check warning on line 141 in zetta_utils/tensor_ops/generators.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/tensor_ops/generators.py#L138-L141

Added lines #L138 - L141 were not covered by tests
0, :, math.floor(pointA.x) - index.start[0], math.floor(pointA.y) - index.start[1]
] = torch.tensor(pointB - pointA)[
:2
].flipud() # Don't know why flipud is needed

return sparse_field

Check warning on line 147 in zetta_utils/tensor_ops/generators.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/tensor_ops/generators.py#L147

Added line #L147 was not covered by tests


# https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
@builder.register("rand_perlin_2d")
@typechecked
Expand Down

0 comments on commit 8849ff4

Please sign in to comment.