From c835e53f64c4609647642d424eee4ed2addee85a Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 24 Oct 2024 12:20:24 +0100 Subject: [PATCH 1/6] teach submesh to optionally return a vertex/face map --- navis/morpho/subset.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/navis/morpho/subset.py b/navis/morpho/subset.py index 82c5b438..3e75a89f 100644 --- a/navis/morpho/subset.py +++ b/navis/morpho/subset.py @@ -359,7 +359,7 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): return x -def submesh(mesh, *, faces_index=None, vertex_index=None): +def submesh(mesh, *, faces_index=None, vertex_index=None, return_map=False): """Re-imlementation of trimesh.submesh that is faster for our use case. Notably we: @@ -382,6 +382,9 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): Indices of faces to keep. vertex_index : array-like Indices of vertices to keep. + return_map : bool, optional + If True, will return a mapping of old to new vertex and + face indices. Returns ------- @@ -389,6 +392,14 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): Vertices of submesh. faces : np.ndarray Faces of submesh. + vert_map : np.ndarray + Only returned if `return_map` is True. Mapping of old vertex indices + to new vertex indices. Vertices that are not in the submesh have a + value of -1. + face_map : np.ndarray + Only returned if `return_map` is True. Mapping of old face indices + to new face indices. Faces that are not in the submesh have a value + of -1. """ if faces_index is None and vertex_index is None: @@ -439,4 +450,14 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): # (making a copy to allow `mask` to be garbage collected) faces = mask[faces].copy() - return vertices, faces + if not return_map: + return vertices, faces + else: + face_map = np.full(len(original_faces), -1, dtype=np.int32) + face_map[faces_index] = np.arange(len(faces_index)) + vert_map = np.full(len(original_vertices), -1, dtype=np.int32) + vert_map[unique] = np.arange(len(unique)) + return vertices, faces, vert_map, face_map + + + From 6a501ae97cb76a6b0644ff020688d90b1bf06d5b Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 24 Oct 2024 12:20:44 +0100 Subject: [PATCH 2/6] fix bug in subset --- navis/morpho/subset.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/navis/morpho/subset.py b/navis/morpho/subset.py index 3e75a89f..c828f13d 100644 --- a/navis/morpho/subset.py +++ b/navis/morpho/subset.py @@ -308,13 +308,9 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): axis=1, ) - # Make sure any new roots or leafs are properly typed - # We won't produce new slabs but roots and leaves might change - x.nodes.loc[x.nodes.parent_id < 0, "type"] = "root" - x.nodes.loc[ - (~x.nodes.node_id.isin(x.nodes.parent_id.values) & (x.nodes.parent_id >= 0)), - "type", - ] = "end" + # Make sure nodes are correctly classified (need to do this because we're not actually clearing + # the temporary attributes) + graph.classify_nodes(x, inplace=True) # Filter connectors if not keep_disc_cn and x.has_connectors: From 88a2dec1225dad2cc294e076ede1bf63a89d0f9a Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 24 Oct 2024 12:44:21 +0100 Subject: [PATCH 3/6] first pass at masking interface: - add mask/unmask and apply_mask methods to TreeNeuron, MeshNeuron and Dotprops - add is_masked property for all neurons - add `navis.NeuronMask` class - add __length__ to all neurons - dotprops: clear `_tree` with temporary attributes --- docs/api.md | 23 +- navis/core/__init__.py | 15 +- navis/core/base.py | 96 +++++- navis/core/dotprop.py | 139 ++++++++- navis/core/masking.py | 204 ++++++++++++ navis/core/mesh.py | 399 ++++++++++++++++++++---- navis/core/skeleton.py | 692 +++++++++++++++++++++++++++-------------- navis/core/voxel.py | 9 + 8 files changed, 1261 insertions(+), 316 deletions(-) create mode 100644 navis/core/masking.py diff --git a/docs/api.md b/docs/api.md index a42622a8..fb7cfdc6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -47,13 +47,13 @@ learn more! ``TreeNeurons``, ``MeshNeurons``, ``VoxelNeurons`` and ``Dotprops`` are neuron classes. ``NeuronLists`` are containers thereof. -| Class | Description | -|------|------| -| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | -| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | -| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | -| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | -| [`navis.NeuronList`][] | Containers for neurons. | +| Class | Description | +|-------------------------|---------------------------------------------------------| +| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | +| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | +| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | +| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | +| [`navis.NeuronList`][] | Containers for neurons. | ### General Neuron methods @@ -89,6 +89,7 @@ to all neurons: | `Neuron.type` | {{ autosummary("navis.BaseNeuron.type") }} | | `Neuron.soma` | {{ autosummary("navis.BaseNeuron.soma") }} | | `Neuron.bbox` | {{ autosummary("navis.BaseNeuron.bbox") }} | +| `Neuron.is_masked` | {{ autosummary("navis.BaseNeuron.is_masked") }} | !!! note @@ -119,6 +120,8 @@ this neuron type. Note that most of them are simply short-hands for the other | [`TreeNeuron.reroot()`][navis.TreeNeuron.reroot] | {{ autosummary("navis.TreeNeuron.reroot") }} | | [`TreeNeuron.resample()`][navis.TreeNeuron.resample] | {{ autosummary("navis.TreeNeuron.resample") }} | | [`TreeNeuron.snap()`][navis.TreeNeuron.snap] | {{ autosummary("navis.TreeNeuron.snap") }} | +| [`TreeNeuron.mask()`][navis.TreeNeuron.mask] | {{ autosummary("navis.TreeNeuron.mask") }} | +| [`TreeNeuron.unmask()`][navis.TreeNeuron.unmask] | {{ autosummary("navis.TreeNeuron.unmask") }} | In addition, a [`navis.TreeNeuron`][] has a range of different properties: @@ -146,7 +149,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`TreeNeuron.vertices`][navis.TreeNeuron.vertices] | {{ autosummary("navis.TreeNeuron.vertices") }} | | [`TreeNeuron.volume`][navis.TreeNeuron.volume] | {{ autosummary("navis.TreeNeuron.volume") }} | - #### Skeleton utility functions | Function | Description | @@ -158,7 +160,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`navis.graph.skeleton_adjacency_matrix()`][navis.graph.skeleton_adjacency_matrix] | {{ autosummary("navis.graph.skeleton_adjacency_matrix") }} | - ### Mesh neurons Properties specific to [`navis.MeshNeuron`][]: @@ -178,6 +179,8 @@ Methods specific to [`navis.MeshNeuron`][]: | [`MeshNeuron.skeletonize()`][navis.MeshNeuron.skeletonize] | {{ autosummary("navis.MeshNeuron.skeletonize") }} | | [`MeshNeuron.snap()`][navis.MeshNeuron.snap] | {{ autosummary("navis.MeshNeuron.snap") }} | | [`MeshNeuron.validate()`][navis.MeshNeuron.validate] | {{ autosummary("navis.MeshNeuron.validate") }} | +| [`MeshNeuron.mask()`][navis.MeshNeuron.mask] | {{ autosummary("navis.MeshNeuron.mask") }} | +| [`MeshNeuron.unmask()`][navis.MeshNeuron.unmask] | {{ autosummary("navis.MeshNeuron.unmask") }} | ### Voxel neurons @@ -215,6 +218,8 @@ These are methods and properties specific to [Dotprops][navis.Dotprops]: | [`Dotprops.alpha`][navis.Dotprops.alpha] | {{ autosummary("navis.Dotprops.alpha") }} | | [`Dotprops.to_skeleton()`][navis.Dotprops.to_skeleton] | {{ autosummary("navis.Dotprops.to_skeleton") }} | | [`Dotprops.snap()`][navis.Dotprops.snap] | {{ autosummary("navis.Dotprops.snap") }} | +| [`Dotprops.mask()`][navis.Dotprops.mask] | {{ autosummary("navis.Dotprops.mask") }} | +| [`Dotprops.unmask()`][navis.Dotprops.unmask] | {{ autosummary("navis.Dotprops.unmask") }} | ### Converting between types diff --git a/navis/core/__init__.py b/navis/core/__init__.py index 39b4342e..80973700 100644 --- a/navis/core/__init__.py +++ b/navis/core/__init__.py @@ -18,11 +18,22 @@ from .dotprop import Dotprops from .voxel import VoxelNeuron from .neuronlist import NeuronList +from .masking import NeuronMask from .core_utils import make_dotprops, to_neuron_space, NeuronProcessor from typing import Union NeuronObject = Union[NeuronList, TreeNeuron, BaseNeuron, MeshNeuron] -__all__ = ['Volume', 'Neuron', 'BaseNeuron', 'TreeNeuron', 'MeshNeuron', - 'Dotprops', 'VoxelNeuron', 'NeuronList', 'make_dotprops'] +__all__ = [ + "Volume", + "Neuron", + "BaseNeuron", + "TreeNeuron", + "MeshNeuron", + "NeuronMask", + "Dotprops", + "VoxelNeuron", + "NeuronList", + "make_dotprops", +] diff --git a/navis/core/base.py b/navis/core/base.py index 73bf2bc4..009b0a66 100644 --- a/navis/core/base.py +++ b/navis/core/base.py @@ -46,7 +46,8 @@ def Neuron( - x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata + x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], + **metadata, # noqa: F821 ): """Constructor for Neuron objects. Depending on the input, either a `TreeNeuron` or a `MeshNeuron` is returned. @@ -195,6 +196,9 @@ class BaseNeuron(UnitObject): #: Core data table(s) used to calculate hash CORE_DATA = [] + #: Property used to calculate length of neuron + _LENGTH_DATA = None + def __init__(self, **kwargs): # Set a random ID -> may be replaced later self.id = uuid.uuid4() @@ -303,6 +307,14 @@ def __isub__(self, other): """Subtraction with assignment (-=).""" return self.__sub__(other, copy=False) + def __len__(self): + if self._LENGTH_DATA is None: + return None + # Deal with potential empty neurons + if not hasattr(self, self._LENGTH_DATA): + return 0 + return len(getattr(self, self._LENGTH_DATA)) + def _repr_html_(self): frame = self.summary().to_frame() frame.columns = [""] @@ -654,6 +666,7 @@ def copy(self, deepcopy=False) -> "BaseNeuron": def summary(self, add_props=None) -> pd.Series: """Get a summary of this neuron.""" + # Do not remove the list -> otherwise we might change the original! props = list(self.SUMMARY_PROPS) @@ -721,6 +734,87 @@ def plot3d(self, **kwargs): return plot3d(core.NeuronList(self, make_copy=False), **kwargs) + @property + def is_masked(self): + """Test if neuron is masked. + + See Also + -------- + [`navis.BaseNeuron.mask`][] + Mask neuron. + [`navis.BaseNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + """ + return hasattr(self, "_masked_data") + + def mask(self, mask): + """Mask neuron.""" + raise NotImplementedError( + f"Masking not implemented for neuron of type {type(self)}." + ) + + def unmask(self): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Returns + ------- + self + + See Also + -------- + [`Neuron.is_masked`][navis.BaseNeuron.is_masked] + Check if neuron. is masked. + [`Neuron.mask`][navis.BaseNeuron.unmask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + for k, v in self._masked_data.items(): + if hasattr(self, k): + setattr(self, k, v) + + delattr(self, "_mask") + delattr(self, "_masked_data") + self._clear_temp_attr() + + return self + + def apply_mask(self, inplace=False): + """Apply mask to neuron. + + This will effectively make the mask permanent. + + Parameters + ---------- + inplace : bool + If True will apply mask in-place. If False + will return a copy and the original neuron + will remain masked. + + Returns + ------- + Neuron + Neuron with mask applied. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + n = self if inplace else self.copy() + + delattr(n, "_mask") + delattr(n, "_masked_data") + + return n + def map_units( self, units: Union[pint.Unit, str], diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index 0f51dd1b..d3a69cdc 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -105,11 +105,14 @@ class Dotprops(BaseNeuron): EQ_ATTRIBUTES = ['name', 'n_points', 'k'] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage'] + TEMP_ATTR = ['_memory_usage', "_tree"] #: Core data table(s) used to calculate hash _CORE_DATA = ['points', 'vect'] + #: Property used to calculate length of neuron + _LENGTH_DATA = 'points' + def __init__(self, points: np.ndarray, k: int, @@ -230,9 +233,6 @@ def __getstate__(self): return state - def __len__(self): - return len(self.points) - @property def alpha(self): """Alpha value for tangent vectors (optional).""" @@ -539,6 +539,137 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl if not inplace: return x + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with property name + + Returns + ------- + self + The masked neuron. + + See Also + -------- + [`Dotprops.unmask`][navis.Dotprops.unmask] + Remove mask from neuron. + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = getattr(self, mask) + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != len(self): + raise ValueError("Mask must have same length as points.") + + self._mask = mask + self._masked_data = {} + self._masked_data['_points'] = self.points + + # Drop soma if masked out + if self.soma is not None: + if isinstance(self.soma, (list, np.ndarray)): + soma_left = self.soma[mask[self.soma]] + self._masked_data['_soma'] = self.soma + + if any(soma_left): + self.soma = soma_left + else: + self.soma = None + elif not mask[self.soma]: + self._masked_data['_soma'] = self.soma + self.soma = None + + # N.B. we're directly setting `._nodes`` to avoid overhead from checks + for att in ("_points", "_vect", "_alpha"): + if hasattr(self, att): + self._masked_data[att] = getattr(self, att) + setattr(self, att, getattr(self, att)[mask]) + + if copy: + setattr(self, att, getattr(self, att).copy()) + + if hasattr(self, "_connectors") and "point_ix" in self._connectors.columns: + self._masked_data['connectors'] = self.connectors + self._connectors = self._connectors.loc[ + self.connectors.point_ix.isin(np.arange(len(mask))[mask]) + ] + if copy: + self._connectors = self._connectors.copy() + + self._clear_temp_attr() + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`Dotprops.mask`][navis.Dotprops.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + for k, v in self._masked_data.items(): + # Combine with current data + if hasattr(self, k): + v = np.concatenate((v[~mask], getattr(self, k)), axis=0) + setattr(self, k, v) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self + def recalculate_tangents(self, k: int, inplace=False): """Recalculate tangent vectors and alpha with a new `k`. diff --git a/navis/core/masking.py b/navis/core/masking.py new file mode 100644 index 00000000..345de824 --- /dev/null +++ b/navis/core/masking.py @@ -0,0 +1,204 @@ +# This script is part of navis (http://www.github.com/navis-org/navis). +# Copyright (C) 2018 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import numpy as np +import pandas as pd + +from .neuronlist import NeuronList +from .skeleton import TreeNeuron +from .dotprop import Dotprops +from .voxel import VoxelNeuron +from .mesh import MeshNeuron + +from .. import utils + +__all__ = ["NeuronMask"] + +# mode = "r"\w"? +class NeuronMask: + """Mask neuron(s) by a specific property. + + Parameters + ---------- + x : Neuron/List + Neuron(s) to mask. + mask : str | array | callable | list | dict + The mask to apply: + - str: The name of the property to mask by + - array: boolean mask + - callable: A function that takes a neuron as input + and returns a boolean mask + - list: a list of the above + - dict: a dictionary mapping neuron IDs to one of the + above + copy_data : bool + Whether to copy the neuron data (e.g. node table for skeletons) + when masking. Setting this to False will may some time and + memory but may lead to e.g. pandas setting-on-copy warnings + if the data is modified. Only set to `True` if you know your + code won't modify the data. + reset_neurons : bool + If True, reset the neurons to their original state after the + context manager exits. If False, will try to incorporate any + changes made to the masked neurons. Note that this may not + work for destructive operations. + validate_mask : bool + If True, validate `mask` against the neurons before setting it. + This is recommended but can come with an overhead (in particular + if `mask` is a callable). + + Examples + -------- + >>> import navis + >>> # Grab a few skeletons + >>> nl = navis.example_neurons(3) + >>> # Label axon and dendrites + >>> navis.split_axon_dendrite(nl, label_only=True) + >>> # Mask by axon + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'): + ... print("Axon cable length:", nl.cable_length * nl[0].units) + Axon cable length: [363469.75 411147.1875 390231.8125] nanometer + >>> # Mask by dendrite + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'dendrite'): + ... print("Dendrite cable length:", nl.cable_length * nl[0].units) + Dendrite cable length: [1410770.0 1612187.25 1510453.875] nanometer + + See Also + -------- + [`navis.BaseNeuron.is_masked`][] + Check if a neuron is masked. Property exists for all neuron types. + [`navis.BaseNeuron.mask`][] + Mask a neuron. Implementation details depend on the neuron type. + [`navis.BaseNeuron.unmask`][] + Unmask a neuron. Implementation details depend on the neuron type. + + """ + + def __init__(self, x, mask, reset_neurons=True, copy_data=True, validate_mask=True): + self.neurons = x + + if validate_mask: + self.mask = mask + else: + self._mask = mask + + self.reset = reset_neurons + self.copy = copy_data + + @property + def neurons(self): + return self._neurons + + @neurons.setter + def neurons(self, value): + self._neurons = NeuronList(value) + + if any(n.is_masked for n in self._neurons): + raise MaskingError("At least some neuron(s) are already masked") + + @property + def mask(self): + return self._mask + + @mask.setter + def mask(self, mask): + # Validate the mask + if isinstance(mask, str): + for n in self.neurons: + if isinstance(n, TreeNeuron): + if mask not in n.nodes.columns: + raise MaskingError(f"Neuron does not have '{mask}' column") + elif not hasattr(n, mask): + raise MaskingError(f"Neuron does not have '{mask}' attribute") + elif isinstance(mask, (list, np.ndarray, pd.Series)): + if len(self.neurons) == 1 and len(mask) != 1: + # If we only have one neuron, we can accept a single mask + # but we still want to wrap it in a list for consistency + mask = [np.asarray(mask)] + + if len(mask) != len(self.neurons): + raise MaskingError("Number of masks does not match number of neurons") + + # Validate each mask + for m, n in zip(mask, self.neurons): + validate_mask_length(m, n) + elif isinstance(mask, dict): + for n in self.neurons: + if n.id not in mask: + raise MaskingError(f"Neuron {n.id} not in mask dictionary") + validate_mask_length(mask[n.id], n) + elif callable(mask): + # If this is a function, try calling it on the first neuron + test = mask(self.neurons[0]) + if not isinstance(test, (pd.Series, np.ndarray)) or test.dtype != bool: + raise MaskingError("Callable mask must return a boolean array") + validate_mask_length(test, self.neurons[0]) + else: + raise MaskingError(f"Unexpected mask type: {type(mask)}") + + self._mask = mask + + def __enter__(self): + for i, n in enumerate(self.neurons): + if callable(self.mask): + mask = self.mask(n) + elif isinstance(self.mask, dict): + mask = self.mask[n.id] + elif isinstance(self.mask, str): + mask = self.mask + else: + mask = self.mask[i] + + n.mask(mask, copy=self.copy) + + return self + + def __exit__(self, *args): + for i, n in enumerate(self.neurons): + n.unmask(reset=self.reset) + + +def validate_mask_length(mask, neuron): + """Validate mask length for a given neuron.""" + if callable(mask): + mask = mask(neuron) + elif isinstance(mask, str): + if isinstance(neuron, TreeNeuron): + mask = neuron.nodes[mask] + else: + mask = getattr(neuron, mask) + + if isinstance(mask, list): + mask = np.asarray(mask) + + if not isinstance(mask, (np.ndarray, pd.Series)) or mask.dtype != bool: + raise MaskingError("Mask must be a boolean array") + + if isinstance(neuron, TreeNeuron): + if len(mask) != len(neuron.nodes): + raise MaskingError("Mask length does not match number of nodes") + elif isinstance(neuron, VoxelNeuron): + if neuron._base_data_type == "grid" and mask.shape != neuron.shape: + raise MaskingError("Mask shape does not match voxel shape") + elif len(neuron.voxels) != len(mask): + raise MaskingError("Mask length does not match number of voxels") + elif isinstance(neuron, Dotprops): + if len(mask) != len(neuron.points): + raise MaskingError("Mask length does not match number of points") + elif isinstance(neuron, MeshNeuron): + if len(mask) != len(neuron.vertices): + raise MaskingError("Mask length does not match number of vertices") + + +class MaskingError(Exception): + pass \ No newline at end of file diff --git a/navis/core/mesh.py b/navis/core/mesh.py index 233076e0..d6f78db0 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -26,7 +26,7 @@ from typing import Union, Optional -from .. import utils, config, meshes, conversion, graph +from .. import utils, config, meshes, conversion, graph, morpho from .base import BaseNeuron from .neuronlist import NeuronList from .skeleton import TreeNeuron @@ -39,7 +39,7 @@ xxhash = None -__all__ = ['MeshNeuron'] +__all__ = ["MeshNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -89,24 +89,28 @@ class MeshNeuron(BaseNeuron): soma: Optional[Union[list, np.ndarray]] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'units', 'n_vertices', 'n_faces'] + SUMMARY_PROPS = ["type", "name", "units", "n_vertices", "n_faces"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name', 'n_vertices', 'n_faces'] + EQ_ATTRIBUTES = ["name", "n_vertices", "n_faces"] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage', '_trimesh', '_skeleton', '_igraph', '_graph_nx'] + TEMP_ATTR = ["_memory_usage", "_trimesh", "_skeleton", "_igraph", "_graph_nx"] #: Core data table(s) used to calculate hash - CORE_DATA = ['vertices', 'faces'] - - def __init__(self, - x, - units: Union[pint.Unit, str] = None, - process: bool = True, - validate: bool = False, - **metadata - ): + CORE_DATA = ["vertices", "faces"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "vertices" + + def __init__( + self, + x, + units: Union[pint.Unit, str] = None, + process: bool = True, + validate: bool = False, + **metadata, + ): """Initialize Mesh Neuron.""" super().__init__() @@ -117,12 +121,12 @@ def __init__(self, if isinstance(x, MeshNeuron): self.__dict__.update(x.copy().__dict__) self.vertices, self.faces = x.vertices, x.faces - elif hasattr(x, 'faces') and hasattr(x, 'vertices'): + elif hasattr(x, "faces") and hasattr(x, "vertices"): self.vertices, self.faces = x.vertices, x.faces elif isinstance(x, dict): - if 'faces' not in x or 'vertices' not in x: + if "faces" not in x or "vertices" not in x: raise ValueError('Dictionary must contain "vertices" and "faces"') - self.vertices, self.faces = x['vertices'], x['faces'] + self.vertices, self.faces = x["vertices"], x["faces"] elif isinstance(x, str) and os.path.isfile(x): m = tm.load(x) self.vertices, self.faces = m.vertices, m.faces @@ -134,10 +138,12 @@ def __init__(self, self._skeleton = TreeNeuron(x) elif isinstance(x, tuple): if len(x) != 2 or any([not isinstance(v, np.ndarray) for v in x]): - raise TypeError('Expect tuple to be two arrays: (vertices, faces)') + raise TypeError("Expect tuple to be two arrays: (vertices, faces)") self.vertices, self.faces = x[0], x[1] else: - raise utils.ConstructionError(f'Unable to construct MeshNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct MeshNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -147,9 +153,9 @@ def __init__(self, if process and self.vertices.shape[0]: # For some reason we can't do self._trimesh at this stage - _trimesh = tm.Trimesh(self.vertices, self.faces, - process=process, - validate=validate) + _trimesh = tm.Trimesh( + self.vertices, self.faces, process=process, validate=validate + ) self.vertices = _trimesh.vertices self.faces = _trimesh.faces @@ -174,8 +180,8 @@ def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not callable(v)} # We don't need the trimesh object - if '_trimesh' in state: - _ = state.pop('_trimesh') + if "_trimesh" in state: + _ = state.pop("_trimesh") return state @@ -188,9 +194,9 @@ def __truediv__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.divide(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.divide(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] /= other + n.connectors.loc[:, ["x", "y", "z"]] /= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -209,9 +215,9 @@ def __mul__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.multiply(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.multiply(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] *= other + n.connectors.loc[:, ["x", "y", "z"]] *= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -229,9 +235,9 @@ def __add__(self, other, copy=True): """Implement addition for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.add(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.add(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] += other + n.connectors.loc[:, ["x", "y", "z"]] += other self._clear_temp_attr() @@ -245,9 +251,9 @@ def __sub__(self, other, copy=True): """Implement subtraction for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.subtract(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.subtract(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] -= other + n.connectors.loc[:, ["x", "y", "z"]] -= other self._clear_temp_attr() @@ -261,8 +267,8 @@ def bbox(self) -> np.ndarray: mx = np.max(self.vertices, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -279,7 +285,7 @@ def vertices(self, verts): if not isinstance(verts, np.ndarray): raise TypeError(f'Vertices must be numpy array, got "{type(verts)}"') if verts.ndim != 2: - raise ValueError('Vertices must be 2-dimensional array') + raise ValueError("Vertices must be 2-dimensional array") self._vertices = verts self._clear_temp_attr() @@ -293,16 +299,16 @@ def faces(self, faces): if not isinstance(faces, np.ndarray): raise TypeError(f'Faces must be numpy array, got "{type(faces)}"') if faces.ndim != 2: - raise ValueError('Faces must be 2-dimensional array') + raise ValueError("Faces must be 2-dimensional array") self._faces = faces self._clear_temp_attr() @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of the vertex connectivity.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph @@ -312,7 +318,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of the vertex connectivity.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute self._graph_nx = graph.neuron2nx(self) return self._graph_nx @@ -335,13 +341,13 @@ def volume(self) -> float: @property @temp_property - def skeleton(self) -> 'TreeNeuron': + def skeleton(self) -> "TreeNeuron": """Skeleton representation of this neuron. Uses [`navis.conversion.mesh2skeleton`][]. """ - if not hasattr(self, '_skeleton'): + if not hasattr(self, "_skeleton"): self._skeleton = self.skeletonize() return self._skeleton @@ -357,7 +363,9 @@ def skeleton(self, s): @property def soma(self): """Not implemented for MeshNeurons - use `.soma_pos`.""" - raise AttributeError("MeshNeurons have a soma position (`.soma_pos`), not a soma.") + raise AttributeError( + "MeshNeurons have a soma position (`.soma_pos`), not a soma." + ) @property def soma_pos(self): @@ -365,7 +373,7 @@ def soma_pos(self): Returns `None` if no soma. """ - return getattr(self, '_soma_pos', None) + return getattr(self, "_soma_pos", None) @soma_pos.setter def soma_pos(self, value): @@ -377,38 +385,295 @@ def soma_pos(self, value): try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) self._soma_pos = value @property def type(self) -> str: """Neuron type.""" - return 'navis.MeshNeuron' + return "navis.MeshNeuron" @property @temp_property def trimesh(self): """Trimesh representation of the neuron.""" - if not getattr(self, '_trimesh', None): - self._trimesh = tm.Trimesh(vertices=self._vertices, - faces=self._faces, - process=False) + if not getattr(self, "_trimesh", None): + if hasattr(self, "extra_edges"): + # Only use TrimeshPlus if we actually need it + # to avoid unnecessarily breaking stuff elsewhere + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) + self._trimesh.extra_edges = self.extra_edges + else: + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) return self._trimesh - def copy(self) -> 'MeshNeuron': + def copy(self) -> "MeshNeuron": """Return a copy of the neuron.""" - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new neuron x = self.__class__(None) # Override with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x - def snap(self, locs, to='vertices'): + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - string with property name + - callable that accepts a neuron and returns a valid mask + The mask can be either for vertices or faces but will ultimately be + used to mask out faces. Vertices not participating in any face + will be removed regardless of the mask. + copy : bool + Whether to copy mask a copy of the data. Only applies for connectors. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.unmask`][navis.MeshNeuron.unmask] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked! Layering multiple masks is currently not supported. " + "Please either apply the existing mask or unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = getattr(self, mask) + + mask = np.asarray(mask) + + # Some checks + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif len(mask) not in (self.vertices.shape[0], self.faces.shape[0]): + raise ValueError("Mask length does not match number of vertices or faces.") + + # Transate vertex mask to face mask + if mask.shape[0] == self.vertices.shape[0]: + vert_mask = mask + face_mask = np.all(mask[self.faces], axis=1) + else: + face_mask = mask + vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) + vert_mask[np.unique(self.faces[face_mask])] = True + + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + self, vertex_index=np.where(vert_mask)[0], return_map=True + ) + + # The above will have likely dropped some vertices - we need to update the vertex mask + vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) + vert_mask[np.where(vert_map != -1)[0]] = True + + # Track mask, vertices and faces before masking + self._mask = face_mask # mask is always the face mask + self._masked_data = {} + self._masked_data["_vertices"] = self._vertices + self._masked_data["_faces"] = self._faces + + # Update vertices and faces + self._vertices = verts_new + self._faces = faces_new + + # See if we can mask the mesh's skeleton as well + if hasattr(self, "_skeleton"): + # If the skeleton has a vertex map, we can use it to mask the skeleton + if hasattr(self._skeleton, "vertex_map"): + # Generate a mask for the skeleton + # (keep in mind vertex_map are node IDs, not indices) + sk_mask = self._skeleton.nodes.node_id.isin( + self._skeleton.vertex_map[vert_mask] + ) + + # Apply mask + self._skeleton.mask(sk_mask) + + # Last but not least: we need to update the vertex map + # Track the old map. N.B. we're not adding this to + # skeleton._masked_data since the remapping is done by + # the MeshNeuron itself! + self._skeleton._vertex_map_unmasked = self._skeleton.vertex_map + + # Subset the vertex map to the surviving mesh vertices + # N.B. that the node IDs don't change when masking skeletons! + self._skeleton.vertex_map = self._skeleton.vertex_map[vert_mask] + # If the skeleton has no vertex map, we have to ditch it and + # let it be regenerated when needed + else: + self._masked_data["_skeleton"] = self._skeleton + self._skeleton = None # Clear the skeleton + + # See if we need to mask any connectors as well + if hasattr(self, "_connectors"): + # Only mask if there is an actual "vertex_ind" or "face_ind" column + cn_mask = None + if "vertex_ind" in self._connectors.columns: + cn_mask = self._connectors.vertex_id.isin(np.where(vert_mask)[0]) + elif "face_ind" in self._connectors.columns: + cn_mask = self._connectors.face_id.isin(np.where(face_mask)[0]) + + if cn_mask is not None: + self._masked_data["_connectors"] = self._connectors + self._connectors = self._connectors.loc[mask] + if copy: + self._connectors = self._connectors.copy() + + # Check if we need to drop the soma position + if hasattr(self, "soma_pos"): + vid = self.snap(self.soma_pos, to="vertices")[0] + if not vert_mask[vid]: + self._masked_data["_soma_pos"] = self.soma_pos + self.soma_pos = None + + # Clear temporary attributes but keep the skeleton since we already fixed that + self._clear_temp_attr(exclude=["_skeleton"]) + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.mask`][navis.MeshNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + # First fix the skeleton (if it exists) + skeleton = getattr(self, "_skeleton", None) + if skeleton is not None: + # If the skeleton is not masked, it was created after the masking + # - in which case we have to throw it away because we can't recover + # the full neuron state. + if not skeleton.is_masked: + skeleton = None + else: + # Unmask the skeleton as well + skeleton.unmask(reset=reset) + + # Manually restore the vertex map + # N.B. that any destructive action (e.g. twig pruning) may have + # removed nodes from the skeleton. If that's the case, we can't + # restore the vertex map and have to ditch it. + if hasattr(skeleton, "_vertex_map_unmasked"): + skeleton.vertex_map = skeleton._vertex_map_unmasked + + # Important note: currently the skeleton gets ditched whenever the MeshNeuron + # is stale. That's mainly because (a) functions modify the mesh but not + # the vertex map and (b) re-generating the skeleton is usually cheap. + # In the long run, we need to make sure the skeleton is always in sync + # and not cleared unless that's explicitly requested. + # I'm thinking something like a MeshNeuron.sync_skeleton() method that + # can either sync the skeleton with the mesh or vice versa. + + if reset: + # Unmask and reset (this will clear temporary attributes including the skeleton) + super().unmask() + if skeleton is not None: + self.skeleton = skeleton + return self + + # Regenerate the vertex mask from the stored face mask + face_mask = self._mask + vert_mask = np.zeros(len(self._masked_data["_vertices"]), dtype=bool) + vert_mask[np.unique(self._masked_data["_faces"][face_mask])] = True + + # Generate a mesh for the masked-out data: + # The mesh prior to masking + pre_mesh = tm.Trimesh(self._masked_data["_vertices"], self._masked_data["_faces"]) + # The vertices and faces that were masked out + pre_vertices, pre_faces, vert_map, face_map = morpho.subset.submesh( + pre_mesh, faces_index=np.where(~face_mask)[0], return_map=True + ) + + # Combine the two + comb = tm.util.concatenate( + [tm.Trimesh(self.vertices, self.faces), tm.Trimesh(pre_vertices, pre_faces)] + ) + + # Drop duplicate faces + comb.update_faces(comb.unique_faces()) + + # Merge vertices that are exactly the same + comb.merge_vertices(digis=6) + + # Update the neuron + self._vertices, self._faces = np.asarray(comb.vertices), np.asarray(comb.faces) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + # Check if we can re-use the skeleton + if skeleton is not None: + # Check if the vertex map is still valid + # Note to self: we could do some elaborate checks here to map old to + # most likely new vertex / nodes but that's a bit overkill for now. + if hasattr(skeleton, 'vertex_map'): + if skeleton.vertex_map.shape[0] != self._vertices.shape[0]: + skeleton = None + elif skeleton.vertex_map.max() >= self._faces.shape[0]: + skeleton = None + + # If we still have a skeleton at this point, we can re-use it + if skeleton is not None: + self._skeleton = skeleton + + return self + + def snap(self, locs, to="vertices"): """Snap xyz location(s) to closest vertex or synapse. Parameters @@ -436,15 +701,16 @@ def snap(self, locs, to='vertices'): """ locs = np.asarray(locs).astype(self.vertices.dtype) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ('vertices', 'vertex', 'connectors', 'connectors'): - raise ValueError('`to` must be "vertices" or "connectors", ' - f'got {to}') + if to not in ("vertices", "vertex", "connectors", "connectors"): + raise ValueError('`to` must be "vertices" or "connectors", ' f"got {to}") # Generate tree tree = scipy.spatial.cKDTree(data=self.vertices) @@ -454,7 +720,9 @@ def snap(self, locs, to='vertices'): return ix, dist - def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> 'TreeNeuron': + def skeletonize( + self, method="wavefront", heal=True, inv_dist=None, **kwargs + ) -> "TreeNeuron": """Skeletonize mesh. See [`navis.conversion.mesh2skeleton`][] for details. @@ -477,10 +745,13 @@ def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> Returns ------- skeleton : navis.TreeNeuron + Has a `.vertex_map` attribute that maps each vertex in the + input mesh to a skeleton node ID. """ - return conversion.mesh2skeleton(self, method=method, heal=heal, - inv_dist=inv_dist, **kwargs) + return conversion.mesh2skeleton( + self, method=method, heal=heal, inv_dist=inv_dist, **kwargs + ) def validate(self, inplace=False): """Use trimesh to try and fix some common mesh issues. diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 5bcab701..892e7c88 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -25,7 +25,7 @@ from io import BufferedIOBase -from typing import Union, Callable, List, Sequence, Optional, Dict, overload +from typing import Union, Callable, List, Sequence, Optional, Dict from typing_extensions import Literal from .. import graph, morpho, utils, config, core, sampling, intersection @@ -39,7 +39,7 @@ except ModuleNotFoundError: xxhash = None -__all__ = ['TreeNeuron'] +__all__ = ["TreeNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -52,15 +52,17 @@ def requires_nodes(func): """Return `None` if neuron has no nodes.""" + @functools.wraps(func) def wrapper(*args, **kwargs): self = args[0] # Return 0 - if isinstance(self.nodes, str) and self.nodes == 'NA': - return 'NA' + if isinstance(self.nodes, str) and self.nodes == "NA": + return "NA" if not isinstance(self.nodes, pd.DataFrame): return None return func(*args, **kwargs) + return wrapper @@ -95,8 +97,8 @@ class TreeNeuron(BaseNeuron): nodes: pd.DataFrame - graph: 'nx.DiGraph' - igraph: 'igraph.Graph' # type: ignore # doesn't know iGraph + graph: "nx.DiGraph" + igraph: "igraph.Graph" # type: ignore # doesn't know iGraph n_branches: int n_leafs: int @@ -117,37 +119,63 @@ class TreeNeuron(BaseNeuron): soma_detection_label: Union[float, int, str] = 1 #: Soma radius (e.g. for plotting). If string, must be column in nodes #: table. Default = 'radius'. - soma_radius: Union[float, int, str] = 'radius' + soma_radius: Union[float, int, str] = "radius" # Set default function for soma finding. Default = [`navis.morpho.find_soma`][] - _soma: Union[Callable[['TreeNeuron'], Sequence[int]], int] = morpho.find_soma + _soma: Union[Callable[["TreeNeuron"], Sequence[int]], int] = morpho.find_soma tags: Optional[Dict[str, List[int]]] = None #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['n_nodes', 'n_connectors', 'soma', 'root', - 'n_branches', 'n_leafs', 'cable_length', 'name'] + EQ_ATTRIBUTES = [ + "n_nodes", + "n_connectors", + "soma", + "root", + "n_branches", + "n_leafs", + "cable_length", + "name", + ] #: Temporary attributes that need to be regenerated when data changes. - TEMP_ATTR = ['_igraph', '_graph_nx', '_segments', '_small_segments', - '_geodesic_matrix', 'centrality_method', '_simple', - '_cable_length', '_memory_usage', '_adjacency_matrix'] + TEMP_ATTR = [ + "_igraph", + "_graph_nx", + "_segments", + "_small_segments", + "_geodesic_matrix", + "centrality_method", + "_simple", + "_cable_length", + "_memory_usage", + "_adjacency_matrix", + ] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'n_nodes', 'n_connectors', 'n_branches', - 'n_leafs', 'cable_length', 'soma', 'units'] + SUMMARY_PROPS = [ + "type", + "name", + "n_nodes", + "n_connectors", + "n_branches", + "n_leafs", + "cable_length", + "soma", + "units", + ] #: Core data table(s) used to calculate hash - CORE_DATA = ['nodes:node_id,parent_id,x,y,z'] - - def __init__(self, - x: Union[pd.DataFrame, - BufferedIOBase, - str, - 'TreeNeuron', - nx.DiGraph], - units: Union[pint.Unit, str] = None, - **metadata - ): + CORE_DATA = ["nodes:node_id,parent_id,x,y,z"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "nodes" + + def __init__( + self, + x: Union[pd.DataFrame, BufferedIOBase, str, "TreeNeuron", nx.DiGraph], + units: Union[pint.Unit, str] = None, + **metadata, + ): """Initialize Skeleton Neuron.""" super().__init__() @@ -157,10 +185,12 @@ def __init__(self, if isinstance(x, pd.DataFrame): self.nodes = x elif isinstance(x, pd.Series): - if not hasattr(x, 'nodes'): - raise ValueError('pandas.Series must have `nodes` entry.') + if not hasattr(x, "nodes"): + raise ValueError("pandas.Series must have `nodes` entry.") elif not isinstance(x.nodes, pd.DataFrame): - raise TypeError(f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"') + raise TypeError( + f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"' + ) self.nodes = x.nodes metadata.update(x.to_dict()) elif isinstance(x, nx.Graph): @@ -182,13 +212,15 @@ def __init__(self, elif isinstance(x, tuple): # Tuple of vertices and edges if len(x) != 2: - raise ValueError('Tuple must have 2 elements: vertices and edges.') + raise ValueError("Tuple must have 2 elements: vertices and edges.") self.nodes = graph.edges2neuron(edges=x[1], vertices=x[0]).nodes elif isinstance(x, type(None)): # This is a essentially an empty neuron pass else: - raise utils.ConstructionError(f'Unable to construct TreeNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct TreeNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -218,21 +250,23 @@ def __truediv__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Division by list/array requires 4 ' - 'divisors for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Division by list/array requires 4 " + "divisors for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] /= other + n.nodes[["x", "y", "z", "radius"]] /= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] /= other + n.connectors[["x", "y", "z"]] /= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius /= other @@ -243,7 +277,7 @@ def __truediv__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units * other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -255,21 +289,23 @@ def __mul__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Multiplication by list/array requires 4' - 'multipliers for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Multiplication by list/array requires 4" + "multipliers for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] *= other + n.nodes[["x", "y", "z", "radius"]] *= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] *= other + n.connectors[["x", "y", "z"]] *= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius *= other @@ -280,7 +316,7 @@ def __mul__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units / other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -292,19 +328,21 @@ def __add__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] += other + n.nodes[["x", "y", "z"]] += other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] += other + n.connectors[["x", "y", "z"]] += other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n # If another neuron, return a list of neurons elif isinstance(other, BaseNeuron): @@ -319,19 +357,21 @@ def __sub__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] -= other + n.nodes[["x", "y", "z"]] -= other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] -= other + n.connectors[["x", "y", "z"]] -= other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -341,10 +381,10 @@ def __getstate__(self): # Pickling the graphs actually takes longer than regenerating them # from scratch - if '_graph_nx' in state: - _ = state.pop('_graph_nx') - if '_igraph' in state: - _ = state.pop('_igraph') + if "_graph_nx" in state: + _ = state.pop("_graph_nx") + if "_igraph" in state: + _ = state.pop("_igraph") return state @@ -352,7 +392,7 @@ def __getstate__(self): @temp_property def adjacency_matrix(self): """Adjacency matrix of the skeleton.""" - if not hasattr(self, '_adjacency_matrix'): + if not hasattr(self, "_adjacency_matrix"): self._adjacency_matrix = graph.skeleton_adjacency_matrix(self) return self._adjacency_matrix @@ -360,7 +400,7 @@ def adjacency_matrix(self): @requires_nodes def vertices(self) -> np.ndarray: """Vertices of the skeleton.""" - return self.nodes[['x', 'y', 'z']].values + return self.nodes[["x", "y", "z"]].values @property @requires_nodes @@ -374,7 +414,7 @@ def edges(self) -> np.ndarray: """ not_root = self.nodes[self.nodes.parent_id >= 0] - return not_root[['node_id', 'parent_id']].values + return not_root[["node_id", "parent_id"]].values @property @requires_nodes @@ -387,7 +427,7 @@ def edge_coords(self) -> np.ndarray: Same but with node IDs instead of x/y/z coordinates. """ - locs = self.nodes.set_index('node_id')[['x', 'y', 'z']] + locs = self.nodes.set_index("node_id")[["x", "y", "z"]] edges = self.edges edges_co = np.zeros((edges.shape[0], 2, 3)) edges_co[:, 0, :] = locs.loc[edges[:, 0]].values @@ -396,10 +436,10 @@ def edge_coords(self) -> np.ndarray: @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of this neuron.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute return self.get_igraph() return self._igraph @@ -409,7 +449,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of this neuron.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute return self.get_graph_nx() return self._graph_nx @@ -419,7 +459,7 @@ def graph(self) -> nx.DiGraph: def geodesic_matrix(self): """Matrix with geodesic (along-the-arbor) distance between nodes.""" # If matrix has not yet been generated or needs update - if not hasattr(self, '_geodesic_matrix'): + if not hasattr(self, "_geodesic_matrix"): # (Re-)generate matrix self._geodesic_matrix = graph.geodesic_matrix(self) @@ -429,7 +469,7 @@ def geodesic_matrix(self): @requires_nodes def leafs(self) -> pd.DataFrame: """Leaf node table.""" - return self.nodes[self.nodes['type'] == 'end'] + return self.nodes[self.nodes["type"] == "end"] @property @requires_nodes @@ -441,7 +481,7 @@ def ends(self): @requires_nodes def branch_points(self): """Branch node table.""" - return self.nodes[self.nodes['type'] == 'branch'] + return self.nodes[self.nodes["type"] == "branch"] @property def nodes(self) -> pd.DataFrame: @@ -461,20 +501,24 @@ def nodes(self, v): def _set_nodes(self, v): # Redefine this function in subclass to change validation - self._nodes = utils.validate_table(v, - required=[('node_id', 'rowId', 'node', 'treenode_id', 'PointNo'), - ('parent_id', 'link', 'parent', 'Parent'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z')], - rename=True, - optional={('radius', 'W'): 0}, - restrict=False) + self._nodes = utils.validate_table( + v, + required=[ + ("node_id", "rowId", "node", "treenode_id", "PointNo"), + ("parent_id", "link", "parent", "Parent"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ], + rename=True, + optional={("radius", "W"): 0}, + restrict=False, + ) # Make sure we don't end up with object dtype anywhere as this can # cause problems - for c in ('node_id', 'parent_id'): - if self._nodes[c].dtype == 'O': + for c in ("node_id", "parent_id"): + if self._nodes[c].dtype == "O": self._nodes[c] = self._nodes[c].astype(int) graph.classify_nodes(self) @@ -504,8 +548,7 @@ def is_tree(self) -> bool: @property def subtrees(self) -> List[List[int]]: """List of subtrees. Sorted by size as sets of node IDs.""" - return sorted(graph._connected_components(self), - key=lambda x: -len(x)) + return sorted(graph._connected_components(self), key=lambda x: -len(x)) @property def connectors(self) -> pd.DataFrame: @@ -514,7 +557,7 @@ def connectors(self) -> pd.DataFrame: def _get_connectors(self) -> pd.DataFrame: # Redefine this function in subclass to change how nodes are retrieved - return getattr(self, '_connectors', None) + return getattr(self, "_connectors", None) @connectors.setter def connectors(self, v): @@ -528,15 +571,19 @@ def _set_connectors(self, v): if isinstance(v, type(None)): self._connectors = None else: - self._connectors = utils.validate_table(v, - required=[('connector_id', 'id'), - ('node_id', 'rowId', 'node', 'treenode_id'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z'), - ('type', 'relation', 'label', 'prepost')], - rename=True, - restrict=False) + self._connectors = utils.validate_table( + v, + required=[ + ("connector_id", "id"), + ("node_id", "rowId", "node", "treenode_id"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ("type", "relation", "label", "prepost"), + ], + rename=True, + restrict=False, + ) @property @requires_nodes @@ -550,8 +597,9 @@ def cycles(self) -> Optional[List[int]]: """ try: - c = nx.find_cycle(self.graph, - source=self.nodes[self.nodes.type == 'end'].node_id.values) + c = nx.find_cycle( + self.graph, source=self.nodes[self.nodes.type == "end"].node_id.values + ) return c except nx.exception.NetworkXNoCycle: return None @@ -559,16 +607,16 @@ def cycles(self) -> Optional[List[int]]: raise @property - def simple(self) -> 'TreeNeuron': + def simple(self) -> "TreeNeuron": """Simplified representation consisting only of root, branch points and leafs.""" - if not hasattr(self, '_simple'): + if not hasattr(self, "_simple"): self._simple = self.copy() # Make sure we don't have a soma, otherwise that node will be preserved self._simple.soma = None # Downsample - self._simple.downsample(float('inf'), inplace=True) + self._simple.downsample(float("inf"), inplace=True) return self._simple @property @@ -592,11 +640,11 @@ def soma(self) -> Optional[Union[str, int]]: if all(pd.isnull(soma)): soma = None elif not any(self.nodes.node_id.isin(soma)): - logger.warning(f'Soma(s) {soma} not found in node table.') + logger.warning(f"Soma(s) {soma} not found in node table.") soma = None else: if soma not in self.nodes.node_id.values: - logger.warning(f'Soma {soma} not found in node table.') + logger.warning(f"Soma {soma} not found in node table.") soma = None return soma @@ -604,7 +652,7 @@ def soma(self) -> Optional[Union[str, int]]: @soma.setter def soma(self, value: Union[Callable, int, None]) -> None: """Set soma.""" - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self._soma = types.MethodType(value, self) elif isinstance(value, type(None)): self._soma = None @@ -614,7 +662,7 @@ def soma(self, value: Union[Callable, int, None]) -> None: if value in self.nodes.node_id.values: self._soma = value else: - raise ValueError('Soma must be function, None or a valid node ID.') + raise ValueError("Soma must be function, None or a valid node ID.") @property def soma_pos(self) -> Optional[Sequence]: @@ -633,7 +681,7 @@ def soma_pos(self) -> Optional[Sequence]: else: soma = utils.make_iterable(soma) - return self.nodes.loc[self.nodes.node_id.isin(soma), ['x', 'y', 'z']].values + return self.nodes.loc[self.nodes.node_id.isin(soma), ["x", "y", "z"]].values @soma_pos.setter def soma_pos(self, value: Sequence) -> None: @@ -645,16 +693,20 @@ def soma_pos(self, value: Sequence) -> None: try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) # Generate tree - id, dist = self.snap(value, to='nodes') + id, dist = self.snap(value, to="nodes") # A sanity check if dist > (self.sampling_resolution * 10): - logger.warning(f'New soma position for {self.id} is suspiciously ' - f'far away from the closest node: {dist}') + logger.warning( + f"New soma position for {self.id} is suspiciously " + f"far away from the closest node: {dist}" + ) self.soma = id @@ -673,26 +725,26 @@ def root(self, value: Union[int, List[int]]) -> None: @property def type(self) -> str: """Neuron type.""" - return 'navis.TreeNeuron' + return "navis.TreeNeuron" @property @requires_nodes def n_branches(self) -> Optional[int]: """Number of branch points.""" - return self.nodes[self.nodes.type == 'branch'].shape[0] + return self.nodes[self.nodes.type == "branch"].shape[0] @property @requires_nodes def n_leafs(self) -> Optional[int]: """Number of leaf nodes.""" - return self.nodes[self.nodes.type == 'end'].shape[0] + return self.nodes[self.nodes.type == "end"].shape[0] @property @temp_property @add_units(compact=True) def cable_length(self) -> Union[int, float]: """Cable length.""" - if not hasattr(self, '_cable_length'): + if not hasattr(self, "_cable_length"): self._cable_length = morpho.cable_length(self) return self._cable_length @@ -700,17 +752,21 @@ def cable_length(self) -> Union[int, float]: @add_units(compact=True, power=2) def surface_area(self) -> float: """Radius-based lateral surface area.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - area will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - area will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -721,23 +777,27 @@ def surface_area(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (np.pi * (r1 + r2) * np.sqrt( (r1-r2)**2 + h**2)).sum() + return (np.pi * (r1 + r2) * np.sqrt((r1 - r2) ** 2 + h**2)).sum() @property @add_units(compact=True, power=3) def volume(self) -> float: """Radius-based volume.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - volume will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - volume will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -748,17 +808,17 @@ def volume(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (1/3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() + return (1 / 3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() @property def bbox(self) -> np.ndarray: """Bounding box (includes connectors).""" - mn = np.min(self.nodes[['x', 'y', 'z']].values, axis=0) - mx = np.max(self.nodes[['x', 'y', 'z']].values, axis=0) + mn = np.min(self.nodes[["x", "y", "z"]].values, axis=0) + mx = np.max(self.nodes[["x", "y", "z"]].values, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -780,9 +840,9 @@ def sampling_resolution(self) -> float: def segments(self) -> List[list]: """Neuron broken down into linear segments (see also `.small_segments`).""" # Calculate if required - if not hasattr(self, '_segments'): + if not hasattr(self, "_segments"): # This also sets the attribute - self._segments = self._get_segments(how='length') + self._segments = self._get_segments(how="length") return self._segments @property @@ -790,19 +850,18 @@ def segments(self) -> List[list]: def small_segments(self) -> List[list]: """Neuron broken down into small linear segments (see also `.segments`).""" # Calculate if required - if not hasattr(self, '_small_segments'): + if not hasattr(self, "_small_segments"): # This also sets the attribute - self._small_segments = self._get_segments(how='break') + self._small_segments = self._get_segments(how="break") return self._small_segments - def _get_segments(self, - how: Union[Literal['length'], - Literal['break']] = 'length' - ) -> List[list]: + def _get_segments( + self, how: Union[Literal["length"], Literal["break"]] = "length" + ) -> List[list]: """Generate segments for neuron.""" - if how == 'length': + if how == "length": return graph._generate_segments(self) - elif how == 'break': + elif how == "break": return graph._break_segments(self) else: raise ValueError(f'Unknown method: "{how}"') @@ -830,11 +889,11 @@ def _clear_temp_attr(self, exclude: list = []) -> None: elif self._soma not in self.nodes.node_id.values: self.soma = None - if 'classify_nodes' not in exclude: + if "classify_nodes" not in exclude: # Reclassify nodes graph.classify_nodes(self, inplace=True) - def copy(self, deepcopy: bool = False) -> 'TreeNeuron': + def copy(self, deepcopy: bool = False) -> "TreeNeuron": """Return a copy of the neuron. Parameters @@ -850,17 +909,19 @@ def copy(self, deepcopy: bool = False) -> 'TreeNeuron': TreeNeuron """ - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new empty neuron x = self.__class__(None) # Populate with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) # Copy graphs only if neuron is not stale if not self.is_stale: - if '_graph_nx' in self.__dict__: + if "_graph_nx" in self.__dict__: x._graph_nx = self._graph_nx.copy(as_view=deepcopy is not True) - if '_igraph' in self.__dict__: + if "_igraph" in self.__dict__: if self._igraph is not None: # This is pretty cheap, so we will always make a deep copy x._igraph = self._igraph.copy() @@ -883,7 +944,7 @@ def get_graph_nx(self) -> nx.DiGraph: self._graph_nx = graph.neuron2nx(self) return self._graph_nx - def get_igraph(self) -> 'igraph.Graph': # type: ignore + def get_igraph(self) -> "igraph.Graph": # type: ignore """Calculate and return iGraph representation of neuron. Once calculated stored as `.igraph`. Call function again to update @@ -901,11 +962,173 @@ def get_igraph(self) -> 'igraph.Graph': # type: ignore self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph - @overload - def resample(self, resample_to: int, inplace: Literal[False]) -> 'TreeNeuron': ... + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with column name in nodes table + + Returns + ------- + self + + See Also + -------- + [`navis.MeshNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = self.nodes[mask].values + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != self.nodes.shape[0]: + raise ValueError("Mask must have same length as nodes table.") + + self._mask = mask + self._masked_data = {} + self._masked_data["_nodes"] = self.nodes + + # N.B. we're directly setting `._nodes`` to avoid overhead from checks + self._nodes = self._nodes.loc[mask] + if copy: + self._nodes = self._nodes.copy() - @overload - def resample(self, resample_to: int, inplace: Literal[True]) -> None: ... + if hasattr(self, "_connectors"): + self._masked_data["_connectors"] = self.connectors + self._connectors = self._connectors.loc[ + self._connectors.node_id.isin(self.nodes.node_id) + ] + if copy: + self._connectors = self._connectors.copy() + + self._clear_temp_attr() + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`TreeNeuron.is_masked`][navis.TreeNeuron.is_masked] + Check if neuron is masked. + [`TreeNeuron.mask`][navis.TreeNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + + # Combine post-mask data with original data + post_nodes = self.nodes + pre_nodes = self._masked_data["_nodes"].iloc[~mask] + + # We need to take care of a few things: + # 1. Make sure don't have any duplicate node IDs + duplicates = post_nodes.node_id.values[ + post_nodes.node_id.isin(pre_nodes.node_id) + ] + if any(duplicates): + # Start with a new max index + max_id = max(post_nodes.node_id.max(), pre_nodes.node_id.max()) + 1 + # New indices + new_ix = dict(zip(duplicates, range(max_id, max_id + len(duplicates)))) + # Update node IDs and parent IDs + post_nodes["node_id"] = post_nodes.node_id.replace(new_ix) + post_nodes["parent_id"] = post_nodes.parent_id.replace(new_ix) + + # Concatenate + self._nodes = pd.concat([pre_nodes, post_nodes], ignore_index=True) + + # 2. Re-connect the root nodes of the masked data + post_roots = post_nodes[post_nodes.parent_id < 0].node_id.values + pre_parents = ( + self._masked_data["_nodes"].set_index("node_id").parent_id.to_dict() + ) + to_fix = {} + for r in post_roots: + # Skip if this root is not in pre_parents + if r not in pre_parents: + continue + # Skip if this was also a root in the pre-masked data + if pre_parents[r] >= 0: + continue + # Skip if the old parent does not exist anymore + if pre_parents[r] not in self.nodes.node_id.values: + continue + # If we made it here, there is a new root that we can connect to the old parent + to_fix[r] = pre_parents[r] + if to_fix: + # Fix parent IDs + self._nodes["parent_id"] = self._nodes["parent_id"].replace(to_fix) + + # 3. See if any parent IDs have ceased to exist + missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( + self._nodes.parent_id >= 0 + ) + if any(missing_parents): + self.nodes.loc[missing_parents, "parent_id"] = -1 + + # TODO: Make sure that edges have a consistent orientation + # (not sure this is much of a problem) + + # Connectors + # TODO: check `node_id` of connectors too + if "_connectors" in self._masked_data: + if not hasattr(self, "_connectors"): + self._connectors = self._masked_data["_connectors"] + else: + cn = self._masked_data["_connectors"] + cn = cn.loc[cn.node_id.isin(pre_nodes.node_id.values)] + self._connectors = pd.concat([self._connectors, cn], ignore_index=True) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self def resample(self, resample_to, inplace=False): """Resample neuron to given resolution. @@ -939,18 +1162,6 @@ def resample(self, resample_to, inplace=False): return x return None - @overload - def downsample(self, - factor: float, - inplace: Literal[False], - **kwargs) -> 'TreeNeuron': ... - - @overload - def downsample(self, - factor: float, - inplace: Literal[True], - **kwargs) -> None: ... - def downsample(self, factor=5, inplace=False, **kwargs): """Downsample the neuron by given factor. @@ -987,9 +1198,9 @@ def downsample(self, factor=5, inplace=False, **kwargs): return x return None - def reroot(self, - new_root: Union[int, str], - inplace: bool = False) -> Optional['TreeNeuron']: + def reroot( + self, new_root: Union[int, str], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Reroot neuron to given node ID or node tag. Parameters @@ -1020,9 +1231,9 @@ def reroot(self, return x return None - def prune_distal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_distal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Cut off nodes distal to given nodes. Parameters @@ -1047,7 +1258,7 @@ def prune_distal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - prox = graph.cut_skeleton(x, n, ret='proximal')[0] + prox = graph.cut_skeleton(x, n, ret="proximal")[0] # Reinitialise with proximal data x.__init__(prox) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1057,9 +1268,9 @@ def prune_distal_to(self, return x return None - def prune_proximal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_proximal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Remove nodes proximal to given node. Reroots neuron to cut node. Parameters @@ -1084,7 +1295,7 @@ def prune_proximal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - dist = graph.cut_skeleton(x, n, ret='distal')[0] + dist = graph.cut_skeleton(x, n, ret="distal")[0] # Reinitialise with distal data x.__init__(dist) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1097,9 +1308,9 @@ def prune_proximal_to(self, return x return None - def prune_by_strahler(self, - to_prune: Union[int, List[int], slice], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_by_strahler( + self, to_prune: Union[int, List[int], slice], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Prune neuron based on [Strahler order](https://en.wikipedia.org/wiki/Strahler_number). Will reroot neuron to soma if possible. @@ -1132,8 +1343,7 @@ def prune_by_strahler(self, else: x = self.copy() - morpho.prune_by_strahler( - x, to_prune=to_prune, reroot_soma=True, inplace=True) + morpho.prune_by_strahler(x, to_prune=to_prune, reroot_soma=True, inplace=True) # No need to call this as morpho.prune_by_strahler does this already # self._clear_temp_attr() @@ -1142,17 +1352,23 @@ def prune_by_strahler(self, return x return None - def prune_twigs(self, - size: float, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> Optional['TreeNeuron']: + def prune_twigs( + self, + size: float, + mask: Optional[np.ndarray] = None, + inplace: bool = False, + recursive: Union[int, bool, float] = False, + ) -> Optional["TreeNeuron"]: """Prune terminal twigs under a given size. Parameters ---------- size : int | float Twigs shorter than this will be pruned. + mask : iterable | callable, optional + Either a boolean mask, a list of node IDs or a callable taking + a neuron as input and returning one of the former. If provided, + only nodes that are in the mask will be considered for pruning. inplace : bool, optional If False, pruning is performed on copy of original neuron which is then returned. @@ -1172,17 +1388,18 @@ def prune_twigs(self, else: x = self.copy() - morpho.prune_twigs(x, size=size, inplace=True) + morpho.prune_twigs(x, size=size, mask=mask, inplace=True) if not inplace: return x return None - def prune_at_depth(self, - depth: Union[float, int], - source: Optional[int] = None, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_at_depth( + self, + depth: Union[float, int], + source: Optional[int] = None, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune all neurites past a given distance from a source. Parameters @@ -1220,10 +1437,11 @@ def prune_at_depth(self, return x return None - def cell_body_fiber(self, - reroot_soma: bool = True, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def cell_body_fiber( + self, + reroot_soma: bool = True, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron to its cell body fiber. Parameters @@ -1255,11 +1473,12 @@ def cell_body_fiber(self, return x return None - def prune_by_longest_neurite(self, - n: int = 1, - reroot_soma: bool = False, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def prune_by_longest_neurite( + self, + n: int = 1, + reroot_soma: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron down to the longest neurite. Parameters @@ -1284,8 +1503,7 @@ def prune_by_longest_neurite(self, else: x = self.copy() - graph.longest_neurite( - x, n, inplace=True, reroot_soma=reroot_soma) + graph.longest_neurite(x, n, inplace=True, reroot_soma=reroot_soma) # Clear temporary attributes x._clear_temp_attr() @@ -1294,14 +1512,13 @@ def prune_by_longest_neurite(self, return x return None - def prune_by_volume(self, - v: Union[core.Volume, - List[core.Volume], - Dict[str, core.Volume]], - mode: Union[Literal['IN'], Literal['OUT']] = 'IN', - prevent_fragments: bool = False, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_by_volume( + self, + v: Union[core.Volume, List[core.Volume], Dict[str, core.Volume]], + mode: Union[Literal["IN"], Literal["OUT"]] = "IN", + prevent_fragments: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron by intersection with given volume(s). Parameters @@ -1330,9 +1547,9 @@ def prune_by_volume(self, else: x = self.copy() - intersection.in_volume(x, v, inplace=True, - prevent_fragments=prevent_fragments, - mode=mode) + intersection.in_volume( + x, v, inplace=True, prevent_fragments=prevent_fragments, mode=mode + ) # Clear temporary attributes # x._clear_temp_attr() @@ -1341,9 +1558,7 @@ def prune_by_volume(self, return x return None - def to_swc(self, - filename: Optional[str] = None, - **kwargs) -> None: + def to_swc(self, filename: Optional[str] = None, **kwargs) -> None: """Generate SWC file from this neuron. Parameters @@ -1365,9 +1580,10 @@ def to_swc(self, """ return io.write_swc(self, filename, **kwargs) # type: ignore # double import of "io" - def reload(self, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def reload( + self, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Reload neuron. Must have filepath as `.origin` as attribute. Returns @@ -1376,19 +1592,22 @@ def reload(self, If `inplace=False`. """ - if not hasattr(self, 'origin'): - raise AttributeError('To reload TreeNeuron must have `.origin` ' - 'attribute') + if not hasattr(self, "origin"): + raise AttributeError( + "To reload TreeNeuron must have `.origin` " "attribute" + ) - if self.origin in ('DataFrame', 'string'): - raise ValueError('Unable to reload TreeNeuron: it appears to have ' - 'been created from string or DataFrame.') + if self.origin in ("DataFrame", "string"): + raise ValueError( + "Unable to reload TreeNeuron: it appears to have " + "been created from string or DataFrame." + ) kwargs = {} - if hasattr(self, 'soma_label'): - kwargs['soma_label'] = self.soma_label - if hasattr(self, 'connector_labels'): - kwargs['connector_labels'] = self.connector_labels + if hasattr(self, "soma_label"): + kwargs["soma_label"] = self.soma_label + if hasattr(self, "connector_labels"): + kwargs["connector_labels"] = self.connector_labels x = io.read_swc(self.origin, **kwargs) @@ -1403,7 +1622,7 @@ def reload(self, x2._clear_temp_attr() return x - def snap(self, locs, to='nodes'): + def snap(self, locs, to="nodes"): """Snap xyz location(s) to closest node or synapse. Parameters @@ -1431,15 +1650,16 @@ def snap(self, locs, to='nodes'): """ locs = np.asarray(locs).astype(np.float64) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ['nodes', 'connectors']: - raise ValueError('`to` must be "nodes" or "connectors", ' - f'got {to}') + if to not in ["nodes", "connectors"]: + raise ValueError('`to` must be "nodes" or "connectors", ' f"got {to}") # Generate tree tree = graph.neuron2KDTree(self, data=to) @@ -1447,7 +1667,7 @@ def snap(self, locs, to='nodes'): # Find the closest node dist, ix = tree.query(locs) - if to == 'nodes': + if to == "nodes": id = self.nodes.node_id.values[ix] else: id = self.connectors.connector_id.values[ix] diff --git a/navis/core/voxel.py b/navis/core/voxel.py index 4a3e5cc3..b8934851 100644 --- a/navis/core/voxel.py +++ b/navis/core/voxel.py @@ -94,6 +94,9 @@ class VoxelNeuron(BaseNeuron): #: Core data table(s) used to calculate hash CORE_DATA = ['_data'] + #: Property used to calculate length of neuron + _LENGTH_DATA = 'voxels' + def __init__(self, x: Union[np.ndarray], offset: Optional[np.ndarray] = None, @@ -216,6 +219,12 @@ def __sub__(self, other, copy=True): return n return NotImplemented + def __len__(self): + """Return number of voxels.""" + # This is the only neuron for which we implement __len__ separately + # to avoid the potential overhead of generating the voxel representation + return np.product(self.shape) + @property def _base_data_type(self) -> str: """Type of data (grid or voxels) underlying this neuron.""" From 7de016e9608c586b9f5e2db9b4987b36358b001e Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 24 Oct 2024 14:35:09 +0100 Subject: [PATCH 4/6] a few additional fixes for masking: - fix NeuronMask doctest - TreeNeuron.un/mask: make sure to re-classify - TreeNeuron.unmask: fix re-connecting --- navis/core/masking.py | 2 +- navis/core/skeleton.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/navis/core/masking.py b/navis/core/masking.py index 345de824..c5a78861 100644 --- a/navis/core/masking.py +++ b/navis/core/masking.py @@ -63,7 +63,7 @@ class NeuronMask: >>> # Grab a few skeletons >>> nl = navis.example_neurons(3) >>> # Label axon and dendrites - >>> navis.split_axon_dendrite(nl, label_only=True) + >>> _ = navis.split_axon_dendrite(nl, label_only=True) >>> # Mask by axon >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'): ... print("Axon cable length:", nl.cable_length * nl[0].units) diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 892e7c88..41937e5d 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -1009,10 +1009,17 @@ def mask(self, mask, copy=True): self._masked_data["_nodes"] = self.nodes # N.B. we're directly setting `._nodes`` to avoid overhead from checks - self._nodes = self._nodes.loc[mask] + self._nodes = self._nodes.loc[mask].drop("type", axis=1, errors="ignore") if copy: self._nodes = self._nodes.copy() + # See if any parent IDs have ceased to exist + missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( + self._nodes.parent_id >= 0 + ) + if any(missing_parents): + self.nodes.loc[missing_parents, "parent_id"] = -1 + if hasattr(self, "_connectors"): self._masked_data["_connectors"] = self.connectors self._connectors = self._connectors.loc[ @@ -1092,7 +1099,7 @@ def unmask(self, reset=True): if r not in pre_parents: continue # Skip if this was also a root in the pre-masked data - if pre_parents[r] >= 0: + if pre_parents[r] < 0: continue # Skip if the old parent does not exist anymore if pre_parents[r] not in self.nodes.node_id.values: @@ -1110,6 +1117,9 @@ def unmask(self, reset=True): if any(missing_parents): self.nodes.loc[missing_parents, "parent_id"] = -1 + # Force nodes to be re-classified + self.nodes.drop("type", axis=1, errors="ignore", inplace=True) + # TODO: Make sure that edges have a consistent orientation # (not sure this is much of a problem) From fe2d716c8ad7afc95acf9d7ed66f17c796c98cfa Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Wed, 4 Dec 2024 19:29:18 +0000 Subject: [PATCH 5/6] docs: add masking to API --- docs/api.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/api.md b/docs/api.md index fb7cfdc6..ecf28690 100644 --- a/docs/api.md +++ b/docs/api.md @@ -580,6 +580,14 @@ Functions to export neurons. | [`navis.write_precomputed()`][navis.write_precomputed] | {{ autosummary("navis.write_precomputed") }} | | [`navis.write_parquet()`][navis.write_parquet] | {{ autosummary("navis.write_parquet") }} | +## Masking + +Functions and classes for masking: + +| Function/Class | Description | +|----------------|-------------| +| [`navis.NeuronMask`][navis.NeuronMask] | {{ autosummary("navis.NeuronMask") }} | + ## Utility Various utility functions. From fcc2d440b679e100ccd5f21b317878c085e3f305 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Wed, 4 Dec 2024 19:32:16 +0000 Subject: [PATCH 6/6] bits and pieces for masking: - add .view() methods to neuron objects - by default, do not copy data - bits and pieces --- navis/core/base.py | 26 +++- navis/core/dotprop.py | 294 ++++++++++++++++++++++++----------------- navis/core/masking.py | 13 +- navis/core/mesh.py | 130 +++++++++++------- navis/core/skeleton.py | 69 +++++++--- navis/morpho/subset.py | 2 +- 6 files changed, 328 insertions(+), 206 deletions(-) diff --git a/navis/core/base.py b/navis/core/base.py index 009b0a66..219f71c1 100644 --- a/navis/core/base.py +++ b/navis/core/base.py @@ -46,7 +46,7 @@ def Neuron( - x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], + x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], # noqa: F821 **metadata, # noqa: F821 ): """Constructor for Neuron objects. Depending on the input, either a @@ -664,6 +664,10 @@ def copy(self, deepcopy=False) -> "BaseNeuron": return x + def view(self) -> "BaseNeuron": + """Create a view of the neuron without copying data.""" + raise NotImplementedError(f"View not implemented for neuron of type {type(self)}.") + def summary(self, add_props=None) -> pd.Series: """Get a summary of this neuron.""" @@ -687,6 +691,11 @@ def summary(self, add_props=None) -> pd.Series: warnings.simplefilter("ignore") s = pd.Series([getattr(self, at, "NA") for at in props], index=props) + # Show mask status + if self.is_masked: + if "masked" not in s.index: + s["masked"] = True + return s def plot2d(self, **kwargs): @@ -750,7 +759,20 @@ def is_masked(self): return hasattr(self, "_masked_data") def mask(self, mask): - """Mask neuron.""" + """Mask neuron. + + Implementation details depend on the neuron type (see below). + + See Also + -------- + [`navis.TreeNeuron.mask`][] + Mask skeleton. + [`navis.MeshNeuron.mask`][] + Mask mesh. + [`navis.Dotprops.mask`][] + Mask dotprops. + + """ raise NotImplementedError( f"Masking not implemented for neuron of type {type(self)}." ) diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index d3a69cdc..7d83d5e8 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -37,7 +37,7 @@ except ModuleNotFoundError: from scipy.spatial import cKDTree as KDTree -__all__ = ['Dotprops'] +__all__ = ["Dotprops"] # Set up logging logger = config.get_logger(__name__) @@ -93,34 +93,35 @@ class Dotprops(BaseNeuron): points: np.ndarray alpha: np.ndarray - vect: np.ndarray + vect: np.ndarray k: Optional[int] soma: Optional[Union[list, np.ndarray]] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'k', 'units', 'n_points'] + SUMMARY_PROPS = ["type", "name", "k", "units", "n_points"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name', 'n_points', 'k'] + EQ_ATTRIBUTES = ["name", "n_points", "k"] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage', "_tree"] + TEMP_ATTR = ["_memory_usage", "_tree"] #: Core data table(s) used to calculate hash - _CORE_DATA = ['points', 'vect'] + _CORE_DATA = ["points", "vect"] #: Property used to calculate length of neuron - _LENGTH_DATA = 'points' - - def __init__(self, - points: np.ndarray, - k: int, - vect: Optional[np.ndarray] = None, - alpha: Optional[np.ndarray] = None, - units: Union[pint.Unit, str] = None, - **metadata - ): + _LENGTH_DATA = "points" + + def __init__( + self, + points: np.ndarray, + k: int, + vect: Optional[np.ndarray] = None, + alpha: Optional[np.ndarray] = None, + units: Union[pint.Unit, str] = None, + **metadata, + ): """Initialize Dotprops Neuron.""" super().__init__() @@ -144,13 +145,13 @@ def __truediv__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.divide(n.points, other, out=n.points, casting='unsafe') + _ = np.divide(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] /= other + n.connectors.loc[:, ["x", "y", "z"]] /= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -167,13 +168,13 @@ def __mul__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.multiply(n.points, other, out=n.points, casting='unsafe') + _ = np.multiply(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] *= other + n.connectors.loc[:, ["x", "y", "z"]] *= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -190,13 +191,13 @@ def __add__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.add(n.points, other, out=n.points, casting='unsafe') + _ = np.add(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] += other + n.connectors.loc[:, ["x", "y", "z"]] += other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") return n # If another neuron, return a list of neurons @@ -209,13 +210,13 @@ def __sub__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.subtract(n.points, other, out=n.points, casting='unsafe') + _ = np.subtract(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] -= other + n.connectors.loc[:, ["x", "y", "z"]] -= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") return n return NotImplemented @@ -227,9 +228,9 @@ def __getstate__(self): # The KDTree from pykdtree does not like being pickled # We will have to remove it which will force it to be regenerated # after unpickling - if '_tree' in state: - if 'pykdtree' in str(type(state['_tree'])): - _ = state.pop('_tree') + if "_tree" in state: + if "pykdtree" in str(type(state["_tree"])): + _ = state.pop("_tree") return state @@ -238,8 +239,10 @@ def alpha(self): """Alpha value for tangent vectors (optional).""" if isinstance(self._alpha, type(None)): if isinstance(self.k, type(None)) or (self.k <= 0): - raise ValueError('Unable to calculate `alpha` for Dotprops not ' - 'generated using k-nearest-neighbors.') + raise ValueError( + "Unable to calculate `alpha` for Dotprops not " + "generated using k-nearest-neighbors." + ) self.recalculate_tangents(self.k, inplace=True) return self._alpha @@ -249,7 +252,7 @@ def alpha(self, value): if not isinstance(value, type(None)): value = np.asarray(value) if value.ndim != 1: - raise ValueError(f'alpha must be (N, ) array, got {value.shape}') + raise ValueError(f"alpha must be (N, ) array, got {value.shape}") self._alpha = value @property @@ -259,8 +262,8 @@ def bbox(self) -> np.ndarray: mx = np.max(self.points, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -270,12 +273,16 @@ def bbox(self) -> np.ndarray: @property def datatables(self) -> List[str]: """Names of all DataFrames attached to this neuron.""" - return [k for k, v in self.__dict__.items() if isinstance(v, pd.DataFrame, np.ndarray)] + return [ + k + for k, v in self.__dict__.items() + if isinstance(v, pd.DataFrame, np.ndarray) + ] @property def kdtree(self): """KDTree for points.""" - if not getattr(self, '_tree', None): + if not getattr(self, "_tree", None): self._tree = KDTree(self.points) return self._tree @@ -290,7 +297,7 @@ def points(self, value): value = np.zeros((0, 3)) value = np.asarray(value) if value.ndim != 2 or value.shape[1] != 3: - raise ValueError(f'points must be (N, 3) array, got {value.shape}') + raise ValueError(f"points must be (N, 3) array, got {value.shape}") self._points = value # Also reset KDtree self._tree = None @@ -307,7 +314,7 @@ def vect(self, value): if not isinstance(value, type(None)): value = np.asarray(value) if value.ndim != 2 or value.shape[1] != 3: - raise ValueError(f'vectors must be (N, 3) array, got {value.shape}') + raise ValueError(f"vectors must be (N, 3) array, got {value.shape}") self._vect = value @property @@ -336,11 +343,11 @@ def soma(self) -> Optional[int]: if not any(soma): soma = None elif any(np.array(soma) < 0) or any(np.array(soma) > self.points.shape[0]): - logger.warning(f'Soma(s) {soma} not found in points.') + logger.warning(f"Soma(s) {soma} not found in points.") soma = None else: if 0 < soma < self.points.shape[0]: - logger.warning(f'Soma {soma} not found in node table.') + logger.warning(f"Soma {soma} not found in node table.") soma = None return soma @@ -348,7 +355,7 @@ def soma(self) -> Optional[int]: @soma.setter def soma(self, value: Union[Callable, int, None]) -> None: """Set soma.""" - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self._soma = types.MethodType(value, self) elif isinstance(value, type(None)): self._soma = None @@ -358,20 +365,22 @@ def soma(self, value: Union[Callable, int, None]) -> None: if 0 < value < self.points.shape[0]: self._soma = value else: - raise ValueError('Soma must be function, None or a valid node index.') + raise ValueError("Soma must be function, None or a valid node index.") @property def type(self) -> str: """Neuron type.""" - return 'navis.Dotprops' - - def dist_dots(self, - other: 'Dotprops', - alpha: bool = False, - distance_upper_bound: Optional[float] = None, - **kwargs) -> Union[ - Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray] - ]: + return "navis.Dotprops" + + def dist_dots( + self, + other: "Dotprops", + alpha: bool = False, + distance_upper_bound: Optional[float] = None, + **kwargs, + ) -> Union[ + Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray] + ]: """Query this Dotprops against another. This function is mainly for `navis.nblast`. @@ -419,9 +428,9 @@ def dist_dots(self, # Scipy's KDTree does not like the distance to be None diub = distance_upper_bound if distance_upper_bound else np.inf - fast_dists, fast_idxs = other.kdtree.query(points, - distance_upper_bound=diub, - **kwargs) + fast_dists, fast_idxs = other.kdtree.query( + points, distance_upper_bound=diub, **kwargs + ) # If upper distance we have to worry about infinite distances if distance_upper_bound: @@ -479,7 +488,7 @@ def downsample(self, factor=5, inplace=False, **kwargs): return x return None - def copy(self) -> 'Dotprops': + def copy(self) -> "Dotprops": """Return a copy of the dotprops. Returns @@ -489,17 +498,40 @@ def copy(self) -> 'Dotprops': """ # Don't copy the KDtree - when using pykdtree, copy.copy throws an # error and the construction is super fast anyway - no_copy = ['_lock', '_tree'] + no_copy = ["_lock", "_tree"] # Generate new empty neuron - note we pass vect and alpha to # prevent calculation on initialization - x = self.__class__(points=np.zeros((0, 3)), k=1, - vect=np.zeros((0, 3)), alpha=np.zeros(0)) + x = self.__class__( + points=np.zeros((0, 3)), k=1, vect=np.zeros((0, 3)), alpha=np.zeros(0) + ) # Populate with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x - def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inplace=False): + def view(self) -> "Dotprops": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__( + points=np.zeros((0, 3)), k=1, vect=np.zeros((0, 3)), alpha=np.zeros(0) + ) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) + + return x + + def drop_fluff( + self, epsilon, keep_size: int = None, n_largest: int = None, inplace=False + ): """Remove fluff from neuron. By default, this function will remove all but the largest connected @@ -534,16 +566,20 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl Base function. See for details and examples. """ - x = morpho.drop_fluff(self, epsilon=epsilon, keep_size=keep_size, n_largest=n_largest, inplace=inplace) + x = morpho.drop_fluff( + self, + epsilon=epsilon, + keep_size=keep_size, + n_largest=n_largest, + inplace=inplace, + ) if not inplace: return x - def mask(self, mask, copy=True): + def mask(self, mask, inplace=False, copy=False) -> "Dotprops": """Mask neuron with given mask. - This is always done in-place! - Parameters ---------- mask : np.ndarray @@ -551,10 +587,16 @@ def mask(self, mask, copy=True): - 1D array with boolean values - callable that accepts a neuron and returns a mask - string with property name + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data (points, vectors, alpha, etc.) after masking. + This is useful if you want to avoid accidentally modifying + the original nodes table. Returns ------- - self + n : Dotprops The masked neuron. See Also @@ -572,56 +614,60 @@ def mask(self, mask, copy=True): "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." ) + n = self + if not inplace: + n = self.view() + if callable(mask): - mask = mask(self) + mask = mask(n) elif isinstance(mask, str): - mask = getattr(self, mask) + mask = getattr(n, mask) mask = np.asarray(mask) if mask.dtype != bool: raise ValueError("Mask must be boolean array.") - elif mask.shape[0] != len(self): + elif mask.shape[0] != len(n): raise ValueError("Mask must have same length as points.") - self._mask = mask - self._masked_data = {} - self._masked_data['_points'] = self.points + n._mask = mask + n._masked_data = {} + n._masked_data["_points"] = n.points # Drop soma if masked out - if self.soma is not None: - if isinstance(self.soma, (list, np.ndarray)): - soma_left = self.soma[mask[self.soma]] - self._masked_data['_soma'] = self.soma + if n.soma is not None: + if isinstance(n.soma, (list, np.ndarray)): + soma_left = n.soma[mask[n.soma]] + n._masked_data["_soma"] = n.soma if any(soma_left): - self.soma = soma_left + n.soma = soma_left else: - self.soma = None - elif not mask[self.soma]: - self._masked_data['_soma'] = self.soma - self.soma = None + n.soma = None + elif not mask[n.soma]: + n._masked_data["_soma"] = n.soma + n.soma = None - # N.B. we're directly setting `._nodes`` to avoid overhead from checks + # Apply the mask and make copy if requested for att in ("_points", "_vect", "_alpha"): - if hasattr(self, att): - self._masked_data[att] = getattr(self, att) - setattr(self, att, getattr(self, att)[mask]) + if hasattr(n, att): + n._masked_data[att] = getattr(n, att) # save original data + setattr(n, att, getattr(n, att)[mask]) # apply mask if copy: - setattr(self, att, getattr(self, att).copy()) + setattr(n, att, getattr(n, att).copy()) # copy masked data if requested - if hasattr(self, "_connectors") and "point_ix" in self._connectors.columns: - self._masked_data['connectors'] = self.connectors - self._connectors = self._connectors.loc[ - self.connectors.point_ix.isin(np.arange(len(mask))[mask]) - ] + if hasattr(n, "_connectors") and "point_ix" in n._connectors.columns: + n._masked_data["connectors"] = n.connectors + n._connectors = n._connectors.loc[ + n.connectors.point_ix.isin(np.arange(len(mask))[mask]) + ] if copy: - self._connectors = self._connectors.copy() + n._connectors = n._connectors.copy() - self._clear_temp_attr() + n._clear_temp_attr() - return self + return n def unmask(self, reset=True): """Unmask neuron. @@ -699,8 +745,9 @@ def recalculate_tangents(self, k: int, inplace=False): # Checks and balances n_points = x.points.shape[0] if n_points < k: - raise ValueError(f"Too few points ({n_points}) to calculate {k} " - "nearest-neighbors") + raise ValueError( + f"Too few points ({n_points}) to calculate {k} " "nearest-neighbors" + ) # Create the KDTree and get the k-nearest neighbors for each point dist, ix = self.kdtree.query(x.points, k=k) @@ -728,7 +775,7 @@ def recalculate_tangents(self, k: int, inplace=False): if not inplace: return x - def snap(self, locs, to='points'): + def snap(self, locs, to="points"): """Snap xyz location(s) to closest point or synapse. Parameters @@ -757,15 +804,16 @@ def snap(self, locs, to='points'): """ locs = np.asarray(locs).astype(np.float64) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ['points', 'connectors']: - raise ValueError('`to` must be "points" or "connectors", ' - f'got {to}') + if to not in ["points", "connectors"]: + raise ValueError('`to` must be "points" or "connectors", ' f"got {to}") # Generate tree tree = graph.neuron2KDTree(self, data=to) @@ -775,9 +823,9 @@ def snap(self, locs, to='points'): return ix, dist - def to_skeleton(self, - scale_vec: Union[float, Literal['auto']] = 'auto' - ) -> core.TreeNeuron: + def to_skeleton( + self, scale_vec: Union[float, Literal["auto"]] = "auto" + ) -> core.TreeNeuron: """Turn Dotprop into a TreeNeuron. This does *not* skeletonize the neuron but rather generates a line @@ -801,12 +849,13 @@ def to_skeleton(self, TreeNeuron """ - if not isinstance(scale_vec, numbers.Number) and scale_vec != 'auto': - raise ValueError('`scale_vect` must be "auto" or a number, ' - f'got {scale_vec}') + if not isinstance(scale_vec, numbers.Number) and scale_vec != "auto": + raise ValueError( + '`scale_vect` must be "auto" or a number, ' f"got {scale_vec}" + ) - if scale_vec == 'auto': - scale_vec = self.sampling_resolution * .8 + if scale_vec == "auto": + scale_vec = self.sampling_resolution * 0.8 # Prepare segments - this is based on nat:::plot3d.dotprops halfvect = self.vect / 2 * scale_vec @@ -819,16 +868,16 @@ def to_skeleton(self, segs[1::2] = ends # Generate node table - nodes = pd.DataFrame(segs, columns=['x', 'y', 'z']) - nodes['node_id'] = nodes.index - nodes['parent_id'] = -1 - nodes.loc[1::2, 'parent_id'] = nodes.index.values[::2] + nodes = pd.DataFrame(segs, columns=["x", "y", "z"]) + nodes["node_id"] = nodes.index + nodes["parent_id"] = -1 + nodes.loc[1::2, "parent_id"] = nodes.index.values[::2] # Produce a minimal TreeNeuron tn = core.TreeNeuron(nodes, units=self.units, id=self.id) # Carry over the label - if getattr(self, '_label', None): + if getattr(self, "_label", None): tn._label = self._label # Add some other relevant attributes directly @@ -837,4 +886,3 @@ def to_skeleton(self, tn._soma = self._soma return tn - diff --git a/navis/core/masking.py b/navis/core/masking.py index c5a78861..be74976e 100644 --- a/navis/core/masking.py +++ b/navis/core/masking.py @@ -20,7 +20,6 @@ from .voxel import VoxelNeuron from .mesh import MeshNeuron -from .. import utils __all__ = ["NeuronMask"] @@ -30,7 +29,7 @@ class NeuronMask: Parameters ---------- - x : Neuron/List + x : Neuron/List Neuron(s) to mask. mask : str | array | callable | list | dict The mask to apply: @@ -43,10 +42,8 @@ class NeuronMask: above copy_data : bool Whether to copy the neuron data (e.g. node table for skeletons) - when masking. Setting this to False will may some time and - memory but may lead to e.g. pandas setting-on-copy warnings - if the data is modified. Only set to `True` if you know your - code won't modify the data. + when masking. Set this to `True` if you know your code will modify + the masked data and you want to prevent changes to the original. reset_neurons : bool If True, reset the neurons to their original state after the context manager exits. If False, will try to incorporate any @@ -84,7 +81,7 @@ class NeuronMask: """ - def __init__(self, x, mask, reset_neurons=True, copy_data=True, validate_mask=True): + def __init__(self, x, mask, reset_neurons=True, copy_data=False, validate_mask=True): self.neurons = x if validate_mask: @@ -159,7 +156,7 @@ def __enter__(self): else: mask = self.mask[i] - n.mask(mask, copy=self.copy) + n.mask(mask, copy=self.copy, inplace=True) return self diff --git a/navis/core/mesh.py b/navis/core/mesh.py index d6f78db0..5cca369c 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -428,10 +428,24 @@ def copy(self) -> "MeshNeuron": return x - def mask(self, mask, copy=True): - """Mask neuron with given mask. + def view(self) -> "MeshNeuron": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__(None) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) + + return x - This is always done in-place! + def mask(self, mask, inplace=False, copy=False): + """Mask neuron with given mask. Parameters ---------- @@ -443,8 +457,12 @@ def mask(self, mask, copy=True): The mask can be either for vertices or faces but will ultimately be used to mask out faces. Vertices not participating in any face will be removed regardless of the mask. - copy : bool - Whether to copy mask a copy of the data. Only applies for connectors. + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data (faces, vertices, etc.) after masking. This + is useful if you want to avoid accidentally modifying + the original nodes table. Returns ------- @@ -466,101 +484,111 @@ def mask(self, mask, copy=True): "Please either apply the existing mask or unmask first." ) + n = self + if not inplace: + n = self.view() + if callable(mask): - mask = mask(self) + mask = mask(n) elif isinstance(mask, str): - mask = getattr(self, mask) + mask = getattr(n, mask) mask = np.asarray(mask) # Some checks if mask.dtype != bool: raise ValueError("Mask must be boolean array.") - elif len(mask) not in (self.vertices.shape[0], self.faces.shape[0]): + elif len(mask) not in (n.vertices.shape[0], n.faces.shape[0]): raise ValueError("Mask length does not match number of vertices or faces.") # Transate vertex mask to face mask - if mask.shape[0] == self.vertices.shape[0]: + if mask.shape[0] == n.vertices.shape[0]: vert_mask = mask - face_mask = np.all(mask[self.faces], axis=1) + face_mask = np.all(mask[n.faces], axis=1) + + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + n, vertex_index=np.where(vert_mask)[0], return_map=True + ) else: face_mask = mask - vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) - vert_mask[np.unique(self.faces[face_mask])] = True + vert_mask = np.zeros(n.vertices.shape[0], dtype=bool) + vert_mask[np.unique(n.faces[face_mask])] = True - # Apply mask - verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( - self, vertex_index=np.where(vert_mask)[0], return_map=True - ) + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + n, faces_index=np.where(face_mask)[0], return_map=True + ) # The above will have likely dropped some vertices - we need to update the vertex mask - vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) + vert_mask = np.zeros(n.vertices.shape[0], dtype=bool) vert_mask[np.where(vert_map != -1)[0]] = True # Track mask, vertices and faces before masking - self._mask = face_mask # mask is always the face mask - self._masked_data = {} - self._masked_data["_vertices"] = self._vertices - self._masked_data["_faces"] = self._faces + n._mask = face_mask # mask is always the face mask + n._masked_data = {} + n._masked_data["_vertices"] = n._vertices + n._masked_data["_faces"] = n._faces # Update vertices and faces - self._vertices = verts_new - self._faces = faces_new + n._vertices = verts_new + n._faces = faces_new # See if we can mask the mesh's skeleton as well - if hasattr(self, "_skeleton"): + if hasattr(n, "_skeleton"): # If the skeleton has a vertex map, we can use it to mask the skeleton - if hasattr(self._skeleton, "vertex_map"): + if hasattr(n._skeleton, "vertex_map"): # Generate a mask for the skeleton # (keep in mind vertex_map are node IDs, not indices) - sk_mask = self._skeleton.nodes.node_id.isin( - self._skeleton.vertex_map[vert_mask] + sk_mask = n._skeleton.nodes.node_id.isin( + n._skeleton.vertex_map[vert_mask] ) # Apply mask - self._skeleton.mask(sk_mask) + n._skeleton.mask(sk_mask) # Last but not least: we need to update the vertex map # Track the old map. N.B. we're not adding this to # skeleton._masked_data since the remapping is done by # the MeshNeuron itself! - self._skeleton._vertex_map_unmasked = self._skeleton.vertex_map + n._skeleton._vertex_map_unmasked = n._skeleton.vertex_map # Subset the vertex map to the surviving mesh vertices # N.B. that the node IDs don't change when masking skeletons! - self._skeleton.vertex_map = self._skeleton.vertex_map[vert_mask] + n._skeleton.vertex_map = n._skeleton.vertex_map[vert_mask] # If the skeleton has no vertex map, we have to ditch it and # let it be regenerated when needed else: - self._masked_data["_skeleton"] = self._skeleton - self._skeleton = None # Clear the skeleton + n._masked_data["_skeleton"] = n._skeleton + n._skeleton = None # Clear the skeleton # See if we need to mask any connectors as well - if hasattr(self, "_connectors"): + if hasattr(n, "_connectors"): # Only mask if there is an actual "vertex_ind" or "face_ind" column - cn_mask = None - if "vertex_ind" in self._connectors.columns: - cn_mask = self._connectors.vertex_id.isin(np.where(vert_mask)[0]) - elif "face_ind" in self._connectors.columns: - cn_mask = self._connectors.face_id.isin(np.where(face_mask)[0]) + if "vertex_ind" in n._connectors.columns: + cn_mask = n._connectors.vertex_id.isin(np.where(vert_mask)[0]) + elif "face_ind" in n._connectors.columns: + cn_mask = n._connectors.face_id.isin(np.where(face_mask)[0]) + else: + cn_mask = None if cn_mask is not None: - self._masked_data["_connectors"] = self._connectors - self._connectors = self._connectors.loc[mask] + n._masked_data["_connectors"] = n._connectors + n._connectors = n._connectors.loc[mask] if copy: - self._connectors = self._connectors.copy() + n._connectors = n._connectors.copy() # Check if we need to drop the soma position - if hasattr(self, "soma_pos"): - vid = self.snap(self.soma_pos, to="vertices")[0] + if hasattr(n, "soma_pos"): + vid = n.snap(self.soma_pos, to="vertices")[0] if not vert_mask[vid]: - self._masked_data["_soma_pos"] = self.soma_pos - self.soma_pos = None + n._masked_data["_soma_pos"] = n.soma_pos + n.soma_pos = None - # Clear temporary attributes but keep the skeleton since we already fixed that - self._clear_temp_attr(exclude=["_skeleton"]) + # Clear temporary attributes but keep the skeleton since we already fixed that manually + n._clear_temp_attr(exclude=["_skeleton"]) - return self + return n def unmask(self, reset=True): """Unmask neuron. @@ -631,7 +659,9 @@ def unmask(self, reset=True): # Generate a mesh for the masked-out data: # The mesh prior to masking - pre_mesh = tm.Trimesh(self._masked_data["_vertices"], self._masked_data["_faces"]) + pre_mesh = tm.Trimesh( + self._masked_data["_vertices"], self._masked_data["_faces"] + ) # The vertices and faces that were masked out pre_vertices, pre_faces, vert_map, face_map = morpho.subset.submesh( pre_mesh, faces_index=np.where(~face_mask)[0], return_map=True @@ -661,7 +691,7 @@ def unmask(self, reset=True): # Check if the vertex map is still valid # Note to self: we could do some elaborate checks here to map old to # most likely new vertex / nodes but that's a bit overkill for now. - if hasattr(skeleton, 'vertex_map'): + if hasattr(skeleton, "vertex_map"): if skeleton.vertex_map.shape[0] != self._vertices.shape[0]: skeleton = None elif skeleton.vertex_map.max() >= self._faces.shape[0]: diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 41937e5d..62e97302 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -930,6 +930,22 @@ def copy(self, deepcopy: bool = False) -> "TreeNeuron": return x + def view(self) -> "TreeNeuron": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__(None) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) + + return x + def get_graph_nx(self) -> nx.DiGraph: """Calculate and return networkX representation of neuron. @@ -962,11 +978,9 @@ def get_igraph(self) -> "igraph.Graph": # type: ignore self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph - def mask(self, mask, copy=True): + def mask(self, mask, inplace=False, copy=False) -> "TreeNeuron": """Mask neuron with given mask. - This is always done in-place! - Parameters ---------- mask : np.ndarray @@ -974,10 +988,17 @@ def mask(self, mask, copy=True): - 1D array with boolean values - callable that accepts a neuron and returns a mask - string with column name in nodes table + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data such as the node table after masking. This + is useful if you want to avoid accidentally modifying + the original nodes table. Returns ------- - self + n : TreeNeuron + The masked neuron. See Also -------- @@ -992,45 +1013,49 @@ def mask(self, mask, copy=True): "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." ) + n = self + if not inplace: + n = self.view() + if callable(mask): - mask = mask(self) + mask = mask(n) elif isinstance(mask, str): - mask = self.nodes[mask].values + mask = n.nodes[mask].values mask = np.asarray(mask) if mask.dtype != bool: raise ValueError("Mask must be boolean array.") - elif mask.shape[0] != self.nodes.shape[0]: + elif mask.shape[0] != n.nodes.shape[0]: raise ValueError("Mask must have same length as nodes table.") - self._mask = mask - self._masked_data = {} - self._masked_data["_nodes"] = self.nodes + n._mask = mask + n._masked_data = {} + n._masked_data["_nodes"] = n.nodes # N.B. we're directly setting `._nodes`` to avoid overhead from checks - self._nodes = self._nodes.loc[mask].drop("type", axis=1, errors="ignore") + n._nodes = n._nodes.loc[mask].drop("type", axis=1, errors="ignore") if copy: - self._nodes = self._nodes.copy() + n._nodes = n._nodes.copy() # See if any parent IDs have ceased to exist - missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( - self._nodes.parent_id >= 0 + missing_parents = ~n._nodes.parent_id.isin(n._nodes.node_id) & ( + n._nodes.parent_id >= 0 ) if any(missing_parents): - self.nodes.loc[missing_parents, "parent_id"] = -1 + n.nodes.loc[missing_parents, "parent_id"] = -1 - if hasattr(self, "_connectors"): - self._masked_data["_connectors"] = self.connectors - self._connectors = self._connectors.loc[ - self._connectors.node_id.isin(self.nodes.node_id) + if hasattr(n, "_connectors"): + n._masked_data["_connectors"] = n.connectors + n._connectors = n._connectors.loc[ + n._connectors.node_id.isin(n.nodes.node_id) ] if copy: - self._connectors = self._connectors.copy() + n._connectors = n._connectors.copy() - self._clear_temp_attr() + n._clear_temp_attr() - return self + return n def unmask(self, reset=True): """Unmask neuron. diff --git a/navis/morpho/subset.py b/navis/morpho/subset.py index c828f13d..7b0d21fd 100644 --- a/navis/morpho/subset.py +++ b/navis/morpho/subset.py @@ -401,7 +401,7 @@ def submesh(mesh, *, faces_index=None, vertex_index=None, return_map=False): if faces_index is None and vertex_index is None: raise ValueError("Either `faces_index` or `vertex_index` must be provided.") elif faces_index is not None and vertex_index is not None: - raise ValueError("Only one of `faces_index` or `vertex_index` can be provided.") + raise ValueError("Must provide either `faces_index` or `vertex_index`, not both.") # First check if we can return either an empty mesh or the original mesh right away if faces_index is not None: