Skip to content

Commit

Permalink
fix list-of-list bug with spelunker
Browse files Browse the repository at this point in the history
  • Loading branch information
ceesem committed May 30, 2024
1 parent cec5428 commit cf5d8bd
Showing 1 changed file with 144 additions and 116 deletions.
260 changes: 144 additions & 116 deletions src/nglui/parser/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import numpy as np
import pandas as pd
from itertools import chain

from ..easyviewer.ev_base.base import SEGMENTATION_LAYER_TYPES


def _is_spelunker_state(state):
"""Check if a state is a spelunker state or not."""
return "dimension" in state.keys()


def layer_names(state):
"""Get all layer names in the state
Expand Down Expand Up @@ -177,7 +185,7 @@ def _get_line_annotations(state, layer_name):


def _get_bbox_annotations(state, layer_name):
return _get_type_annotations(state, layer_name, "line")
return _get_type_annotations(state, layer_name, "axis_aligned_bounding_box")


def _get_group_annotations(state, layer_name):
Expand Down Expand Up @@ -229,6 +237,22 @@ def _extract_group_data(annos):
}


def _flatten_list_of_strings(l):
out = []
for x in l:
if isinstance(x, list):
out.extend(x)
else:
out.append(x)
return out


def _extract_segments(anno):
seg_data = anno.get("segments", [])
seg_list = _flatten_list_of_strings(seg_data)
return [int(x) for x in seg_list]


def _generic_annotations(
state, layer_name, description, linked_segmentations, tags, group, anno_type
):
Expand All @@ -238,10 +262,7 @@ def _generic_annotations(
desc = [anno.get("description", None) for anno in annos]
out.append(desc)
if linked_segmentations:
linked_seg = [
[np.uint64(x) for x in anno.get("segments", [])] for anno in annos
]
out.append(linked_seg)
out.append([_extract_segments(anno) for anno in annos])
if tags:
tag_list = [anno.get("tagIds", []) for anno in annos]
out.append(tag_list)
Expand Down Expand Up @@ -539,20 +560,15 @@ def extract_multicut(state, seg_layer=None):
return np.array(pts), np.array(side), np.array(svids), root_id


def annotation_dataframe(state):
"""Return all annotations across all annotation layers in the state.
Parameters
----------
state : dict
Neuroglancer state dictionary
def _concat_list(d):
d_out = []
for x in d:
for y in x:
d_out.append(y)
return d_out

Returns
-------
pd.DataFrame
Dataframe with columns layer, anno_type, point, pointB, linked_segmentation, tags, anno_id, group_id, description.
"""

def _parse_layer_dataframe(state, ln, expand_tags):
lns = []
points = []
anno_types = []
Expand All @@ -562,104 +578,95 @@ def annotation_dataframe(state):
group_ids = []
descs = []

for ln in annotation_layers(state):
# Points
p_pt, p_desc, p_seg, p_tag, p_grp = point_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_p_pts = len(p_pt)
p_type = ["point"] * n_p_pts
p_ln = [ln] * n_p_pts
p_ptB = [np.nan] * n_p_pts

lns.append(p_ln)
points.append(p_pt)
anno_types.append(p_type)
pointBs.append(p_ptB)
linked_segs.append(p_seg)
tags.append(p_tag)
group_ids.append(p_grp)
descs.append(p_desc)

# Lines
l_ptA, l_ptB, l_desc, l_seg, l_tag, l_grp = line_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_l_pts = len(l_ptA)
l_type = ["line"] * n_l_pts
l_ln = [ln] * n_l_pts

lns.append(l_ln)
points.append(l_ptA)
anno_types.append(l_type)
pointBs.append(l_ptB)
linked_segs.append(l_seg)
tags.append(l_tag)
group_ids.append(l_grp)
descs.append(l_desc)

# Spheres
s_ptA, s_ptB, s_desc, s_seg, s_tag, s_grp = sphere_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_s_pts = len(s_ptA)
s_type = ["sphere"] * n_s_pts
s_ln = [ln] * n_s_pts

lns.append(s_ln)
points.append(s_ptA)
anno_types.append(s_type)
pointBs.append(s_ptB)
linked_segs.append(s_seg)
tags.append(s_tag)
group_ids.append(s_grp)
descs.append(s_desc)

# Bboxes
b_ptA, b_ptB, b_desc, b_seg, b_tag, b_grp = bbox_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_b_pts = len(b_ptA)
b_type = ["bbox"] * n_b_pts
b_ln = [ln] * n_b_pts

lns.append(b_ln)
points.append(b_ptA)
anno_types.append(b_type)
pointBs.append(b_ptB)
linked_segs.append(b_seg)
tags.append(b_tag)
group_ids.append(b_grp)
descs.append(b_desc)

def _concat_list(d):
d_out = []
for x in d:
for y in x:
d_out.append(y)
return d_out

df = pd.DataFrame(
p_pt, p_desc, p_seg, p_tag, p_grp = point_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_p_pts = len(p_pt)
p_type = ["point"] * n_p_pts
p_ln = [ln] * n_p_pts
p_ptB = [np.nan] * n_p_pts

lns.append(p_ln)
points.append(p_pt)
anno_types.append(p_type)
pointBs.append(p_ptB)
linked_segs.append(p_seg)
tags.append(p_tag)
group_ids.append(p_grp)
descs.append(p_desc)

# Lines
l_ptA, l_ptB, l_desc, l_seg, l_tag, l_grp = line_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_l_pts = len(l_ptA)
l_type = ["line"] * n_l_pts
l_ln = [ln] * n_l_pts

lns.append(l_ln)
points.append(l_ptA)
anno_types.append(l_type)
pointBs.append(l_ptB)
linked_segs.append(l_seg)
tags.append(l_tag)
group_ids.append(l_grp)
descs.append(l_desc)

# Spheres
s_ptA, s_ptB, s_desc, s_seg, s_tag, s_grp = sphere_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_s_pts = len(s_ptA)
s_type = ["sphere"] * n_s_pts
s_ln = [ln] * n_s_pts

lns.append(s_ln)
points.append(s_ptA)
anno_types.append(s_type)
pointBs.append(s_ptB)
linked_segs.append(s_seg)
tags.append(s_tag)
group_ids.append(s_grp)
descs.append(s_desc)

# Bboxes
b_ptA, b_ptB, b_desc, b_seg, b_tag, b_grp = bbox_annotations(
state,
ln,
description=True,
linked_segmentations=True,
tags=True,
group=True,
)
n_b_pts = len(b_ptA)
b_type = ["bbox"] * n_b_pts
b_ln = [ln] * n_b_pts

lns.append(b_ln)
points.append(b_ptA)
anno_types.append(b_type)
pointBs.append(b_ptB)
linked_segs.append(b_seg)
tags.append(b_tag)
group_ids.append(b_grp)
descs.append(b_desc)

return pd.DataFrame(
{
"layer": _concat_list(lns),
"anno_type": _concat_list(anno_types),
Expand All @@ -671,4 +678,25 @@ def _concat_list(d):
"description": _concat_list(descs),
}
)
return df


def annotation_dataframe(state, expand_tags=False):
"""Return all annotations across all annotation layers in the state.
Parameters
----------
state : dict
Neuroglancer state dictionary
expand_tags : bool, optional
If True, expand tags into separate boolean columns. By default False
Returns
-------
pd.DataFrame
Dataframe with columns layer, anno_type, point, pointB, linked_segmentation, tags, anno_id, group_id, description.
"""
dfs = [
_parse_layer_dataframe(state, ln, expand_tags)
for ln in annotation_layers(state)
]
return pd.concat(dfs, ignore_index=True)

0 comments on commit cf5d8bd

Please sign in to comment.