diff --git a/docs/api.md b/docs/api.md index a42622a8..fb7cfdc6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -47,13 +47,13 @@ learn more! ``TreeNeurons``, ``MeshNeurons``, ``VoxelNeurons`` and ``Dotprops`` are neuron classes. ``NeuronLists`` are containers thereof. -| Class | Description | -|------|------| -| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | -| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | -| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | -| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | -| [`navis.NeuronList`][] | Containers for neurons. | +| Class | Description | +|-------------------------|---------------------------------------------------------| +| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | +| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | +| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | +| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | +| [`navis.NeuronList`][] | Containers for neurons. | ### General Neuron methods @@ -89,6 +89,7 @@ to all neurons: | `Neuron.type` | {{ autosummary("navis.BaseNeuron.type") }} | | `Neuron.soma` | {{ autosummary("navis.BaseNeuron.soma") }} | | `Neuron.bbox` | {{ autosummary("navis.BaseNeuron.bbox") }} | +| `Neuron.is_masked` | {{ autosummary("navis.BaseNeuron.is_masked") }} | !!! note @@ -119,6 +120,8 @@ this neuron type. Note that most of them are simply short-hands for the other | [`TreeNeuron.reroot()`][navis.TreeNeuron.reroot] | {{ autosummary("navis.TreeNeuron.reroot") }} | | [`TreeNeuron.resample()`][navis.TreeNeuron.resample] | {{ autosummary("navis.TreeNeuron.resample") }} | | [`TreeNeuron.snap()`][navis.TreeNeuron.snap] | {{ autosummary("navis.TreeNeuron.snap") }} | +| [`TreeNeuron.mask()`][navis.TreeNeuron.mask] | {{ autosummary("navis.TreeNeuron.mask") }} | +| [`TreeNeuron.unmask()`][navis.TreeNeuron.unmask] | {{ autosummary("navis.TreeNeuron.unmask") }} | In addition, a [`navis.TreeNeuron`][] has a range of different properties: @@ -146,7 +149,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`TreeNeuron.vertices`][navis.TreeNeuron.vertices] | {{ autosummary("navis.TreeNeuron.vertices") }} | | [`TreeNeuron.volume`][navis.TreeNeuron.volume] | {{ autosummary("navis.TreeNeuron.volume") }} | - #### Skeleton utility functions | Function | Description | @@ -158,7 +160,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`navis.graph.skeleton_adjacency_matrix()`][navis.graph.skeleton_adjacency_matrix] | {{ autosummary("navis.graph.skeleton_adjacency_matrix") }} | - ### Mesh neurons Properties specific to [`navis.MeshNeuron`][]: @@ -178,6 +179,8 @@ Methods specific to [`navis.MeshNeuron`][]: | [`MeshNeuron.skeletonize()`][navis.MeshNeuron.skeletonize] | {{ autosummary("navis.MeshNeuron.skeletonize") }} | | [`MeshNeuron.snap()`][navis.MeshNeuron.snap] | {{ autosummary("navis.MeshNeuron.snap") }} | | [`MeshNeuron.validate()`][navis.MeshNeuron.validate] | {{ autosummary("navis.MeshNeuron.validate") }} | +| [`MeshNeuron.mask()`][navis.MeshNeuron.mask] | {{ autosummary("navis.MeshNeuron.mask") }} | +| [`MeshNeuron.unmask()`][navis.MeshNeuron.unmask] | {{ autosummary("navis.MeshNeuron.unmask") }} | ### Voxel neurons @@ -215,6 +218,8 @@ These are methods and properties specific to [Dotprops][navis.Dotprops]: | [`Dotprops.alpha`][navis.Dotprops.alpha] | {{ autosummary("navis.Dotprops.alpha") }} | | [`Dotprops.to_skeleton()`][navis.Dotprops.to_skeleton] | {{ autosummary("navis.Dotprops.to_skeleton") }} | | [`Dotprops.snap()`][navis.Dotprops.snap] | {{ autosummary("navis.Dotprops.snap") }} | +| [`Dotprops.mask()`][navis.Dotprops.mask] | {{ autosummary("navis.Dotprops.mask") }} | +| [`Dotprops.unmask()`][navis.Dotprops.unmask] | {{ autosummary("navis.Dotprops.unmask") }} | ### Converting between types diff --git a/navis/core/__init__.py b/navis/core/__init__.py index 39b4342e..80973700 100644 --- a/navis/core/__init__.py +++ b/navis/core/__init__.py @@ -18,11 +18,22 @@ from .dotprop import Dotprops from .voxel import VoxelNeuron from .neuronlist import NeuronList +from .masking import NeuronMask from .core_utils import make_dotprops, to_neuron_space, NeuronProcessor from typing import Union NeuronObject = Union[NeuronList, TreeNeuron, BaseNeuron, MeshNeuron] -__all__ = ['Volume', 'Neuron', 'BaseNeuron', 'TreeNeuron', 'MeshNeuron', - 'Dotprops', 'VoxelNeuron', 'NeuronList', 'make_dotprops'] +__all__ = [ + "Volume", + "Neuron", + "BaseNeuron", + "TreeNeuron", + "MeshNeuron", + "NeuronMask", + "Dotprops", + "VoxelNeuron", + "NeuronList", + "make_dotprops", +] diff --git a/navis/core/base.py b/navis/core/base.py index 73bf2bc4..009b0a66 100644 --- a/navis/core/base.py +++ b/navis/core/base.py @@ -46,7 +46,8 @@ def Neuron( - x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata + x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], + **metadata, # noqa: F821 ): """Constructor for Neuron objects. Depending on the input, either a `TreeNeuron` or a `MeshNeuron` is returned. @@ -195,6 +196,9 @@ class BaseNeuron(UnitObject): #: Core data table(s) used to calculate hash CORE_DATA = [] + #: Property used to calculate length of neuron + _LENGTH_DATA = None + def __init__(self, **kwargs): # Set a random ID -> may be replaced later self.id = uuid.uuid4() @@ -303,6 +307,14 @@ def __isub__(self, other): """Subtraction with assignment (-=).""" return self.__sub__(other, copy=False) + def __len__(self): + if self._LENGTH_DATA is None: + return None + # Deal with potential empty neurons + if not hasattr(self, self._LENGTH_DATA): + return 0 + return len(getattr(self, self._LENGTH_DATA)) + def _repr_html_(self): frame = self.summary().to_frame() frame.columns = [""] @@ -654,6 +666,7 @@ def copy(self, deepcopy=False) -> "BaseNeuron": def summary(self, add_props=None) -> pd.Series: """Get a summary of this neuron.""" + # Do not remove the list -> otherwise we might change the original! props = list(self.SUMMARY_PROPS) @@ -721,6 +734,87 @@ def plot3d(self, **kwargs): return plot3d(core.NeuronList(self, make_copy=False), **kwargs) + @property + def is_masked(self): + """Test if neuron is masked. + + See Also + -------- + [`navis.BaseNeuron.mask`][] + Mask neuron. + [`navis.BaseNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + """ + return hasattr(self, "_masked_data") + + def mask(self, mask): + """Mask neuron.""" + raise NotImplementedError( + f"Masking not implemented for neuron of type {type(self)}." + ) + + def unmask(self): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Returns + ------- + self + + See Also + -------- + [`Neuron.is_masked`][navis.BaseNeuron.is_masked] + Check if neuron. is masked. + [`Neuron.mask`][navis.BaseNeuron.unmask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + for k, v in self._masked_data.items(): + if hasattr(self, k): + setattr(self, k, v) + + delattr(self, "_mask") + delattr(self, "_masked_data") + self._clear_temp_attr() + + return self + + def apply_mask(self, inplace=False): + """Apply mask to neuron. + + This will effectively make the mask permanent. + + Parameters + ---------- + inplace : bool + If True will apply mask in-place. If False + will return a copy and the original neuron + will remain masked. + + Returns + ------- + Neuron + Neuron with mask applied. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + n = self if inplace else self.copy() + + delattr(n, "_mask") + delattr(n, "_masked_data") + + return n + def map_units( self, units: Union[pint.Unit, str], diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index 0f51dd1b..d3a69cdc 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -105,11 +105,14 @@ class Dotprops(BaseNeuron): EQ_ATTRIBUTES = ['name', 'n_points', 'k'] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage'] + TEMP_ATTR = ['_memory_usage', "_tree"] #: Core data table(s) used to calculate hash _CORE_DATA = ['points', 'vect'] + #: Property used to calculate length of neuron + _LENGTH_DATA = 'points' + def __init__(self, points: np.ndarray, k: int, @@ -230,9 +233,6 @@ def __getstate__(self): return state - def __len__(self): - return len(self.points) - @property def alpha(self): """Alpha value for tangent vectors (optional).""" @@ -539,6 +539,137 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl if not inplace: return x + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with property name + + Returns + ------- + self + The masked neuron. + + See Also + -------- + [`Dotprops.unmask`][navis.Dotprops.unmask] + Remove mask from neuron. + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = getattr(self, mask) + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != len(self): + raise ValueError("Mask must have same length as points.") + + self._mask = mask + self._masked_data = {} + self._masked_data['_points'] = self.points + + # Drop soma if masked out + if self.soma is not None: + if isinstance(self.soma, (list, np.ndarray)): + soma_left = self.soma[mask[self.soma]] + self._masked_data['_soma'] = self.soma + + if any(soma_left): + self.soma = soma_left + else: + self.soma = None + elif not mask[self.soma]: + self._masked_data['_soma'] = self.soma + self.soma = None + + # N.B. we're directly setting `._nodes`` to avoid overhead from checks + for att in ("_points", "_vect", "_alpha"): + if hasattr(self, att): + self._masked_data[att] = getattr(self, att) + setattr(self, att, getattr(self, att)[mask]) + + if copy: + setattr(self, att, getattr(self, att).copy()) + + if hasattr(self, "_connectors") and "point_ix" in self._connectors.columns: + self._masked_data['connectors'] = self.connectors + self._connectors = self._connectors.loc[ + self.connectors.point_ix.isin(np.arange(len(mask))[mask]) + ] + if copy: + self._connectors = self._connectors.copy() + + self._clear_temp_attr() + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`Dotprops.mask`][navis.Dotprops.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + for k, v in self._masked_data.items(): + # Combine with current data + if hasattr(self, k): + v = np.concatenate((v[~mask], getattr(self, k)), axis=0) + setattr(self, k, v) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self + def recalculate_tangents(self, k: int, inplace=False): """Recalculate tangent vectors and alpha with a new `k`. diff --git a/navis/core/masking.py b/navis/core/masking.py new file mode 100644 index 00000000..345de824 --- /dev/null +++ b/navis/core/masking.py @@ -0,0 +1,204 @@ +# This script is part of navis (http://www.github.com/navis-org/navis). +# Copyright (C) 2018 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import numpy as np +import pandas as pd + +from .neuronlist import NeuronList +from .skeleton import TreeNeuron +from .dotprop import Dotprops +from .voxel import VoxelNeuron +from .mesh import MeshNeuron + +from .. import utils + +__all__ = ["NeuronMask"] + +# mode = "r"\w"? +class NeuronMask: + """Mask neuron(s) by a specific property. + + Parameters + ---------- + x : Neuron/List + Neuron(s) to mask. + mask : str | array | callable | list | dict + The mask to apply: + - str: The name of the property to mask by + - array: boolean mask + - callable: A function that takes a neuron as input + and returns a boolean mask + - list: a list of the above + - dict: a dictionary mapping neuron IDs to one of the + above + copy_data : bool + Whether to copy the neuron data (e.g. node table for skeletons) + when masking. Setting this to False will may some time and + memory but may lead to e.g. pandas setting-on-copy warnings + if the data is modified. Only set to `True` if you know your + code won't modify the data. + reset_neurons : bool + If True, reset the neurons to their original state after the + context manager exits. If False, will try to incorporate any + changes made to the masked neurons. Note that this may not + work for destructive operations. + validate_mask : bool + If True, validate `mask` against the neurons before setting it. + This is recommended but can come with an overhead (in particular + if `mask` is a callable). + + Examples + -------- + >>> import navis + >>> # Grab a few skeletons + >>> nl = navis.example_neurons(3) + >>> # Label axon and dendrites + >>> navis.split_axon_dendrite(nl, label_only=True) + >>> # Mask by axon + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'): + ... print("Axon cable length:", nl.cable_length * nl[0].units) + Axon cable length: [363469.75 411147.1875 390231.8125] nanometer + >>> # Mask by dendrite + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'dendrite'): + ... print("Dendrite cable length:", nl.cable_length * nl[0].units) + Dendrite cable length: [1410770.0 1612187.25 1510453.875] nanometer + + See Also + -------- + [`navis.BaseNeuron.is_masked`][] + Check if a neuron is masked. Property exists for all neuron types. + [`navis.BaseNeuron.mask`][] + Mask a neuron. Implementation details depend on the neuron type. + [`navis.BaseNeuron.unmask`][] + Unmask a neuron. Implementation details depend on the neuron type. + + """ + + def __init__(self, x, mask, reset_neurons=True, copy_data=True, validate_mask=True): + self.neurons = x + + if validate_mask: + self.mask = mask + else: + self._mask = mask + + self.reset = reset_neurons + self.copy = copy_data + + @property + def neurons(self): + return self._neurons + + @neurons.setter + def neurons(self, value): + self._neurons = NeuronList(value) + + if any(n.is_masked for n in self._neurons): + raise MaskingError("At least some neuron(s) are already masked") + + @property + def mask(self): + return self._mask + + @mask.setter + def mask(self, mask): + # Validate the mask + if isinstance(mask, str): + for n in self.neurons: + if isinstance(n, TreeNeuron): + if mask not in n.nodes.columns: + raise MaskingError(f"Neuron does not have '{mask}' column") + elif not hasattr(n, mask): + raise MaskingError(f"Neuron does not have '{mask}' attribute") + elif isinstance(mask, (list, np.ndarray, pd.Series)): + if len(self.neurons) == 1 and len(mask) != 1: + # If we only have one neuron, we can accept a single mask + # but we still want to wrap it in a list for consistency + mask = [np.asarray(mask)] + + if len(mask) != len(self.neurons): + raise MaskingError("Number of masks does not match number of neurons") + + # Validate each mask + for m, n in zip(mask, self.neurons): + validate_mask_length(m, n) + elif isinstance(mask, dict): + for n in self.neurons: + if n.id not in mask: + raise MaskingError(f"Neuron {n.id} not in mask dictionary") + validate_mask_length(mask[n.id], n) + elif callable(mask): + # If this is a function, try calling it on the first neuron + test = mask(self.neurons[0]) + if not isinstance(test, (pd.Series, np.ndarray)) or test.dtype != bool: + raise MaskingError("Callable mask must return a boolean array") + validate_mask_length(test, self.neurons[0]) + else: + raise MaskingError(f"Unexpected mask type: {type(mask)}") + + self._mask = mask + + def __enter__(self): + for i, n in enumerate(self.neurons): + if callable(self.mask): + mask = self.mask(n) + elif isinstance(self.mask, dict): + mask = self.mask[n.id] + elif isinstance(self.mask, str): + mask = self.mask + else: + mask = self.mask[i] + + n.mask(mask, copy=self.copy) + + return self + + def __exit__(self, *args): + for i, n in enumerate(self.neurons): + n.unmask(reset=self.reset) + + +def validate_mask_length(mask, neuron): + """Validate mask length for a given neuron.""" + if callable(mask): + mask = mask(neuron) + elif isinstance(mask, str): + if isinstance(neuron, TreeNeuron): + mask = neuron.nodes[mask] + else: + mask = getattr(neuron, mask) + + if isinstance(mask, list): + mask = np.asarray(mask) + + if not isinstance(mask, (np.ndarray, pd.Series)) or mask.dtype != bool: + raise MaskingError("Mask must be a boolean array") + + if isinstance(neuron, TreeNeuron): + if len(mask) != len(neuron.nodes): + raise MaskingError("Mask length does not match number of nodes") + elif isinstance(neuron, VoxelNeuron): + if neuron._base_data_type == "grid" and mask.shape != neuron.shape: + raise MaskingError("Mask shape does not match voxel shape") + elif len(neuron.voxels) != len(mask): + raise MaskingError("Mask length does not match number of voxels") + elif isinstance(neuron, Dotprops): + if len(mask) != len(neuron.points): + raise MaskingError("Mask length does not match number of points") + elif isinstance(neuron, MeshNeuron): + if len(mask) != len(neuron.vertices): + raise MaskingError("Mask length does not match number of vertices") + + +class MaskingError(Exception): + pass \ No newline at end of file diff --git a/navis/core/mesh.py b/navis/core/mesh.py index 233076e0..d6f78db0 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -26,7 +26,7 @@ from typing import Union, Optional -from .. import utils, config, meshes, conversion, graph +from .. import utils, config, meshes, conversion, graph, morpho from .base import BaseNeuron from .neuronlist import NeuronList from .skeleton import TreeNeuron @@ -39,7 +39,7 @@ xxhash = None -__all__ = ['MeshNeuron'] +__all__ = ["MeshNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -89,24 +89,28 @@ class MeshNeuron(BaseNeuron): soma: Optional[Union[list, np.ndarray]] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'units', 'n_vertices', 'n_faces'] + SUMMARY_PROPS = ["type", "name", "units", "n_vertices", "n_faces"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name', 'n_vertices', 'n_faces'] + EQ_ATTRIBUTES = ["name", "n_vertices", "n_faces"] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage', '_trimesh', '_skeleton', '_igraph', '_graph_nx'] + TEMP_ATTR = ["_memory_usage", "_trimesh", "_skeleton", "_igraph", "_graph_nx"] #: Core data table(s) used to calculate hash - CORE_DATA = ['vertices', 'faces'] - - def __init__(self, - x, - units: Union[pint.Unit, str] = None, - process: bool = True, - validate: bool = False, - **metadata - ): + CORE_DATA = ["vertices", "faces"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "vertices" + + def __init__( + self, + x, + units: Union[pint.Unit, str] = None, + process: bool = True, + validate: bool = False, + **metadata, + ): """Initialize Mesh Neuron.""" super().__init__() @@ -117,12 +121,12 @@ def __init__(self, if isinstance(x, MeshNeuron): self.__dict__.update(x.copy().__dict__) self.vertices, self.faces = x.vertices, x.faces - elif hasattr(x, 'faces') and hasattr(x, 'vertices'): + elif hasattr(x, "faces") and hasattr(x, "vertices"): self.vertices, self.faces = x.vertices, x.faces elif isinstance(x, dict): - if 'faces' not in x or 'vertices' not in x: + if "faces" not in x or "vertices" not in x: raise ValueError('Dictionary must contain "vertices" and "faces"') - self.vertices, self.faces = x['vertices'], x['faces'] + self.vertices, self.faces = x["vertices"], x["faces"] elif isinstance(x, str) and os.path.isfile(x): m = tm.load(x) self.vertices, self.faces = m.vertices, m.faces @@ -134,10 +138,12 @@ def __init__(self, self._skeleton = TreeNeuron(x) elif isinstance(x, tuple): if len(x) != 2 or any([not isinstance(v, np.ndarray) for v in x]): - raise TypeError('Expect tuple to be two arrays: (vertices, faces)') + raise TypeError("Expect tuple to be two arrays: (vertices, faces)") self.vertices, self.faces = x[0], x[1] else: - raise utils.ConstructionError(f'Unable to construct MeshNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct MeshNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -147,9 +153,9 @@ def __init__(self, if process and self.vertices.shape[0]: # For some reason we can't do self._trimesh at this stage - _trimesh = tm.Trimesh(self.vertices, self.faces, - process=process, - validate=validate) + _trimesh = tm.Trimesh( + self.vertices, self.faces, process=process, validate=validate + ) self.vertices = _trimesh.vertices self.faces = _trimesh.faces @@ -174,8 +180,8 @@ def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not callable(v)} # We don't need the trimesh object - if '_trimesh' in state: - _ = state.pop('_trimesh') + if "_trimesh" in state: + _ = state.pop("_trimesh") return state @@ -188,9 +194,9 @@ def __truediv__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.divide(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.divide(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] /= other + n.connectors.loc[:, ["x", "y", "z"]] /= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -209,9 +215,9 @@ def __mul__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.multiply(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.multiply(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] *= other + n.connectors.loc[:, ["x", "y", "z"]] *= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -229,9 +235,9 @@ def __add__(self, other, copy=True): """Implement addition for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.add(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.add(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] += other + n.connectors.loc[:, ["x", "y", "z"]] += other self._clear_temp_attr() @@ -245,9 +251,9 @@ def __sub__(self, other, copy=True): """Implement subtraction for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.subtract(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.subtract(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] -= other + n.connectors.loc[:, ["x", "y", "z"]] -= other self._clear_temp_attr() @@ -261,8 +267,8 @@ def bbox(self) -> np.ndarray: mx = np.max(self.vertices, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -279,7 +285,7 @@ def vertices(self, verts): if not isinstance(verts, np.ndarray): raise TypeError(f'Vertices must be numpy array, got "{type(verts)}"') if verts.ndim != 2: - raise ValueError('Vertices must be 2-dimensional array') + raise ValueError("Vertices must be 2-dimensional array") self._vertices = verts self._clear_temp_attr() @@ -293,16 +299,16 @@ def faces(self, faces): if not isinstance(faces, np.ndarray): raise TypeError(f'Faces must be numpy array, got "{type(faces)}"') if faces.ndim != 2: - raise ValueError('Faces must be 2-dimensional array') + raise ValueError("Faces must be 2-dimensional array") self._faces = faces self._clear_temp_attr() @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of the vertex connectivity.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph @@ -312,7 +318,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of the vertex connectivity.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute self._graph_nx = graph.neuron2nx(self) return self._graph_nx @@ -335,13 +341,13 @@ def volume(self) -> float: @property @temp_property - def skeleton(self) -> 'TreeNeuron': + def skeleton(self) -> "TreeNeuron": """Skeleton representation of this neuron. Uses [`navis.conversion.mesh2skeleton`][]. """ - if not hasattr(self, '_skeleton'): + if not hasattr(self, "_skeleton"): self._skeleton = self.skeletonize() return self._skeleton @@ -357,7 +363,9 @@ def skeleton(self, s): @property def soma(self): """Not implemented for MeshNeurons - use `.soma_pos`.""" - raise AttributeError("MeshNeurons have a soma position (`.soma_pos`), not a soma.") + raise AttributeError( + "MeshNeurons have a soma position (`.soma_pos`), not a soma." + ) @property def soma_pos(self): @@ -365,7 +373,7 @@ def soma_pos(self): Returns `None` if no soma. """ - return getattr(self, '_soma_pos', None) + return getattr(self, "_soma_pos", None) @soma_pos.setter def soma_pos(self, value): @@ -377,38 +385,295 @@ def soma_pos(self, value): try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) self._soma_pos = value @property def type(self) -> str: """Neuron type.""" - return 'navis.MeshNeuron' + return "navis.MeshNeuron" @property @temp_property def trimesh(self): """Trimesh representation of the neuron.""" - if not getattr(self, '_trimesh', None): - self._trimesh = tm.Trimesh(vertices=self._vertices, - faces=self._faces, - process=False) + if not getattr(self, "_trimesh", None): + if hasattr(self, "extra_edges"): + # Only use TrimeshPlus if we actually need it + # to avoid unnecessarily breaking stuff elsewhere + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) + self._trimesh.extra_edges = self.extra_edges + else: + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) return self._trimesh - def copy(self) -> 'MeshNeuron': + def copy(self) -> "MeshNeuron": """Return a copy of the neuron.""" - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new neuron x = self.__class__(None) # Override with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x - def snap(self, locs, to='vertices'): + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - string with property name + - callable that accepts a neuron and returns a valid mask + The mask can be either for vertices or faces but will ultimately be + used to mask out faces. Vertices not participating in any face + will be removed regardless of the mask. + copy : bool + Whether to copy mask a copy of the data. Only applies for connectors. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.unmask`][navis.MeshNeuron.unmask] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked! Layering multiple masks is currently not supported. " + "Please either apply the existing mask or unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = getattr(self, mask) + + mask = np.asarray(mask) + + # Some checks + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif len(mask) not in (self.vertices.shape[0], self.faces.shape[0]): + raise ValueError("Mask length does not match number of vertices or faces.") + + # Transate vertex mask to face mask + if mask.shape[0] == self.vertices.shape[0]: + vert_mask = mask + face_mask = np.all(mask[self.faces], axis=1) + else: + face_mask = mask + vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) + vert_mask[np.unique(self.faces[face_mask])] = True + + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + self, vertex_index=np.where(vert_mask)[0], return_map=True + ) + + # The above will have likely dropped some vertices - we need to update the vertex mask + vert_mask = np.zeros(self.vertices.shape[0], dtype=bool) + vert_mask[np.where(vert_map != -1)[0]] = True + + # Track mask, vertices and faces before masking + self._mask = face_mask # mask is always the face mask + self._masked_data = {} + self._masked_data["_vertices"] = self._vertices + self._masked_data["_faces"] = self._faces + + # Update vertices and faces + self._vertices = verts_new + self._faces = faces_new + + # See if we can mask the mesh's skeleton as well + if hasattr(self, "_skeleton"): + # If the skeleton has a vertex map, we can use it to mask the skeleton + if hasattr(self._skeleton, "vertex_map"): + # Generate a mask for the skeleton + # (keep in mind vertex_map are node IDs, not indices) + sk_mask = self._skeleton.nodes.node_id.isin( + self._skeleton.vertex_map[vert_mask] + ) + + # Apply mask + self._skeleton.mask(sk_mask) + + # Last but not least: we need to update the vertex map + # Track the old map. N.B. we're not adding this to + # skeleton._masked_data since the remapping is done by + # the MeshNeuron itself! + self._skeleton._vertex_map_unmasked = self._skeleton.vertex_map + + # Subset the vertex map to the surviving mesh vertices + # N.B. that the node IDs don't change when masking skeletons! + self._skeleton.vertex_map = self._skeleton.vertex_map[vert_mask] + # If the skeleton has no vertex map, we have to ditch it and + # let it be regenerated when needed + else: + self._masked_data["_skeleton"] = self._skeleton + self._skeleton = None # Clear the skeleton + + # See if we need to mask any connectors as well + if hasattr(self, "_connectors"): + # Only mask if there is an actual "vertex_ind" or "face_ind" column + cn_mask = None + if "vertex_ind" in self._connectors.columns: + cn_mask = self._connectors.vertex_id.isin(np.where(vert_mask)[0]) + elif "face_ind" in self._connectors.columns: + cn_mask = self._connectors.face_id.isin(np.where(face_mask)[0]) + + if cn_mask is not None: + self._masked_data["_connectors"] = self._connectors + self._connectors = self._connectors.loc[mask] + if copy: + self._connectors = self._connectors.copy() + + # Check if we need to drop the soma position + if hasattr(self, "soma_pos"): + vid = self.snap(self.soma_pos, to="vertices")[0] + if not vert_mask[vid]: + self._masked_data["_soma_pos"] = self.soma_pos + self.soma_pos = None + + # Clear temporary attributes but keep the skeleton since we already fixed that + self._clear_temp_attr(exclude=["_skeleton"]) + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.mask`][navis.MeshNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + # First fix the skeleton (if it exists) + skeleton = getattr(self, "_skeleton", None) + if skeleton is not None: + # If the skeleton is not masked, it was created after the masking + # - in which case we have to throw it away because we can't recover + # the full neuron state. + if not skeleton.is_masked: + skeleton = None + else: + # Unmask the skeleton as well + skeleton.unmask(reset=reset) + + # Manually restore the vertex map + # N.B. that any destructive action (e.g. twig pruning) may have + # removed nodes from the skeleton. If that's the case, we can't + # restore the vertex map and have to ditch it. + if hasattr(skeleton, "_vertex_map_unmasked"): + skeleton.vertex_map = skeleton._vertex_map_unmasked + + # Important note: currently the skeleton gets ditched whenever the MeshNeuron + # is stale. That's mainly because (a) functions modify the mesh but not + # the vertex map and (b) re-generating the skeleton is usually cheap. + # In the long run, we need to make sure the skeleton is always in sync + # and not cleared unless that's explicitly requested. + # I'm thinking something like a MeshNeuron.sync_skeleton() method that + # can either sync the skeleton with the mesh or vice versa. + + if reset: + # Unmask and reset (this will clear temporary attributes including the skeleton) + super().unmask() + if skeleton is not None: + self.skeleton = skeleton + return self + + # Regenerate the vertex mask from the stored face mask + face_mask = self._mask + vert_mask = np.zeros(len(self._masked_data["_vertices"]), dtype=bool) + vert_mask[np.unique(self._masked_data["_faces"][face_mask])] = True + + # Generate a mesh for the masked-out data: + # The mesh prior to masking + pre_mesh = tm.Trimesh(self._masked_data["_vertices"], self._masked_data["_faces"]) + # The vertices and faces that were masked out + pre_vertices, pre_faces, vert_map, face_map = morpho.subset.submesh( + pre_mesh, faces_index=np.where(~face_mask)[0], return_map=True + ) + + # Combine the two + comb = tm.util.concatenate( + [tm.Trimesh(self.vertices, self.faces), tm.Trimesh(pre_vertices, pre_faces)] + ) + + # Drop duplicate faces + comb.update_faces(comb.unique_faces()) + + # Merge vertices that are exactly the same + comb.merge_vertices(digis=6) + + # Update the neuron + self._vertices, self._faces = np.asarray(comb.vertices), np.asarray(comb.faces) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + # Check if we can re-use the skeleton + if skeleton is not None: + # Check if the vertex map is still valid + # Note to self: we could do some elaborate checks here to map old to + # most likely new vertex / nodes but that's a bit overkill for now. + if hasattr(skeleton, 'vertex_map'): + if skeleton.vertex_map.shape[0] != self._vertices.shape[0]: + skeleton = None + elif skeleton.vertex_map.max() >= self._faces.shape[0]: + skeleton = None + + # If we still have a skeleton at this point, we can re-use it + if skeleton is not None: + self._skeleton = skeleton + + return self + + def snap(self, locs, to="vertices"): """Snap xyz location(s) to closest vertex or synapse. Parameters @@ -436,15 +701,16 @@ def snap(self, locs, to='vertices'): """ locs = np.asarray(locs).astype(self.vertices.dtype) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ('vertices', 'vertex', 'connectors', 'connectors'): - raise ValueError('`to` must be "vertices" or "connectors", ' - f'got {to}') + if to not in ("vertices", "vertex", "connectors", "connectors"): + raise ValueError('`to` must be "vertices" or "connectors", ' f"got {to}") # Generate tree tree = scipy.spatial.cKDTree(data=self.vertices) @@ -454,7 +720,9 @@ def snap(self, locs, to='vertices'): return ix, dist - def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> 'TreeNeuron': + def skeletonize( + self, method="wavefront", heal=True, inv_dist=None, **kwargs + ) -> "TreeNeuron": """Skeletonize mesh. See [`navis.conversion.mesh2skeleton`][] for details. @@ -477,10 +745,13 @@ def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> Returns ------- skeleton : navis.TreeNeuron + Has a `.vertex_map` attribute that maps each vertex in the + input mesh to a skeleton node ID. """ - return conversion.mesh2skeleton(self, method=method, heal=heal, - inv_dist=inv_dist, **kwargs) + return conversion.mesh2skeleton( + self, method=method, heal=heal, inv_dist=inv_dist, **kwargs + ) def validate(self, inplace=False): """Use trimesh to try and fix some common mesh issues. diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 5bcab701..892e7c88 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -25,7 +25,7 @@ from io import BufferedIOBase -from typing import Union, Callable, List, Sequence, Optional, Dict, overload +from typing import Union, Callable, List, Sequence, Optional, Dict from typing_extensions import Literal from .. import graph, morpho, utils, config, core, sampling, intersection @@ -39,7 +39,7 @@ except ModuleNotFoundError: xxhash = None -__all__ = ['TreeNeuron'] +__all__ = ["TreeNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -52,15 +52,17 @@ def requires_nodes(func): """Return `None` if neuron has no nodes.""" + @functools.wraps(func) def wrapper(*args, **kwargs): self = args[0] # Return 0 - if isinstance(self.nodes, str) and self.nodes == 'NA': - return 'NA' + if isinstance(self.nodes, str) and self.nodes == "NA": + return "NA" if not isinstance(self.nodes, pd.DataFrame): return None return func(*args, **kwargs) + return wrapper @@ -95,8 +97,8 @@ class TreeNeuron(BaseNeuron): nodes: pd.DataFrame - graph: 'nx.DiGraph' - igraph: 'igraph.Graph' # type: ignore # doesn't know iGraph + graph: "nx.DiGraph" + igraph: "igraph.Graph" # type: ignore # doesn't know iGraph n_branches: int n_leafs: int @@ -117,37 +119,63 @@ class TreeNeuron(BaseNeuron): soma_detection_label: Union[float, int, str] = 1 #: Soma radius (e.g. for plotting). If string, must be column in nodes #: table. Default = 'radius'. - soma_radius: Union[float, int, str] = 'radius' + soma_radius: Union[float, int, str] = "radius" # Set default function for soma finding. Default = [`navis.morpho.find_soma`][] - _soma: Union[Callable[['TreeNeuron'], Sequence[int]], int] = morpho.find_soma + _soma: Union[Callable[["TreeNeuron"], Sequence[int]], int] = morpho.find_soma tags: Optional[Dict[str, List[int]]] = None #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['n_nodes', 'n_connectors', 'soma', 'root', - 'n_branches', 'n_leafs', 'cable_length', 'name'] + EQ_ATTRIBUTES = [ + "n_nodes", + "n_connectors", + "soma", + "root", + "n_branches", + "n_leafs", + "cable_length", + "name", + ] #: Temporary attributes that need to be regenerated when data changes. - TEMP_ATTR = ['_igraph', '_graph_nx', '_segments', '_small_segments', - '_geodesic_matrix', 'centrality_method', '_simple', - '_cable_length', '_memory_usage', '_adjacency_matrix'] + TEMP_ATTR = [ + "_igraph", + "_graph_nx", + "_segments", + "_small_segments", + "_geodesic_matrix", + "centrality_method", + "_simple", + "_cable_length", + "_memory_usage", + "_adjacency_matrix", + ] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'n_nodes', 'n_connectors', 'n_branches', - 'n_leafs', 'cable_length', 'soma', 'units'] + SUMMARY_PROPS = [ + "type", + "name", + "n_nodes", + "n_connectors", + "n_branches", + "n_leafs", + "cable_length", + "soma", + "units", + ] #: Core data table(s) used to calculate hash - CORE_DATA = ['nodes:node_id,parent_id,x,y,z'] - - def __init__(self, - x: Union[pd.DataFrame, - BufferedIOBase, - str, - 'TreeNeuron', - nx.DiGraph], - units: Union[pint.Unit, str] = None, - **metadata - ): + CORE_DATA = ["nodes:node_id,parent_id,x,y,z"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "nodes" + + def __init__( + self, + x: Union[pd.DataFrame, BufferedIOBase, str, "TreeNeuron", nx.DiGraph], + units: Union[pint.Unit, str] = None, + **metadata, + ): """Initialize Skeleton Neuron.""" super().__init__() @@ -157,10 +185,12 @@ def __init__(self, if isinstance(x, pd.DataFrame): self.nodes = x elif isinstance(x, pd.Series): - if not hasattr(x, 'nodes'): - raise ValueError('pandas.Series must have `nodes` entry.') + if not hasattr(x, "nodes"): + raise ValueError("pandas.Series must have `nodes` entry.") elif not isinstance(x.nodes, pd.DataFrame): - raise TypeError(f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"') + raise TypeError( + f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"' + ) self.nodes = x.nodes metadata.update(x.to_dict()) elif isinstance(x, nx.Graph): @@ -182,13 +212,15 @@ def __init__(self, elif isinstance(x, tuple): # Tuple of vertices and edges if len(x) != 2: - raise ValueError('Tuple must have 2 elements: vertices and edges.') + raise ValueError("Tuple must have 2 elements: vertices and edges.") self.nodes = graph.edges2neuron(edges=x[1], vertices=x[0]).nodes elif isinstance(x, type(None)): # This is a essentially an empty neuron pass else: - raise utils.ConstructionError(f'Unable to construct TreeNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct TreeNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -218,21 +250,23 @@ def __truediv__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Division by list/array requires 4 ' - 'divisors for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Division by list/array requires 4 " + "divisors for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] /= other + n.nodes[["x", "y", "z", "radius"]] /= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] /= other + n.connectors[["x", "y", "z"]] /= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius /= other @@ -243,7 +277,7 @@ def __truediv__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units * other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -255,21 +289,23 @@ def __mul__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Multiplication by list/array requires 4' - 'multipliers for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Multiplication by list/array requires 4" + "multipliers for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] *= other + n.nodes[["x", "y", "z", "radius"]] *= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] *= other + n.connectors[["x", "y", "z"]] *= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius *= other @@ -280,7 +316,7 @@ def __mul__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units / other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -292,19 +328,21 @@ def __add__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] += other + n.nodes[["x", "y", "z"]] += other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] += other + n.connectors[["x", "y", "z"]] += other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n # If another neuron, return a list of neurons elif isinstance(other, BaseNeuron): @@ -319,19 +357,21 @@ def __sub__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] -= other + n.nodes[["x", "y", "z"]] -= other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] -= other + n.connectors[["x", "y", "z"]] -= other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -341,10 +381,10 @@ def __getstate__(self): # Pickling the graphs actually takes longer than regenerating them # from scratch - if '_graph_nx' in state: - _ = state.pop('_graph_nx') - if '_igraph' in state: - _ = state.pop('_igraph') + if "_graph_nx" in state: + _ = state.pop("_graph_nx") + if "_igraph" in state: + _ = state.pop("_igraph") return state @@ -352,7 +392,7 @@ def __getstate__(self): @temp_property def adjacency_matrix(self): """Adjacency matrix of the skeleton.""" - if not hasattr(self, '_adjacency_matrix'): + if not hasattr(self, "_adjacency_matrix"): self._adjacency_matrix = graph.skeleton_adjacency_matrix(self) return self._adjacency_matrix @@ -360,7 +400,7 @@ def adjacency_matrix(self): @requires_nodes def vertices(self) -> np.ndarray: """Vertices of the skeleton.""" - return self.nodes[['x', 'y', 'z']].values + return self.nodes[["x", "y", "z"]].values @property @requires_nodes @@ -374,7 +414,7 @@ def edges(self) -> np.ndarray: """ not_root = self.nodes[self.nodes.parent_id >= 0] - return not_root[['node_id', 'parent_id']].values + return not_root[["node_id", "parent_id"]].values @property @requires_nodes @@ -387,7 +427,7 @@ def edge_coords(self) -> np.ndarray: Same but with node IDs instead of x/y/z coordinates. """ - locs = self.nodes.set_index('node_id')[['x', 'y', 'z']] + locs = self.nodes.set_index("node_id")[["x", "y", "z"]] edges = self.edges edges_co = np.zeros((edges.shape[0], 2, 3)) edges_co[:, 0, :] = locs.loc[edges[:, 0]].values @@ -396,10 +436,10 @@ def edge_coords(self) -> np.ndarray: @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of this neuron.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute return self.get_igraph() return self._igraph @@ -409,7 +449,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of this neuron.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute return self.get_graph_nx() return self._graph_nx @@ -419,7 +459,7 @@ def graph(self) -> nx.DiGraph: def geodesic_matrix(self): """Matrix with geodesic (along-the-arbor) distance between nodes.""" # If matrix has not yet been generated or needs update - if not hasattr(self, '_geodesic_matrix'): + if not hasattr(self, "_geodesic_matrix"): # (Re-)generate matrix self._geodesic_matrix = graph.geodesic_matrix(self) @@ -429,7 +469,7 @@ def geodesic_matrix(self): @requires_nodes def leafs(self) -> pd.DataFrame: """Leaf node table.""" - return self.nodes[self.nodes['type'] == 'end'] + return self.nodes[self.nodes["type"] == "end"] @property @requires_nodes @@ -441,7 +481,7 @@ def ends(self): @requires_nodes def branch_points(self): """Branch node table.""" - return self.nodes[self.nodes['type'] == 'branch'] + return self.nodes[self.nodes["type"] == "branch"] @property def nodes(self) -> pd.DataFrame: @@ -461,20 +501,24 @@ def nodes(self, v): def _set_nodes(self, v): # Redefine this function in subclass to change validation - self._nodes = utils.validate_table(v, - required=[('node_id', 'rowId', 'node', 'treenode_id', 'PointNo'), - ('parent_id', 'link', 'parent', 'Parent'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z')], - rename=True, - optional={('radius', 'W'): 0}, - restrict=False) + self._nodes = utils.validate_table( + v, + required=[ + ("node_id", "rowId", "node", "treenode_id", "PointNo"), + ("parent_id", "link", "parent", "Parent"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ], + rename=True, + optional={("radius", "W"): 0}, + restrict=False, + ) # Make sure we don't end up with object dtype anywhere as this can # cause problems - for c in ('node_id', 'parent_id'): - if self._nodes[c].dtype == 'O': + for c in ("node_id", "parent_id"): + if self._nodes[c].dtype == "O": self._nodes[c] = self._nodes[c].astype(int) graph.classify_nodes(self) @@ -504,8 +548,7 @@ def is_tree(self) -> bool: @property def subtrees(self) -> List[List[int]]: """List of subtrees. Sorted by size as sets of node IDs.""" - return sorted(graph._connected_components(self), - key=lambda x: -len(x)) + return sorted(graph._connected_components(self), key=lambda x: -len(x)) @property def connectors(self) -> pd.DataFrame: @@ -514,7 +557,7 @@ def connectors(self) -> pd.DataFrame: def _get_connectors(self) -> pd.DataFrame: # Redefine this function in subclass to change how nodes are retrieved - return getattr(self, '_connectors', None) + return getattr(self, "_connectors", None) @connectors.setter def connectors(self, v): @@ -528,15 +571,19 @@ def _set_connectors(self, v): if isinstance(v, type(None)): self._connectors = None else: - self._connectors = utils.validate_table(v, - required=[('connector_id', 'id'), - ('node_id', 'rowId', 'node', 'treenode_id'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z'), - ('type', 'relation', 'label', 'prepost')], - rename=True, - restrict=False) + self._connectors = utils.validate_table( + v, + required=[ + ("connector_id", "id"), + ("node_id", "rowId", "node", "treenode_id"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ("type", "relation", "label", "prepost"), + ], + rename=True, + restrict=False, + ) @property @requires_nodes @@ -550,8 +597,9 @@ def cycles(self) -> Optional[List[int]]: """ try: - c = nx.find_cycle(self.graph, - source=self.nodes[self.nodes.type == 'end'].node_id.values) + c = nx.find_cycle( + self.graph, source=self.nodes[self.nodes.type == "end"].node_id.values + ) return c except nx.exception.NetworkXNoCycle: return None @@ -559,16 +607,16 @@ def cycles(self) -> Optional[List[int]]: raise @property - def simple(self) -> 'TreeNeuron': + def simple(self) -> "TreeNeuron": """Simplified representation consisting only of root, branch points and leafs.""" - if not hasattr(self, '_simple'): + if not hasattr(self, "_simple"): self._simple = self.copy() # Make sure we don't have a soma, otherwise that node will be preserved self._simple.soma = None # Downsample - self._simple.downsample(float('inf'), inplace=True) + self._simple.downsample(float("inf"), inplace=True) return self._simple @property @@ -592,11 +640,11 @@ def soma(self) -> Optional[Union[str, int]]: if all(pd.isnull(soma)): soma = None elif not any(self.nodes.node_id.isin(soma)): - logger.warning(f'Soma(s) {soma} not found in node table.') + logger.warning(f"Soma(s) {soma} not found in node table.") soma = None else: if soma not in self.nodes.node_id.values: - logger.warning(f'Soma {soma} not found in node table.') + logger.warning(f"Soma {soma} not found in node table.") soma = None return soma @@ -604,7 +652,7 @@ def soma(self) -> Optional[Union[str, int]]: @soma.setter def soma(self, value: Union[Callable, int, None]) -> None: """Set soma.""" - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self._soma = types.MethodType(value, self) elif isinstance(value, type(None)): self._soma = None @@ -614,7 +662,7 @@ def soma(self, value: Union[Callable, int, None]) -> None: if value in self.nodes.node_id.values: self._soma = value else: - raise ValueError('Soma must be function, None or a valid node ID.') + raise ValueError("Soma must be function, None or a valid node ID.") @property def soma_pos(self) -> Optional[Sequence]: @@ -633,7 +681,7 @@ def soma_pos(self) -> Optional[Sequence]: else: soma = utils.make_iterable(soma) - return self.nodes.loc[self.nodes.node_id.isin(soma), ['x', 'y', 'z']].values + return self.nodes.loc[self.nodes.node_id.isin(soma), ["x", "y", "z"]].values @soma_pos.setter def soma_pos(self, value: Sequence) -> None: @@ -645,16 +693,20 @@ def soma_pos(self, value: Sequence) -> None: try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) # Generate tree - id, dist = self.snap(value, to='nodes') + id, dist = self.snap(value, to="nodes") # A sanity check if dist > (self.sampling_resolution * 10): - logger.warning(f'New soma position for {self.id} is suspiciously ' - f'far away from the closest node: {dist}') + logger.warning( + f"New soma position for {self.id} is suspiciously " + f"far away from the closest node: {dist}" + ) self.soma = id @@ -673,26 +725,26 @@ def root(self, value: Union[int, List[int]]) -> None: @property def type(self) -> str: """Neuron type.""" - return 'navis.TreeNeuron' + return "navis.TreeNeuron" @property @requires_nodes def n_branches(self) -> Optional[int]: """Number of branch points.""" - return self.nodes[self.nodes.type == 'branch'].shape[0] + return self.nodes[self.nodes.type == "branch"].shape[0] @property @requires_nodes def n_leafs(self) -> Optional[int]: """Number of leaf nodes.""" - return self.nodes[self.nodes.type == 'end'].shape[0] + return self.nodes[self.nodes.type == "end"].shape[0] @property @temp_property @add_units(compact=True) def cable_length(self) -> Union[int, float]: """Cable length.""" - if not hasattr(self, '_cable_length'): + if not hasattr(self, "_cable_length"): self._cable_length = morpho.cable_length(self) return self._cable_length @@ -700,17 +752,21 @@ def cable_length(self) -> Union[int, float]: @add_units(compact=True, power=2) def surface_area(self) -> float: """Radius-based lateral surface area.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - area will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - area will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -721,23 +777,27 @@ def surface_area(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (np.pi * (r1 + r2) * np.sqrt( (r1-r2)**2 + h**2)).sum() + return (np.pi * (r1 + r2) * np.sqrt((r1 - r2) ** 2 + h**2)).sum() @property @add_units(compact=True, power=3) def volume(self) -> float: """Radius-based volume.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - volume will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - volume will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -748,17 +808,17 @@ def volume(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (1/3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() + return (1 / 3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() @property def bbox(self) -> np.ndarray: """Bounding box (includes connectors).""" - mn = np.min(self.nodes[['x', 'y', 'z']].values, axis=0) - mx = np.max(self.nodes[['x', 'y', 'z']].values, axis=0) + mn = np.min(self.nodes[["x", "y", "z"]].values, axis=0) + mx = np.max(self.nodes[["x", "y", "z"]].values, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -780,9 +840,9 @@ def sampling_resolution(self) -> float: def segments(self) -> List[list]: """Neuron broken down into linear segments (see also `.small_segments`).""" # Calculate if required - if not hasattr(self, '_segments'): + if not hasattr(self, "_segments"): # This also sets the attribute - self._segments = self._get_segments(how='length') + self._segments = self._get_segments(how="length") return self._segments @property @@ -790,19 +850,18 @@ def segments(self) -> List[list]: def small_segments(self) -> List[list]: """Neuron broken down into small linear segments (see also `.segments`).""" # Calculate if required - if not hasattr(self, '_small_segments'): + if not hasattr(self, "_small_segments"): # This also sets the attribute - self._small_segments = self._get_segments(how='break') + self._small_segments = self._get_segments(how="break") return self._small_segments - def _get_segments(self, - how: Union[Literal['length'], - Literal['break']] = 'length' - ) -> List[list]: + def _get_segments( + self, how: Union[Literal["length"], Literal["break"]] = "length" + ) -> List[list]: """Generate segments for neuron.""" - if how == 'length': + if how == "length": return graph._generate_segments(self) - elif how == 'break': + elif how == "break": return graph._break_segments(self) else: raise ValueError(f'Unknown method: "{how}"') @@ -830,11 +889,11 @@ def _clear_temp_attr(self, exclude: list = []) -> None: elif self._soma not in self.nodes.node_id.values: self.soma = None - if 'classify_nodes' not in exclude: + if "classify_nodes" not in exclude: # Reclassify nodes graph.classify_nodes(self, inplace=True) - def copy(self, deepcopy: bool = False) -> 'TreeNeuron': + def copy(self, deepcopy: bool = False) -> "TreeNeuron": """Return a copy of the neuron. Parameters @@ -850,17 +909,19 @@ def copy(self, deepcopy: bool = False) -> 'TreeNeuron': TreeNeuron """ - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new empty neuron x = self.__class__(None) # Populate with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) # Copy graphs only if neuron is not stale if not self.is_stale: - if '_graph_nx' in self.__dict__: + if "_graph_nx" in self.__dict__: x._graph_nx = self._graph_nx.copy(as_view=deepcopy is not True) - if '_igraph' in self.__dict__: + if "_igraph" in self.__dict__: if self._igraph is not None: # This is pretty cheap, so we will always make a deep copy x._igraph = self._igraph.copy() @@ -883,7 +944,7 @@ def get_graph_nx(self) -> nx.DiGraph: self._graph_nx = graph.neuron2nx(self) return self._graph_nx - def get_igraph(self) -> 'igraph.Graph': # type: ignore + def get_igraph(self) -> "igraph.Graph": # type: ignore """Calculate and return iGraph representation of neuron. Once calculated stored as `.igraph`. Call function again to update @@ -901,11 +962,173 @@ def get_igraph(self) -> 'igraph.Graph': # type: ignore self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph - @overload - def resample(self, resample_to: int, inplace: Literal[False]) -> 'TreeNeuron': ... + def mask(self, mask, copy=True): + """Mask neuron with given mask. + + This is always done in-place! + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with column name in nodes table + + Returns + ------- + self + + See Also + -------- + [`navis.MeshNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + if callable(mask): + mask = mask(self) + elif isinstance(mask, str): + mask = self.nodes[mask].values + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != self.nodes.shape[0]: + raise ValueError("Mask must have same length as nodes table.") + + self._mask = mask + self._masked_data = {} + self._masked_data["_nodes"] = self.nodes + + # N.B. we're directly setting `._nodes`` to avoid overhead from checks + self._nodes = self._nodes.loc[mask] + if copy: + self._nodes = self._nodes.copy() - @overload - def resample(self, resample_to: int, inplace: Literal[True]) -> None: ... + if hasattr(self, "_connectors"): + self._masked_data["_connectors"] = self.connectors + self._connectors = self._connectors.loc[ + self._connectors.node_id.isin(self.nodes.node_id) + ] + if copy: + self._connectors = self._connectors.copy() + + self._clear_temp_attr() + + return self + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`TreeNeuron.is_masked`][navis.TreeNeuron.is_masked] + Check if neuron is masked. + [`TreeNeuron.mask`][navis.TreeNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + + # Combine post-mask data with original data + post_nodes = self.nodes + pre_nodes = self._masked_data["_nodes"].iloc[~mask] + + # We need to take care of a few things: + # 1. Make sure don't have any duplicate node IDs + duplicates = post_nodes.node_id.values[ + post_nodes.node_id.isin(pre_nodes.node_id) + ] + if any(duplicates): + # Start with a new max index + max_id = max(post_nodes.node_id.max(), pre_nodes.node_id.max()) + 1 + # New indices + new_ix = dict(zip(duplicates, range(max_id, max_id + len(duplicates)))) + # Update node IDs and parent IDs + post_nodes["node_id"] = post_nodes.node_id.replace(new_ix) + post_nodes["parent_id"] = post_nodes.parent_id.replace(new_ix) + + # Concatenate + self._nodes = pd.concat([pre_nodes, post_nodes], ignore_index=True) + + # 2. Re-connect the root nodes of the masked data + post_roots = post_nodes[post_nodes.parent_id < 0].node_id.values + pre_parents = ( + self._masked_data["_nodes"].set_index("node_id").parent_id.to_dict() + ) + to_fix = {} + for r in post_roots: + # Skip if this root is not in pre_parents + if r not in pre_parents: + continue + # Skip if this was also a root in the pre-masked data + if pre_parents[r] >= 0: + continue + # Skip if the old parent does not exist anymore + if pre_parents[r] not in self.nodes.node_id.values: + continue + # If we made it here, there is a new root that we can connect to the old parent + to_fix[r] = pre_parents[r] + if to_fix: + # Fix parent IDs + self._nodes["parent_id"] = self._nodes["parent_id"].replace(to_fix) + + # 3. See if any parent IDs have ceased to exist + missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( + self._nodes.parent_id >= 0 + ) + if any(missing_parents): + self.nodes.loc[missing_parents, "parent_id"] = -1 + + # TODO: Make sure that edges have a consistent orientation + # (not sure this is much of a problem) + + # Connectors + # TODO: check `node_id` of connectors too + if "_connectors" in self._masked_data: + if not hasattr(self, "_connectors"): + self._connectors = self._masked_data["_connectors"] + else: + cn = self._masked_data["_connectors"] + cn = cn.loc[cn.node_id.isin(pre_nodes.node_id.values)] + self._connectors = pd.concat([self._connectors, cn], ignore_index=True) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self def resample(self, resample_to, inplace=False): """Resample neuron to given resolution. @@ -939,18 +1162,6 @@ def resample(self, resample_to, inplace=False): return x return None - @overload - def downsample(self, - factor: float, - inplace: Literal[False], - **kwargs) -> 'TreeNeuron': ... - - @overload - def downsample(self, - factor: float, - inplace: Literal[True], - **kwargs) -> None: ... - def downsample(self, factor=5, inplace=False, **kwargs): """Downsample the neuron by given factor. @@ -987,9 +1198,9 @@ def downsample(self, factor=5, inplace=False, **kwargs): return x return None - def reroot(self, - new_root: Union[int, str], - inplace: bool = False) -> Optional['TreeNeuron']: + def reroot( + self, new_root: Union[int, str], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Reroot neuron to given node ID or node tag. Parameters @@ -1020,9 +1231,9 @@ def reroot(self, return x return None - def prune_distal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_distal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Cut off nodes distal to given nodes. Parameters @@ -1047,7 +1258,7 @@ def prune_distal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - prox = graph.cut_skeleton(x, n, ret='proximal')[0] + prox = graph.cut_skeleton(x, n, ret="proximal")[0] # Reinitialise with proximal data x.__init__(prox) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1057,9 +1268,9 @@ def prune_distal_to(self, return x return None - def prune_proximal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_proximal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Remove nodes proximal to given node. Reroots neuron to cut node. Parameters @@ -1084,7 +1295,7 @@ def prune_proximal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - dist = graph.cut_skeleton(x, n, ret='distal')[0] + dist = graph.cut_skeleton(x, n, ret="distal")[0] # Reinitialise with distal data x.__init__(dist) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1097,9 +1308,9 @@ def prune_proximal_to(self, return x return None - def prune_by_strahler(self, - to_prune: Union[int, List[int], slice], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_by_strahler( + self, to_prune: Union[int, List[int], slice], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Prune neuron based on [Strahler order](https://en.wikipedia.org/wiki/Strahler_number). Will reroot neuron to soma if possible. @@ -1132,8 +1343,7 @@ def prune_by_strahler(self, else: x = self.copy() - morpho.prune_by_strahler( - x, to_prune=to_prune, reroot_soma=True, inplace=True) + morpho.prune_by_strahler(x, to_prune=to_prune, reroot_soma=True, inplace=True) # No need to call this as morpho.prune_by_strahler does this already # self._clear_temp_attr() @@ -1142,17 +1352,23 @@ def prune_by_strahler(self, return x return None - def prune_twigs(self, - size: float, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> Optional['TreeNeuron']: + def prune_twigs( + self, + size: float, + mask: Optional[np.ndarray] = None, + inplace: bool = False, + recursive: Union[int, bool, float] = False, + ) -> Optional["TreeNeuron"]: """Prune terminal twigs under a given size. Parameters ---------- size : int | float Twigs shorter than this will be pruned. + mask : iterable | callable, optional + Either a boolean mask, a list of node IDs or a callable taking + a neuron as input and returning one of the former. If provided, + only nodes that are in the mask will be considered for pruning. inplace : bool, optional If False, pruning is performed on copy of original neuron which is then returned. @@ -1172,17 +1388,18 @@ def prune_twigs(self, else: x = self.copy() - morpho.prune_twigs(x, size=size, inplace=True) + morpho.prune_twigs(x, size=size, mask=mask, inplace=True) if not inplace: return x return None - def prune_at_depth(self, - depth: Union[float, int], - source: Optional[int] = None, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_at_depth( + self, + depth: Union[float, int], + source: Optional[int] = None, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune all neurites past a given distance from a source. Parameters @@ -1220,10 +1437,11 @@ def prune_at_depth(self, return x return None - def cell_body_fiber(self, - reroot_soma: bool = True, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def cell_body_fiber( + self, + reroot_soma: bool = True, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron to its cell body fiber. Parameters @@ -1255,11 +1473,12 @@ def cell_body_fiber(self, return x return None - def prune_by_longest_neurite(self, - n: int = 1, - reroot_soma: bool = False, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def prune_by_longest_neurite( + self, + n: int = 1, + reroot_soma: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron down to the longest neurite. Parameters @@ -1284,8 +1503,7 @@ def prune_by_longest_neurite(self, else: x = self.copy() - graph.longest_neurite( - x, n, inplace=True, reroot_soma=reroot_soma) + graph.longest_neurite(x, n, inplace=True, reroot_soma=reroot_soma) # Clear temporary attributes x._clear_temp_attr() @@ -1294,14 +1512,13 @@ def prune_by_longest_neurite(self, return x return None - def prune_by_volume(self, - v: Union[core.Volume, - List[core.Volume], - Dict[str, core.Volume]], - mode: Union[Literal['IN'], Literal['OUT']] = 'IN', - prevent_fragments: bool = False, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_by_volume( + self, + v: Union[core.Volume, List[core.Volume], Dict[str, core.Volume]], + mode: Union[Literal["IN"], Literal["OUT"]] = "IN", + prevent_fragments: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron by intersection with given volume(s). Parameters @@ -1330,9 +1547,9 @@ def prune_by_volume(self, else: x = self.copy() - intersection.in_volume(x, v, inplace=True, - prevent_fragments=prevent_fragments, - mode=mode) + intersection.in_volume( + x, v, inplace=True, prevent_fragments=prevent_fragments, mode=mode + ) # Clear temporary attributes # x._clear_temp_attr() @@ -1341,9 +1558,7 @@ def prune_by_volume(self, return x return None - def to_swc(self, - filename: Optional[str] = None, - **kwargs) -> None: + def to_swc(self, filename: Optional[str] = None, **kwargs) -> None: """Generate SWC file from this neuron. Parameters @@ -1365,9 +1580,10 @@ def to_swc(self, """ return io.write_swc(self, filename, **kwargs) # type: ignore # double import of "io" - def reload(self, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def reload( + self, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Reload neuron. Must have filepath as `.origin` as attribute. Returns @@ -1376,19 +1592,22 @@ def reload(self, If `inplace=False`. """ - if not hasattr(self, 'origin'): - raise AttributeError('To reload TreeNeuron must have `.origin` ' - 'attribute') + if not hasattr(self, "origin"): + raise AttributeError( + "To reload TreeNeuron must have `.origin` " "attribute" + ) - if self.origin in ('DataFrame', 'string'): - raise ValueError('Unable to reload TreeNeuron: it appears to have ' - 'been created from string or DataFrame.') + if self.origin in ("DataFrame", "string"): + raise ValueError( + "Unable to reload TreeNeuron: it appears to have " + "been created from string or DataFrame." + ) kwargs = {} - if hasattr(self, 'soma_label'): - kwargs['soma_label'] = self.soma_label - if hasattr(self, 'connector_labels'): - kwargs['connector_labels'] = self.connector_labels + if hasattr(self, "soma_label"): + kwargs["soma_label"] = self.soma_label + if hasattr(self, "connector_labels"): + kwargs["connector_labels"] = self.connector_labels x = io.read_swc(self.origin, **kwargs) @@ -1403,7 +1622,7 @@ def reload(self, x2._clear_temp_attr() return x - def snap(self, locs, to='nodes'): + def snap(self, locs, to="nodes"): """Snap xyz location(s) to closest node or synapse. Parameters @@ -1431,15 +1650,16 @@ def snap(self, locs, to='nodes'): """ locs = np.asarray(locs).astype(np.float64) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ['nodes', 'connectors']: - raise ValueError('`to` must be "nodes" or "connectors", ' - f'got {to}') + if to not in ["nodes", "connectors"]: + raise ValueError('`to` must be "nodes" or "connectors", ' f"got {to}") # Generate tree tree = graph.neuron2KDTree(self, data=to) @@ -1447,7 +1667,7 @@ def snap(self, locs, to='nodes'): # Find the closest node dist, ix = tree.query(locs) - if to == 'nodes': + if to == "nodes": id = self.nodes.node_id.values[ix] else: id = self.connectors.connector_id.values[ix] diff --git a/navis/core/voxel.py b/navis/core/voxel.py index 4a3e5cc3..b8934851 100644 --- a/navis/core/voxel.py +++ b/navis/core/voxel.py @@ -94,6 +94,9 @@ class VoxelNeuron(BaseNeuron): #: Core data table(s) used to calculate hash CORE_DATA = ['_data'] + #: Property used to calculate length of neuron + _LENGTH_DATA = 'voxels' + def __init__(self, x: Union[np.ndarray], offset: Optional[np.ndarray] = None, @@ -216,6 +219,12 @@ def __sub__(self, other, copy=True): return n return NotImplemented + def __len__(self): + """Return number of voxels.""" + # This is the only neuron for which we implement __len__ separately + # to avoid the potential overhead of generating the voxel representation + return np.product(self.shape) + @property def _base_data_type(self) -> str: """Type of data (grid or voxels) underlying this neuron."""