Skip to content

Commit

Permalink
add more postprocessing methods:
Browse files Browse the repository at this point in the history
- smooth
- despike
- rename remove_hairs -> remove_bristles
- update docs
  • Loading branch information
schlegelp committed Mar 31, 2024
1 parent 30d58a0 commit 76d99f6
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 25 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ separate functions:
results (SWC, nodes/vertices, edges, etc.) and allows for quick visualization
- SWC tables are now strictly conforming to the format (continuous node IDs,
parents always listed before their childs, etc)
- `Skeleton` results contain a mesh to skeleton mapping as `.mesh_map` property
- `Skeleton` results contain a mesh to skeleton mapping as `.mesh_map` property
- added an example mesh: to load it use `skeletor.example_mesh()`
- `skeletor` now has proper tests and a simple [documentation](https://navis-org.github.io/skeletor/)

Expand Down
16 changes: 11 additions & 5 deletions skeletor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
1. Some pre-processing of the mesh (e.g. fixing some potential errors like
degenerate faces, unreferenced vertices, etc.)
2. The skeletonization itself
3. Some post-processing of the skeleton (e.g. adding radius information)
3. Some post-processing of the skeleton (e.g. adding radius information, smoothing, etc.)
------
Expand All @@ -84,6 +84,9 @@
| **postprocessing** | |
| `skeletor.post.clean_up()` | fix some potential errors in the skeleton |
| `skeletor.post.radii()` | add radius information using various method |
| `skeletor.post.smooth()` | smooth the skeleton |
| `skeletor.post.remove_bristles()` | remove single-node bristles from the skeleton |
| `skeletor.post.despike()` | smooth out jumps in the skeleton |
------
Expand Down Expand Up @@ -194,7 +197,7 @@
>>> mesh = sk.example_mesh()
>>> # Alternatively use trimesh to load/construct your own mesh:
>>> # import trimesh as tm
>>> # mesh = tm.Trimesh(vertices, faces)
>>> # mesh = tm.Trimesh(vertices, faces)
>>> # mesh = tm.load_mesh('some/mesh.obj')
>>> # Run some general clean-up (see docstring for details)
>>> fixed = sk.pre.fix_mesh(mesh, remove_disconnected=5, inplace=False)
Expand All @@ -217,9 +220,12 @@
- meshes need to be triangular (we are using `trimesh`)
- use `sk.pre.simplify` if your mesh is very complex (half a million vertices is
where things start getting sluggish)
- a good mesh contraction is often half the battle
- a good mesh contraction is often half the battle but it can be tricky to get
to work
- if the mesh consists of multiple disconnected pieces the skeleton will
likewise be fragmented (i.e. will have multiple roots)
- it's often a good idea to fix issues with the skeleton in postprocessing rather
than trying to get the skeletonization to be perfect
# Benchmarks
Expand Down Expand Up @@ -256,8 +262,8 @@
"""

__version__ = "1.2.3"
__version_vector__ = (1, 2, 3)
__version__ = "1.3.0"
__version_vector__ = (1, 3, 0)

from . import skeletonize
from . import pre
Expand Down
11 changes: 9 additions & 2 deletions skeletor/post/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
centered inside the mesh
- superfluous "hairs" on otherwise straight bits
`skeletor.post.smooth` will smooth out the skeleton.
`skeletor.post.despike` can help you remove spikes in the skeleton where
single nodes are out of aligment.
`skeletor.post.remove_bristles` will remove bristles from the skeleton.
### Computing radius information
Only `skeletor.skeletonize.by_wavefront()` provides radii off the bat. For all
Expand All @@ -38,7 +45,7 @@
"""

from .radiusextraction import radii
from .postprocessing import clean_up
from .postprocessing import clean_up, smooth, despike, remove_bristles

__docformat__ = "numpy"
__all__ = ['radii', 'clean_up']
__all__ = ["radii", "clean_up", "smooth", "despike", "remove_bristles"]
159 changes: 143 additions & 16 deletions skeletor/post/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,10 @@ def clean_up(s, mesh=None, validate=False, inplace=False, **kwargs):
return s


def remove_hairs(s, mesh=None, los_only=True, inplace=False):
"""Remove "hairs" that sometimes occurr along the backbone.
def remove_bristles(s, mesh=None, los_only=False, inplace=False):
"""Remove "bristles" that sometimes occurr along the backbone.
Works by finding terminal twigs that consist of only a single node. We will
then remove those that are within line of sight of their parent.
Note that this is currently not used for clean up as it does not work very
well: removes as many correct hairs as genuine small branches.
Works by finding terminal twigs that consist of only a single node.
Parameters
----------
Expand All @@ -105,16 +101,16 @@ def remove_hairs(s, mesh=None, los_only=True, inplace=False):
Original mesh (e.g. before contraction). If not provided will
use the mesh associated with ``s``.
los_only : bool
If True, will only remove hairs that are in line of sight of
their parent. If False, will remove all single-node hairs.
If True, will only remove bristles that are in line of sight of
their parent. If False, will remove all single-node bristles.
inplace : bool
If False will make and return a copy of the skeleton. If True,
will modify the `s` inplace.
Returns
-------
SWC : pandas.DataFrame
SWC with line-of-sight twigs removed.
s : skeletor.Skeleton
Skeleton with single-node twigs removed.
"""
if isinstance(mesh, type(None)):
Expand Down Expand Up @@ -198,8 +194,8 @@ def recenter_vertices(s, mesh=None, inplace=False):
Returns
-------
SWC : pandas.DataFrame
SWC with line-of-sight twigs removed.
s : skeletor.Skeleton
Skeleton with vertices recentered.
"""
if isinstance(mesh, type(None)):
Expand Down Expand Up @@ -347,8 +343,8 @@ def drop_line_of_sight_twigs(s, mesh=None, max_dist='auto', inplace=False):
Returns
-------
SWC : pandas.DataFrame
SWC with line-of-sight twigs removed.
s : skeletor.Skeleton
Skeleton with line-of-sight twigs removed.
"""
# Make a copy of the SWC
Expand Down Expand Up @@ -472,7 +468,7 @@ def drop_parallel_twigs(s, theta=0.01, inplace=False):
Returns
-------
SWC : pandas.DataFrame
s : skeletor.Skeleton
SWC with parallel twigs removed.
"""
Expand Down Expand Up @@ -553,3 +549,134 @@ def drop_parallel_twigs(s, theta=0.01, inplace=False):
s.reindex(inplace=True)

return s

def smooth(s,
window: int = 3,
to_smooth: list = ['x', 'y', 'z'],
inplace: bool = False):
"""Smooth skeleton using rolling windows.
Parameters
----------
s : skeletor.Skeleton
Skeleton to be processed.
window : int, optional
Size (N observations) of the rolling window in number of
nodes.
to_smooth : list
Columns of the node table to smooth. Should work with any
numeric column (e.g. 'radius').
inplace : bool
If False will make and return a copy of the skeleton. If
True, will modify the `s` inplace.
Returns
-------
s : skeletor.Skeleton
Skeleton with smoothed node table.
"""
if not inplace:
s = s.copy()

# Prepare nodes (add parent_dist for later, set index)
nodes = s.swc.set_index('node_id', inplace=False).copy()

to_smooth = np.array(to_smooth)
miss = to_smooth[~np.isin(to_smooth, nodes.columns)]
if len(miss):
raise ValueError(f'Column(s) not found in node table: {miss}')

# Go over each segment and smooth
for seg in s.get_segments():
# Get this segment's parent distances and get cumsum
this_co = nodes.loc[seg, to_smooth]

interp = this_co.rolling(window, min_periods=1).mean()

nodes.loc[seg, to_smooth] = interp.values

# Reassign nodes
s.swc = nodes.reset_index(drop=False, inplace=False)

return s

def despike(s,
sigma = 5,
max_spike_length = 1,
inplace = False,
reverse = False):
r"""Remove spikes in skeleton.
For each node A, the euclidean distance to its next successor (parent)
B and that node's successor C (i.e A->B->C) is computed. If
:math:`\\frac{dist(A,B)}{dist(A,C)}>sigma`, node B is considered a spike
and realigned between A and C.
Parameters
----------
x : skeletor.Skeleton
Skeleton to be processed.
sigma : float | int, optional
Threshold for spike detection. Smaller sigma = more
aggressive spike detection.
max_spike_length : int, optional
Determines how long (# of nodes) a spike can be.
inplace : bool, optional
If False, a copy of the neuron is returned.
reverse : bool, optional
If True, will **also** walk the segments from proximal
to distal. Use this to catch spikes on e.g. terminal
nodes.
Returns
-------
s skeletor.Skeleton
Despiked neuron.
"""
if not inplace:
s = s.copy()

# Index nodes table by node ID
this_nodes = s.nodes.set_index('node_id', inplace=False)

segments = s.get_segments()
segs_to_walk = segments

if reverse:
segs_to_walk += segs_to_walk[::-1]

# For each spike length do -> do this in reverse to correct the long
# spikes first
for l in list(range(1, max_spike_length + 1))[::-1]:
# Go over all segments
for seg in segs_to_walk:
# Get nodes A, B and C of this segment
this_A = this_nodes.loc[seg[:-l - 1]]
this_B = this_nodes.loc[seg[l:-1]]
this_C = this_nodes.loc[seg[l + 1:]]

# Get coordinates
A = this_A[['x', 'y', 'z']].values
B = this_B[['x', 'y', 'z']].values
C = this_C[['x', 'y', 'z']].values

# Calculate euclidian distances A->B and A->C
dist_AB = np.linalg.norm(A - B, axis=1)
dist_AC = np.linalg.norm(A - C, axis=1)

# Get the spikes
spikes_ix = np.where(np.divide(dist_AB, dist_AC, where=dist_AC != 0) > sigma)[0]
spikes = this_B.iloc[spikes_ix]

if not spikes.empty:
# Interpolate new position(s) between A and C
new_positions = A[spikes_ix] + (C[spikes_ix] - A[spikes_ix]) / 2

this_nodes.loc[spikes.index, ['x', 'y', 'z']] = new_positions

# Reassign node table
s.swc = this_nodes.reset_index(drop=False, inplace=False)

return s
8 changes: 7 additions & 1 deletion skeletor/pre/meshcontraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,17 @@ def contract(mesh, epsilon=1e-06, iter_lim=100, time_lim=None, precision=1e-07,
postfix=[1, iter_lim, 1]) as pbar:
for i in range(iter_lim):
# Get Laplace weights

if operator == 'cotangent':
L = laplacian_cotangent(dm, normalized=True)
else:
L = laplacian_umbrella(dm)

"""
import robust_laplacian
L, M_ = robust_laplacian.mesh_laplacian(np.array(dm.vertices),
np.array(dm.faces),
mollify_factor=1e-3)
"""
V = getMeshVPos(dm)
A = sp.sparse.vstack([WL.dot(L), WH])
b = np.vstack((zeros, WH.dot(V)))
Expand Down

0 comments on commit 76d99f6

Please sign in to comment.