Skip to content

Commit

Permalink
make heal_skeleton() faster:
Browse files Browse the repository at this point in the history
- use pykdtree if available
- cleverer filtering
- better use of pandas groupby
  • Loading branch information
schlegelp committed Jan 26, 2024
1 parent b92b9c2 commit daba1fa
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions navis/morpho/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@

from collections import namedtuple
from itertools import combinations
from scipy.spatial import cKDTree
from scipy.ndimage import gaussian_filter
from typing import Union, Optional, Sequence, List, Set
from typing_extensions import Literal

try:
from pykdtree.kdtree import KDTree
except ImportError:
from scipy.spatial import cKDTree as KDTree

from .. import graph, utils, config, core
from . import mmetrics, subset

Expand Down Expand Up @@ -1910,52 +1914,76 @@ def _stitch_mst(x: 'core.TreeNeuron',

g = x.graph.to_undirected()

# Extract each fragment's rows and construct a KD-Tree
Fragment = namedtuple('Fragment', ['frag_id', 'df', 'kd'])
fragments = []
for frag_id, cc in enumerate(nx.connected_components(g)):
if len(cc) == len(g.nodes):
# There's only one component -- no healing necessary
return x

# Skip if fragment is smaller than threshold
if not isinstance(min_size, type(None)):
if len(cc) < min_size:
continue
# Get connected components
cc = list(nx.connected_components(g))
if len(cc) == 1:
# There's only one component -- no healing necessary
return x

df = x.nodes.query('node_id in @cc')
# Turn into a dictionary node -> component
cc = {n: i for i, c in enumerate(cc) for n in c}

# If mask, drop everything that is masked out
if not isinstance(mask, type(None)):
df = df[df.node_id.isin(mask)]
# Turn into a Series
cc = x.nodes.node_id.map(cc)

# Filter to leaf nodes if applicable
if nodes == 'LEAFS':
df = df[df['type'].isin(['end', 'root'])]
to_use = x.nodes
# Drop fragments smaller than threshold
if not isinstance(min_size, type(None)):
sizes = cc.value_counts()
above = sizes[sizes >= min_size].index
to_use = to_use[cc.isin(above)]
cc = cc[cc.isin(above)]

if not df.empty:
kd = cKDTree(df[[*'xyz']].values)
fragments.append(Fragment(frag_id, df, kd))
# Filter to leaf nodes if applicable
if nodes == 'LEAFS':
keep = to_use['type'].isin(['end', 'root'])
to_use = to_use[keep]
cc = cc[keep]

# If mask, drop everything that is masked out
if not isinstance(mask, type(None)):
keep = to_use.node_id.isin(mask)
to_use = to_use[keep]
cc = cc[keep]

# Collect fragments
Fragment = namedtuple('Fragment', ['frag_id', 'node_ids', 'kd'])
fragments = []
for frag_id, df in to_use.groupby(cc):
kd = KDTree(df[[*'xyz']].values)
fragments.append(Fragment(frag_id, df.node_id.values, kd))

# Sort from big-to-small, so the calculations below use a
# KD tree for the larger point set in every fragment pair.
fragments = sorted(fragments, key=lambda frag: -len(frag.df))
fragments = sorted(fragments, key=lambda frag: -len(frag.node_ids))

# We could use the full graph and connect all
# fragment pairs at their nearest neighbors,
# but it's faster to treat each fragment as a
# single node and run MST on that quotient graph,
# which is tiny.
# Note to self:
# This approach works well if we have a small number of fragments to connect
# But with a large number of fragments, the number of comparisons grows
# exponentially (len(fragments) ** 2 - len(fragments)) / 2) and we would be
# better off running a brute force pairwise distance function on all
# relevant nodes and constructing the graph from that.
frag_graph = nx.Graph()
for frag_a, frag_b in combinations(fragments, 2):
coords_b = frag_b.kd.data
distances, indexes = frag_a.kd.query(coords_b)
if coords_b.ndim == 1:
coords_b = coords_b.reshape(-1, 3)
distances, indexes = frag_a.kd.query(coords_b, distance_upper_bound=max_dist)

# Ignore fragments that are too far apart
if np.all(np.isinf(distances)):
continue

index_b = np.argmin(distances)
index_a = indexes[index_b]

node_a = frag_a.df['node_id'].iloc[index_a]
node_b = frag_b.df['node_id'].iloc[index_b]
node_a = frag_a.node_ids[index_a]
node_b = frag_b.node_ids[index_b]
dist_ab = distances[index_b]

# Add edge from one fragment to another,
Expand All @@ -1970,7 +1998,7 @@ def _stitch_mst(x: 'core.TreeNeuron',

# For each inter-fragment edge, add the corresponding
# fine-grained edge between skeleton nodes in the original graph.
to_add = [[e[2]['node_a'], e[2]['node_b']] for e in frag_edges if e[2]['distance'] <= max_dist]
to_add = [[e[2]['node_a'], e[2]['node_b']] for e in frag_edges]
g.add_edges_from(to_add)

# Rewire based on graph
Expand Down

0 comments on commit daba1fa

Please sign in to comment.