diff --git a/zetta_utils/tensor_ops/generators.py b/zetta_utils/tensor_ops/generators.py index 29db307ab..02a797910 100644 --- a/zetta_utils/tensor_ops/generators.py +++ b/zetta_utils/tensor_ops/generators.py @@ -1,13 +1,17 @@ import math -from typing import Callable, Sequence +from typing import Callable, Iterable, Sequence import affine import einops import torch import torchfields # pylint: disable=unused-import +from neuroglancer.viewer_state import LineAnnotation from typeguard import typechecked from zetta_utils import builder +from zetta_utils.db_annotations.annotation import AnnotationDBEntry +from zetta_utils.geometry.vec import VEC3D_PRECISION, Vec3D +from zetta_utils.layer.volumetric.index import VolumetricIndex from zetta_utils.tensor_ops import convert from zetta_utils.tensor_typing import Tensor @@ -112,6 +116,37 @@ def get_field_from_matrix( return displacement_field +def get_field_from_annotations( + line_annotations: Iterable[AnnotationDBEntry], + index: VolumetricIndex, + device: torch.types.Device | None = None, +) -> torchfields.Field: + """ + Returns a sparse 2D displacement field based on the provided line annotations. + + :param line_annotations: List of line annotations. + :param index: VolumetricIndex specifying bounds and resolution of returned field. + :param device: Device to use for returned field. + :return: DisplacementField for the given line annotations. Unspecified values are NaN. + """ + sparse_field = torch.full((1, 2, index.shape[0], index.shape[1]), torch.nan, device=device) + for line in line_annotations: + annotation = line.ng_annotation + if not isinstance(annotation, LineAnnotation): + raise ValueError(f"Expected LineAnnotation, got {type(annotation)}") + + pointA: Vec3D = Vec3D(*annotation.pointA) / index.resolution + pointB: Vec3D = Vec3D(*annotation.pointB) / index.resolution + if index.contains(round(pointA, VEC3D_PRECISION)): + sparse_field[ + 0, :, math.floor(pointA.x) - index.start[0], math.floor(pointA.y) - index.start[1] + ] = torch.tensor(pointB - pointA)[ + :2 + ].flipud() # Don't know why flipud is needed + + return sparse_field + + # https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 @builder.register("rand_perlin_2d") @typechecked