Skip to content

Commit

Permalink
Add a faster version of "from_set".
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715289056
  • Loading branch information
mjanusz authored and copybara-github committed Jan 14, 2025
1 parent 73b76a7 commit 16145f5
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 12 deletions.
158 changes: 146 additions & 12 deletions connectomics/segmentation/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from scipy import spatial


cKDTree = spatial._ckdtree.cKDTree # pylint:disable=protected-access


def from_subvolume(vol3d: np.ndarray) -> nx.Graph:
"""Returns the RAG for a 3d subvolume.
Expand Down Expand Up @@ -48,17 +51,98 @@ def from_subvolume(vol3d: np.ndarray) -> nx.Graph:
unique_joint_labels = np.unique(x)

seg_nbor_pairs = set(
zip(unique_joint_labels & 0xFFFFFFFF, unique_joint_labels >> 32))
zip(unique_joint_labels & 0xFFFFFFFF, unique_joint_labels >> 32)
)
g.add_edges_from(seg_nbor_pairs, dim=dim)

return g


def from_set(kdts: dict[int, spatial._ckdtree.cKDTree]) -> nx.Graph:
def _graph_from_pairs(
g: nx.Graph,
pairs: dict[tuple[int, int], tuple[float, int, int]],
) -> nx.Graph:
"""Builds a RAG from a set of segment pairs.
Args:
g: initial RAG
pairs: map from segment ID pairs to tuples of (distance, index1, index2)
Returns:
adjacency graph with greedily chosen edges connecting the most proximal
segment pairs
"""
dists = [(dist, idx1, idx2, k) for k, (dist, idx1, idx2), in pairs.items()]
dists.sort()

uf = nx.utils.UnionFind()

for dist, idx1, idx2, (id1, id2) in dists:
if uf[id1] == uf[id2]:
continue

uf.union(id1, id2)
g.add_edge(id1, id2, idx={id1: idx1, id2: idx2})

return g


def _connect_components(g: nx.Graph, kdts: dict[int, cKDTree]) -> nx.Graph:
"""Ensures that the graph is fully connected.
Connects separate components greedily based on maximal proximity.
Args:
g: initial graph defining how segments are connected
kdts: map from segment IDs to k-d trees of associated spatial coordinates
Returns:
graph with all components connected
"""

if nx.number_connected_components(g) <= 1:
return g

# Builds a KD-tree for each connected component.
ccs = list(nx.connected_components(g))
cc_kdts = {}
cc_to_seg = {}
cc_to_idx = {}
for i, cc in enumerate(ccs):
points = []
seg_ids = []
idxs = []
for seg_id in cc:
kdt = kdts[seg_id]
points.extend(kdt.data)
seg_ids.extend([seg_id] * len(kdt.data))
idxs.extend(list(range(len(kdt.data))))

cc_kdts[i] = cKDTree(np.array(points))
cc_to_seg[i] = seg_ids
cc_to_idx[i] = idxs

cc_g = from_set(cc_kdts)
for cc_i, cc_j, data in cc_g.edges(data=True):
id_to_idx = data['idx']
idx_i = id_to_idx[cc_i]
idx_j = id_to_idx[cc_j]
id1 = cc_to_seg[cc_i][idx_i]
id2 = cc_to_seg[cc_j][idx_j]
g.add_edge(
id1, id2, idx={id1: cc_to_idx[cc_i][idx_i], id2: cc_to_idx[cc_j][idx_j]}
)

assert nx.number_connected_components(g) <= 1
return g


def from_set(kdts: dict[int, cKDTree]) -> nx.Graph:
"""Builds a RAG for a set of segments relying on their spatial proximity.
A typical use case is to transform an equivalence set into a graph using
skeleton or other point-based representation of segments.
skeleton or other point-based representation of segments. This has O(N^2)
complexity.
Args:
kdts: map from segment IDs to k-d trees of associated spatial coordinates
Expand All @@ -74,19 +158,69 @@ def from_set(kdts: dict[int, spatial._ckdtree.cKDTree]) -> nx.Graph:
for j in range(i + 1, len(segment_ids)):
dist, idx = kdts[segment_ids[i]].query(kdts[segment_ids[j]].data, k=1)
ii = np.argmin(dist)
pairs[(segment_ids[i], segment_ids[j])] = (dist[ii], idx[ii], ii)
pairs[(segment_ids[i], segment_ids[j])] = (
dist[ii],
int(idx[ii]),
int(ii),
)

dists = [(v[0], v[1], v[2], k) for k, v in pairs.items()]
dists.sort()
g = nx.Graph()
g.add_nodes_from(segment_ids)

return _graph_from_pairs(g, pairs)


def from_set_nn(kdts: dict[int, cKDTree], max_dist: float) -> nx.Graph:
"""Like 'from_set', but uses a more efficient two-stage procedure.
First, a local neighborhood search is performed O(N log N), followed
by O(n^2) reconnection of 'n' connected components if necessary.
Args:
kdts: map from segments IDs to k-d trees of associated spatial coordinates
max_dist: maximum distance within which to search for neighbors (typical
distance between segments) in physical units
Returns:
adjacency graph with greedily chosen edges connecting the most proximal
segment pairs
"""
all_points = []
point_to_seg = []
point_to_idx = []

for seg_id, kdt in kdts.items():
all_points.extend(list(kdt.data))
point_to_seg.extend([seg_id] * len(kdt.data))
point_to_idx.extend(list(range(len(kdt.data))))

uf = nx.utils.UnionFind()
g = nx.Graph()
g.add_nodes_from(list(kdts.keys()))

for dist, idx1, idx2, (id1, id2) in dists:
if uf[id1] == uf[id2]:
continue
if not point_to_seg:
return g

uf.union(id1, id2)
g.add_edge(id1, id2, idx={id1: idx1, id2: idx2})
all_points = np.array(all_points)
combined_kdt = cKDTree(all_points)

# Find nearest neighbors within the radius for each point.
pairs = {}
for i, point in enumerate(all_points):
nbor_indices = combined_kdt.query_ball_point(point, max_dist)
for j in nbor_indices:
if i >= j:
continue

seg_i = point_to_seg[i]
seg_j = point_to_seg[j]

if seg_i != seg_j:
pair = seg_i, seg_j
dist = np.linalg.norm(point - all_points[j])

if pair not in pairs or dist < pairs[pair][0]:
pairs[pair] = (dist, point_to_idx[i], point_to_idx[j])

g = _graph_from_pairs(g, pairs)
g = _connect_components(g, kdts)
return g
16 changes: 16 additions & 0 deletions connectomics/segmentation/rag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def test_from_set_points(self):
g = rag.from_set(kdts)
self.assertTrue(nx.utils.edges_equal(g.edges(), ((1, 2), (2, 3))))

# All segments will be connected in the 1st (subquadratic) pass.
g2 = rag.from_set_nn(kdts, max_dist=10)
self.assertTrue(nx.utils.graphs_equal(g, g2))

# All segments will be connected in the 2nd (quadratic) pass.
g2 = rag.from_set_nn(kdts, max_dist=0.1)
self.assertTrue(nx.utils.graphs_equal(g, g2))

def test_from_set_skeletons(self):
# Each segment is associated with a short skeleton fragment.
skels = {
Expand Down Expand Up @@ -72,6 +80,14 @@ def test_from_set_skeletons(self):
self.assertEqual(g.edges[2, 3]['idx'][2], 3)
self.assertEqual(g.edges[2, 3]['idx'][3], 2)

# All segments will be connected in the 1st (subquadratic) pass.
g2 = rag.from_set_nn(kdts, max_dist=10)
self.assertTrue(nx.utils.graphs_equal(g, g2))

# All segments will be connected in the 2nd (quadratic) pass.
g2 = rag.from_set_nn(kdts, max_dist=0.1)
self.assertTrue(nx.utils.graphs_equal(g, g2))

def test_from_subvolume(self):
seg = np.zeros((10, 10, 2), dtype=np.uint64)
seg[2:, :, 0] = 1
Expand Down

0 comments on commit 16145f5

Please sign in to comment.