diff --git a/navis/morpho/manipulation.py b/navis/morpho/manipulation.py index b60fb195..21b0842b 100644 --- a/navis/morpho/manipulation.py +++ b/navis/morpho/manipulation.py @@ -23,11 +23,15 @@ from collections import namedtuple from itertools import combinations -from scipy.spatial import cKDTree from scipy.ndimage import gaussian_filter from typing import Union, Optional, Sequence, List, Set from typing_extensions import Literal +try: + from pykdtree.kdtree import KDTree +except ImportError: + from scipy.spatial import cKDTree as KDTree + from .. import graph, utils, config, core from . import mmetrics, subset @@ -1910,52 +1914,76 @@ def _stitch_mst(x: 'core.TreeNeuron', g = x.graph.to_undirected() - # Extract each fragment's rows and construct a KD-Tree - Fragment = namedtuple('Fragment', ['frag_id', 'df', 'kd']) - fragments = [] - for frag_id, cc in enumerate(nx.connected_components(g)): - if len(cc) == len(g.nodes): - # There's only one component -- no healing necessary - return x - - # Skip if fragment is smaller than threshold - if not isinstance(min_size, type(None)): - if len(cc) < min_size: - continue + # Get connected components + cc = list(nx.connected_components(g)) + if len(cc) == 1: + # There's only one component -- no healing necessary + return x - df = x.nodes.query('node_id in @cc') + # Turn into a dictionary node -> component + cc = {n: i for i, c in enumerate(cc) for n in c} - # If mask, drop everything that is masked out - if not isinstance(mask, type(None)): - df = df[df.node_id.isin(mask)] + # Turn into a Series + cc = x.nodes.node_id.map(cc) - # Filter to leaf nodes if applicable - if nodes == 'LEAFS': - df = df[df['type'].isin(['end', 'root'])] + to_use = x.nodes + # Drop fragments smaller than threshold + if not isinstance(min_size, type(None)): + sizes = cc.value_counts() + above = sizes[sizes >= min_size].index + to_use = to_use[cc.isin(above)] + cc = cc[cc.isin(above)] - if not df.empty: - kd = cKDTree(df[[*'xyz']].values) - fragments.append(Fragment(frag_id, df, kd)) + # Filter to leaf nodes if applicable + if nodes == 'LEAFS': + keep = to_use['type'].isin(['end', 'root']) + to_use = to_use[keep] + cc = cc[keep] + + # If mask, drop everything that is masked out + if not isinstance(mask, type(None)): + keep = to_use.node_id.isin(mask) + to_use = to_use[keep] + cc = cc[keep] + + # Collect fragments + Fragment = namedtuple('Fragment', ['frag_id', 'node_ids', 'kd']) + fragments = [] + for frag_id, df in to_use.groupby(cc): + kd = KDTree(df[[*'xyz']].values) + fragments.append(Fragment(frag_id, df.node_id.values, kd)) # Sort from big-to-small, so the calculations below use a # KD tree for the larger point set in every fragment pair. - fragments = sorted(fragments, key=lambda frag: -len(frag.df)) + fragments = sorted(fragments, key=lambda frag: -len(frag.node_ids)) # We could use the full graph and connect all # fragment pairs at their nearest neighbors, # but it's faster to treat each fragment as a # single node and run MST on that quotient graph, # which is tiny. + # Note to self: + # This approach works well if we have a small number of fragments to connect + # But with a large number of fragments, the number of comparisons grows + # exponentially (len(fragments) ** 2 - len(fragments)) / 2) and we would be + # better off running a brute force pairwise distance function on all + # relevant nodes and constructing the graph from that. frag_graph = nx.Graph() for frag_a, frag_b in combinations(fragments, 2): coords_b = frag_b.kd.data - distances, indexes = frag_a.kd.query(coords_b) + if coords_b.ndim == 1: + coords_b = coords_b.reshape(-1, 3) + distances, indexes = frag_a.kd.query(coords_b, distance_upper_bound=max_dist) + + # Ignore fragments that are too far apart + if np.all(np.isinf(distances)): + continue index_b = np.argmin(distances) index_a = indexes[index_b] - node_a = frag_a.df['node_id'].iloc[index_a] - node_b = frag_b.df['node_id'].iloc[index_b] + node_a = frag_a.node_ids[index_a] + node_b = frag_b.node_ids[index_b] dist_ab = distances[index_b] # Add edge from one fragment to another, @@ -1970,7 +1998,7 @@ def _stitch_mst(x: 'core.TreeNeuron', # For each inter-fragment edge, add the corresponding # fine-grained edge between skeleton nodes in the original graph. - to_add = [[e[2]['node_a'], e[2]['node_b']] for e in frag_edges if e[2]['distance'] <= max_dist] + to_add = [[e[2]['node_a'], e[2]['node_b']] for e in frag_edges] g.add_edges_from(to_add) # Rewire based on graph