From 76d99f6b8d00fb468cf702783ce90ae964b8f319 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Sun, 31 Mar 2024 11:45:35 +0100 Subject: [PATCH] add more postprocessing methods: - smooth - despike - rename remove_hairs -> remove_bristles - update docs --- NEWS.md | 2 +- skeletor/__init__.py | 16 +++- skeletor/post/__init__.py | 11 ++- skeletor/post/postprocessing.py | 159 ++++++++++++++++++++++++++++---- skeletor/pre/meshcontraction.py | 8 +- 5 files changed, 171 insertions(+), 25 deletions(-) diff --git a/NEWS.md b/NEWS.md index 43bf7ef..bccbd2c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -48,7 +48,7 @@ separate functions: results (SWC, nodes/vertices, edges, etc.) and allows for quick visualization - SWC tables are now strictly conforming to the format (continuous node IDs, parents always listed before their childs, etc) -- `Skeleton` results contain a mesh to skeleton mapping as `.mesh_map` property +- `Skeleton` results contain a mesh to skeleton mapping as `.mesh_map` property - added an example mesh: to load it use `skeletor.example_mesh()` - `skeletor` now has proper tests and a simple [documentation](https://navis-org.github.io/skeletor/) diff --git a/skeletor/__init__.py b/skeletor/__init__.py index 76a10dd..c06a6bc 100644 --- a/skeletor/__init__.py +++ b/skeletor/__init__.py @@ -60,7 +60,7 @@ 1. Some pre-processing of the mesh (e.g. fixing some potential errors like degenerate faces, unreferenced vertices, etc.) 2. The skeletonization itself -3. Some post-processing of the skeleton (e.g. adding radius information) +3. Some post-processing of the skeleton (e.g. adding radius information, smoothing, etc.) ------ @@ -84,6 +84,9 @@ | **postprocessing** | | | `skeletor.post.clean_up()` | fix some potential errors in the skeleton | | `skeletor.post.radii()` | add radius information using various method | +| `skeletor.post.smooth()` | smooth the skeleton | +| `skeletor.post.remove_bristles()` | remove single-node bristles from the skeleton | +| `skeletor.post.despike()` | smooth out jumps in the skeleton | ------ @@ -194,7 +197,7 @@ >>> mesh = sk.example_mesh() >>> # Alternatively use trimesh to load/construct your own mesh: >>> # import trimesh as tm ->>> # mesh = tm.Trimesh(vertices, faces) +>>> # mesh = tm.Trimesh(vertices, faces) >>> # mesh = tm.load_mesh('some/mesh.obj') >>> # Run some general clean-up (see docstring for details) >>> fixed = sk.pre.fix_mesh(mesh, remove_disconnected=5, inplace=False) @@ -217,9 +220,12 @@ - meshes need to be triangular (we are using `trimesh`) - use `sk.pre.simplify` if your mesh is very complex (half a million vertices is where things start getting sluggish) -- a good mesh contraction is often half the battle +- a good mesh contraction is often half the battle but it can be tricky to get + to work - if the mesh consists of multiple disconnected pieces the skeleton will likewise be fragmented (i.e. will have multiple roots) +- it's often a good idea to fix issues with the skeleton in postprocessing rather + than trying to get the skeletonization to be perfect # Benchmarks @@ -256,8 +262,8 @@ """ -__version__ = "1.2.3" -__version_vector__ = (1, 2, 3) +__version__ = "1.3.0" +__version_vector__ = (1, 3, 0) from . import skeletonize from . import pre diff --git a/skeletor/post/__init__.py b/skeletor/post/__init__.py index 1e2639f..866dee0 100644 --- a/skeletor/post/__init__.py +++ b/skeletor/post/__init__.py @@ -29,6 +29,13 @@ centered inside the mesh - superfluous "hairs" on otherwise straight bits +`skeletor.post.smooth` will smooth out the skeleton. + +`skeletor.post.despike` can help you remove spikes in the skeleton where +single nodes are out of aligment. + +`skeletor.post.remove_bristles` will remove bristles from the skeleton. + ### Computing radius information Only `skeletor.skeletonize.by_wavefront()` provides radii off the bat. For all @@ -38,7 +45,7 @@ """ from .radiusextraction import radii -from .postprocessing import clean_up +from .postprocessing import clean_up, smooth, despike, remove_bristles __docformat__ = "numpy" -__all__ = ['radii', 'clean_up'] +__all__ = ["radii", "clean_up", "smooth", "despike", "remove_bristles"] diff --git a/skeletor/post/postprocessing.py b/skeletor/post/postprocessing.py index 856e215..af20ae2 100644 --- a/skeletor/post/postprocessing.py +++ b/skeletor/post/postprocessing.py @@ -88,14 +88,10 @@ def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs): return s -def remove_hairs(s, mesh=None, los_only=True, inplace=False): - """Remove "hairs" that sometimes occurr along the backbone. +def remove_bristles(s, mesh=None, los_only=False, inplace=False): + """Remove "bristles" that sometimes occurr along the backbone. - Works by finding terminal twigs that consist of only a single node. We will - then remove those that are within line of sight of their parent. - - Note that this is currently not used for clean up as it does not work very - well: removes as many correct hairs as genuine small branches. + Works by finding terminal twigs that consist of only a single node. Parameters ---------- @@ -105,16 +101,16 @@ def remove_hairs(s, mesh=None, los_only=True, inplace=False): Original mesh (e.g. before contraction). If not provided will use the mesh associated with ``s``. los_only : bool - If True, will only remove hairs that are in line of sight of - their parent. If False, will remove all single-node hairs. + If True, will only remove bristles that are in line of sight of + their parent. If False, will remove all single-node bristles. inplace : bool If False will make and return a copy of the skeleton. If True, will modify the `s` inplace. Returns ------- - SWC : pandas.DataFrame - SWC with line-of-sight twigs removed. + s : skeletor.Skeleton + Skeleton with single-node twigs removed. """ if isinstance(mesh, type(None)): @@ -198,8 +194,8 @@ def recenter_vertices(s, mesh=None, inplace=False): Returns ------- - SWC : pandas.DataFrame - SWC with line-of-sight twigs removed. + s : skeletor.Skeleton + Skeleton with vertices recentered. """ if isinstance(mesh, type(None)): @@ -347,8 +343,8 @@ def drop_line_of_sight_twigs(s, mesh=None, max_dist='auto', inplace=False): Returns ------- - SWC : pandas.DataFrame - SWC with line-of-sight twigs removed. + s : skeletor.Skeleton + Skeleton with line-of-sight twigs removed. """ # Make a copy of the SWC @@ -472,7 +468,7 @@ def drop_parallel_twigs(s, theta=0.01, inplace=False): Returns ------- - SWC : pandas.DataFrame + s : skeletor.Skeleton SWC with parallel twigs removed. """ @@ -553,3 +549,134 @@ def drop_parallel_twigs(s, theta=0.01, inplace=False): s.reindex(inplace=True) return s + +def smooth(s, + window: int = 3, + to_smooth: list = ['x', 'y', 'z'], + inplace: bool = False): + """Smooth skeleton using rolling windows. + + Parameters + ---------- + s : skeletor.Skeleton + Skeleton to be processed. + window : int, optional + Size (N observations) of the rolling window in number of + nodes. + to_smooth : list + Columns of the node table to smooth. Should work with any + numeric column (e.g. 'radius'). + inplace : bool + If False will make and return a copy of the skeleton. If + True, will modify the `s` inplace. + + Returns + ------- + s : skeletor.Skeleton + Skeleton with smoothed node table. + + """ + if not inplace: + s = s.copy() + + # Prepare nodes (add parent_dist for later, set index) + nodes = s.swc.set_index('node_id', inplace=False).copy() + + to_smooth = np.array(to_smooth) + miss = to_smooth[~np.isin(to_smooth, nodes.columns)] + if len(miss): + raise ValueError(f'Column(s) not found in node table: {miss}') + + # Go over each segment and smooth + for seg in s.get_segments(): + # Get this segment's parent distances and get cumsum + this_co = nodes.loc[seg, to_smooth] + + interp = this_co.rolling(window, min_periods=1).mean() + + nodes.loc[seg, to_smooth] = interp.values + + # Reassign nodes + s.swc = nodes.reset_index(drop=False, inplace=False) + + return s + +def despike(s, + sigma = 5, + max_spike_length = 1, + inplace = False, + reverse = False): + r"""Remove spikes in skeleton. + + For each node A, the euclidean distance to its next successor (parent) + B and that node's successor C (i.e A->B->C) is computed. If + :math:`\\frac{dist(A,B)}{dist(A,C)}>sigma`, node B is considered a spike + and realigned between A and C. + + Parameters + ---------- + x : skeletor.Skeleton + Skeleton to be processed. + sigma : float | int, optional + Threshold for spike detection. Smaller sigma = more + aggressive spike detection. + max_spike_length : int, optional + Determines how long (# of nodes) a spike can be. + inplace : bool, optional + If False, a copy of the neuron is returned. + reverse : bool, optional + If True, will **also** walk the segments from proximal + to distal. Use this to catch spikes on e.g. terminal + nodes. + + Returns + ------- + s skeletor.Skeleton + Despiked neuron. + + """ + if not inplace: + s = s.copy() + + # Index nodes table by node ID + this_nodes = s.nodes.set_index('node_id', inplace=False) + + segments = s.get_segments() + segs_to_walk = segments + + if reverse: + segs_to_walk += segs_to_walk[::-1] + + # For each spike length do -> do this in reverse to correct the long + # spikes first + for l in list(range(1, max_spike_length + 1))[::-1]: + # Go over all segments + for seg in segs_to_walk: + # Get nodes A, B and C of this segment + this_A = this_nodes.loc[seg[:-l - 1]] + this_B = this_nodes.loc[seg[l:-1]] + this_C = this_nodes.loc[seg[l + 1:]] + + # Get coordinates + A = this_A[['x', 'y', 'z']].values + B = this_B[['x', 'y', 'z']].values + C = this_C[['x', 'y', 'z']].values + + # Calculate euclidian distances A->B and A->C + dist_AB = np.linalg.norm(A - B, axis=1) + dist_AC = np.linalg.norm(A - C, axis=1) + + # Get the spikes + spikes_ix = np.where(np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma)[0] + spikes = this_B.iloc[spikes_ix] + + if not spikes.empty: + # Interpolate new position(s) between A and C + new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2 + + this_nodes.loc[spikes.index, ['x', 'y', 'z']] = new_positions + + # Reassign node table + s.swc = this_nodes.reset_index(drop=False, inplace=False) + + return s \ No newline at end of file diff --git a/skeletor/pre/meshcontraction.py b/skeletor/pre/meshcontraction.py index ba1525e..70cc5d9 100644 --- a/skeletor/pre/meshcontraction.py +++ b/skeletor/pre/meshcontraction.py @@ -165,11 +165,17 @@ def contract(mesh, epsilon=1e-06, iter_lim=100, time_lim=None, precision=1e-07, postfix=[1, iter_lim, 1]) as pbar: for i in range(iter_lim): # Get Laplace weights + if operator == 'cotangent': L = laplacian_cotangent(dm, normalized=True) else: L = laplacian_umbrella(dm) - + """ + import robust_laplacian + L, M_ = robust_laplacian.mesh_laplacian(np.array(dm.vertices), + np.array(dm.faces), + mollify_factor=1e-3) + """ V = getMeshVPos(dm) A = sp.sparse.vstack([WL.dot(L), WH]) b = np.vstack((zeros, WH.dot(V)))