diff --git a/docs/api.md b/docs/api.md index a42622a8..ecf28690 100644 --- a/docs/api.md +++ b/docs/api.md @@ -47,13 +47,13 @@ learn more! ``TreeNeurons``, ``MeshNeurons``, ``VoxelNeurons`` and ``Dotprops`` are neuron classes. ``NeuronLists`` are containers thereof. -| Class | Description | -|------|------| -| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | -| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | -| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | -| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | -| [`navis.NeuronList`][] | Containers for neurons. | +| Class | Description | +|-------------------------|---------------------------------------------------------| +| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. | +| [`navis.MeshNeuron`][] | Meshes with vertices and faces. | +| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). | +| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. | +| [`navis.NeuronList`][] | Containers for neurons. | ### General Neuron methods @@ -89,6 +89,7 @@ to all neurons: | `Neuron.type` | {{ autosummary("navis.BaseNeuron.type") }} | | `Neuron.soma` | {{ autosummary("navis.BaseNeuron.soma") }} | | `Neuron.bbox` | {{ autosummary("navis.BaseNeuron.bbox") }} | +| `Neuron.is_masked` | {{ autosummary("navis.BaseNeuron.is_masked") }} | !!! note @@ -119,6 +120,8 @@ this neuron type. Note that most of them are simply short-hands for the other | [`TreeNeuron.reroot()`][navis.TreeNeuron.reroot] | {{ autosummary("navis.TreeNeuron.reroot") }} | | [`TreeNeuron.resample()`][navis.TreeNeuron.resample] | {{ autosummary("navis.TreeNeuron.resample") }} | | [`TreeNeuron.snap()`][navis.TreeNeuron.snap] | {{ autosummary("navis.TreeNeuron.snap") }} | +| [`TreeNeuron.mask()`][navis.TreeNeuron.mask] | {{ autosummary("navis.TreeNeuron.mask") }} | +| [`TreeNeuron.unmask()`][navis.TreeNeuron.unmask] | {{ autosummary("navis.TreeNeuron.unmask") }} | In addition, a [`navis.TreeNeuron`][] has a range of different properties: @@ -146,7 +149,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`TreeNeuron.vertices`][navis.TreeNeuron.vertices] | {{ autosummary("navis.TreeNeuron.vertices") }} | | [`TreeNeuron.volume`][navis.TreeNeuron.volume] | {{ autosummary("navis.TreeNeuron.volume") }} | - #### Skeleton utility functions | Function | Description | @@ -158,7 +160,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties: | [`navis.graph.skeleton_adjacency_matrix()`][navis.graph.skeleton_adjacency_matrix] | {{ autosummary("navis.graph.skeleton_adjacency_matrix") }} | - ### Mesh neurons Properties specific to [`navis.MeshNeuron`][]: @@ -178,6 +179,8 @@ Methods specific to [`navis.MeshNeuron`][]: | [`MeshNeuron.skeletonize()`][navis.MeshNeuron.skeletonize] | {{ autosummary("navis.MeshNeuron.skeletonize") }} | | [`MeshNeuron.snap()`][navis.MeshNeuron.snap] | {{ autosummary("navis.MeshNeuron.snap") }} | | [`MeshNeuron.validate()`][navis.MeshNeuron.validate] | {{ autosummary("navis.MeshNeuron.validate") }} | +| [`MeshNeuron.mask()`][navis.MeshNeuron.mask] | {{ autosummary("navis.MeshNeuron.mask") }} | +| [`MeshNeuron.unmask()`][navis.MeshNeuron.unmask] | {{ autosummary("navis.MeshNeuron.unmask") }} | ### Voxel neurons @@ -215,6 +218,8 @@ These are methods and properties specific to [Dotprops][navis.Dotprops]: | [`Dotprops.alpha`][navis.Dotprops.alpha] | {{ autosummary("navis.Dotprops.alpha") }} | | [`Dotprops.to_skeleton()`][navis.Dotprops.to_skeleton] | {{ autosummary("navis.Dotprops.to_skeleton") }} | | [`Dotprops.snap()`][navis.Dotprops.snap] | {{ autosummary("navis.Dotprops.snap") }} | +| [`Dotprops.mask()`][navis.Dotprops.mask] | {{ autosummary("navis.Dotprops.mask") }} | +| [`Dotprops.unmask()`][navis.Dotprops.unmask] | {{ autosummary("navis.Dotprops.unmask") }} | ### Converting between types @@ -575,6 +580,14 @@ Functions to export neurons. | [`navis.write_precomputed()`][navis.write_precomputed] | {{ autosummary("navis.write_precomputed") }} | | [`navis.write_parquet()`][navis.write_parquet] | {{ autosummary("navis.write_parquet") }} | +## Masking + +Functions and classes for masking: + +| Function/Class | Description | +|----------------|-------------| +| [`navis.NeuronMask`][navis.NeuronMask] | {{ autosummary("navis.NeuronMask") }} | + ## Utility Various utility functions. diff --git a/navis/core/__init__.py b/navis/core/__init__.py index 39b4342e..80973700 100644 --- a/navis/core/__init__.py +++ b/navis/core/__init__.py @@ -18,11 +18,22 @@ from .dotprop import Dotprops from .voxel import VoxelNeuron from .neuronlist import NeuronList +from .masking import NeuronMask from .core_utils import make_dotprops, to_neuron_space, NeuronProcessor from typing import Union NeuronObject = Union[NeuronList, TreeNeuron, BaseNeuron, MeshNeuron] -__all__ = ['Volume', 'Neuron', 'BaseNeuron', 'TreeNeuron', 'MeshNeuron', - 'Dotprops', 'VoxelNeuron', 'NeuronList', 'make_dotprops'] +__all__ = [ + "Volume", + "Neuron", + "BaseNeuron", + "TreeNeuron", + "MeshNeuron", + "NeuronMask", + "Dotprops", + "VoxelNeuron", + "NeuronList", + "make_dotprops", +] diff --git a/navis/core/base.py b/navis/core/base.py index 73bf2bc4..219f71c1 100644 --- a/navis/core/base.py +++ b/navis/core/base.py @@ -46,7 +46,8 @@ def Neuron( - x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata + x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], # noqa: F821 + **metadata, # noqa: F821 ): """Constructor for Neuron objects. Depending on the input, either a `TreeNeuron` or a `MeshNeuron` is returned. @@ -195,6 +196,9 @@ class BaseNeuron(UnitObject): #: Core data table(s) used to calculate hash CORE_DATA = [] + #: Property used to calculate length of neuron + _LENGTH_DATA = None + def __init__(self, **kwargs): # Set a random ID -> may be replaced later self.id = uuid.uuid4() @@ -303,6 +307,14 @@ def __isub__(self, other): """Subtraction with assignment (-=).""" return self.__sub__(other, copy=False) + def __len__(self): + if self._LENGTH_DATA is None: + return None + # Deal with potential empty neurons + if not hasattr(self, self._LENGTH_DATA): + return 0 + return len(getattr(self, self._LENGTH_DATA)) + def _repr_html_(self): frame = self.summary().to_frame() frame.columns = [""] @@ -652,8 +664,13 @@ def copy(self, deepcopy=False) -> "BaseNeuron": return x + def view(self) -> "BaseNeuron": + """Create a view of the neuron without copying data.""" + raise NotImplementedError(f"View not implemented for neuron of type {type(self)}.") + def summary(self, add_props=None) -> pd.Series: """Get a summary of this neuron.""" + # Do not remove the list -> otherwise we might change the original! props = list(self.SUMMARY_PROPS) @@ -674,6 +691,11 @@ def summary(self, add_props=None) -> pd.Series: warnings.simplefilter("ignore") s = pd.Series([getattr(self, at, "NA") for at in props], index=props) + # Show mask status + if self.is_masked: + if "masked" not in s.index: + s["masked"] = True + return s def plot2d(self, **kwargs): @@ -721,6 +743,100 @@ def plot3d(self, **kwargs): return plot3d(core.NeuronList(self, make_copy=False), **kwargs) + @property + def is_masked(self): + """Test if neuron is masked. + + See Also + -------- + [`navis.BaseNeuron.mask`][] + Mask neuron. + [`navis.BaseNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + """ + return hasattr(self, "_masked_data") + + def mask(self, mask): + """Mask neuron. + + Implementation details depend on the neuron type (see below). + + See Also + -------- + [`navis.TreeNeuron.mask`][] + Mask skeleton. + [`navis.MeshNeuron.mask`][] + Mask mesh. + [`navis.Dotprops.mask`][] + Mask dotprops. + + """ + raise NotImplementedError( + f"Masking not implemented for neuron of type {type(self)}." + ) + + def unmask(self): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Returns + ------- + self + + See Also + -------- + [`Neuron.is_masked`][navis.BaseNeuron.is_masked] + Check if neuron. is masked. + [`Neuron.mask`][navis.BaseNeuron.unmask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + for k, v in self._masked_data.items(): + if hasattr(self, k): + setattr(self, k, v) + + delattr(self, "_mask") + delattr(self, "_masked_data") + self._clear_temp_attr() + + return self + + def apply_mask(self, inplace=False): + """Apply mask to neuron. + + This will effectively make the mask permanent. + + Parameters + ---------- + inplace : bool + If True will apply mask in-place. If False + will return a copy and the original neuron + will remain masked. + + Returns + ------- + Neuron + Neuron with mask applied. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + n = self if inplace else self.copy() + + delattr(n, "_mask") + delattr(n, "_masked_data") + + return n + def map_units( self, units: Union[pint.Unit, str], diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index 0f51dd1b..7d83d5e8 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -37,7 +37,7 @@ except ModuleNotFoundError: from scipy.spatial import cKDTree as KDTree -__all__ = ['Dotprops'] +__all__ = ["Dotprops"] # Set up logging logger = config.get_logger(__name__) @@ -93,31 +93,35 @@ class Dotprops(BaseNeuron): points: np.ndarray alpha: np.ndarray - vect: np.ndarray + vect: np.ndarray k: Optional[int] soma: Optional[Union[list, np.ndarray]] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'k', 'units', 'n_points'] + SUMMARY_PROPS = ["type", "name", "k", "units", "n_points"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name', 'n_points', 'k'] + EQ_ATTRIBUTES = ["name", "n_points", "k"] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage'] + TEMP_ATTR = ["_memory_usage", "_tree"] #: Core data table(s) used to calculate hash - _CORE_DATA = ['points', 'vect'] - - def __init__(self, - points: np.ndarray, - k: int, - vect: Optional[np.ndarray] = None, - alpha: Optional[np.ndarray] = None, - units: Union[pint.Unit, str] = None, - **metadata - ): + _CORE_DATA = ["points", "vect"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "points" + + def __init__( + self, + points: np.ndarray, + k: int, + vect: Optional[np.ndarray] = None, + alpha: Optional[np.ndarray] = None, + units: Union[pint.Unit, str] = None, + **metadata, + ): """Initialize Dotprops Neuron.""" super().__init__() @@ -141,13 +145,13 @@ def __truediv__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.divide(n.points, other, out=n.points, casting='unsafe') + _ = np.divide(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] /= other + n.connectors.loc[:, ["x", "y", "z"]] /= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -164,13 +168,13 @@ def __mul__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.multiply(n.points, other, out=n.points, casting='unsafe') + _ = np.multiply(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] *= other + n.connectors.loc[:, ["x", "y", "z"]] *= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -187,13 +191,13 @@ def __add__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.add(n.points, other, out=n.points, casting='unsafe') + _ = np.add(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] += other + n.connectors.loc[:, ["x", "y", "z"]] += other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") return n # If another neuron, return a list of neurons @@ -206,13 +210,13 @@ def __sub__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.subtract(n.points, other, out=n.points, casting='unsafe') + _ = np.subtract(n.points, other, out=n.points, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] -= other + n.connectors.loc[:, ["x", "y", "z"]] -= other # Force recomputing of KDTree - if hasattr(n, '_tree'): - delattr(n, '_tree') + if hasattr(n, "_tree"): + delattr(n, "_tree") return n return NotImplemented @@ -224,22 +228,21 @@ def __getstate__(self): # The KDTree from pykdtree does not like being pickled # We will have to remove it which will force it to be regenerated # after unpickling - if '_tree' in state: - if 'pykdtree' in str(type(state['_tree'])): - _ = state.pop('_tree') + if "_tree" in state: + if "pykdtree" in str(type(state["_tree"])): + _ = state.pop("_tree") return state - def __len__(self): - return len(self.points) - @property def alpha(self): """Alpha value for tangent vectors (optional).""" if isinstance(self._alpha, type(None)): if isinstance(self.k, type(None)) or (self.k <= 0): - raise ValueError('Unable to calculate `alpha` for Dotprops not ' - 'generated using k-nearest-neighbors.') + raise ValueError( + "Unable to calculate `alpha` for Dotprops not " + "generated using k-nearest-neighbors." + ) self.recalculate_tangents(self.k, inplace=True) return self._alpha @@ -249,7 +252,7 @@ def alpha(self, value): if not isinstance(value, type(None)): value = np.asarray(value) if value.ndim != 1: - raise ValueError(f'alpha must be (N, ) array, got {value.shape}') + raise ValueError(f"alpha must be (N, ) array, got {value.shape}") self._alpha = value @property @@ -259,8 +262,8 @@ def bbox(self) -> np.ndarray: mx = np.max(self.points, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -270,12 +273,16 @@ def bbox(self) -> np.ndarray: @property def datatables(self) -> List[str]: """Names of all DataFrames attached to this neuron.""" - return [k for k, v in self.__dict__.items() if isinstance(v, pd.DataFrame, np.ndarray)] + return [ + k + for k, v in self.__dict__.items() + if isinstance(v, pd.DataFrame, np.ndarray) + ] @property def kdtree(self): """KDTree for points.""" - if not getattr(self, '_tree', None): + if not getattr(self, "_tree", None): self._tree = KDTree(self.points) return self._tree @@ -290,7 +297,7 @@ def points(self, value): value = np.zeros((0, 3)) value = np.asarray(value) if value.ndim != 2 or value.shape[1] != 3: - raise ValueError(f'points must be (N, 3) array, got {value.shape}') + raise ValueError(f"points must be (N, 3) array, got {value.shape}") self._points = value # Also reset KDtree self._tree = None @@ -307,7 +314,7 @@ def vect(self, value): if not isinstance(value, type(None)): value = np.asarray(value) if value.ndim != 2 or value.shape[1] != 3: - raise ValueError(f'vectors must be (N, 3) array, got {value.shape}') + raise ValueError(f"vectors must be (N, 3) array, got {value.shape}") self._vect = value @property @@ -336,11 +343,11 @@ def soma(self) -> Optional[int]: if not any(soma): soma = None elif any(np.array(soma) < 0) or any(np.array(soma) > self.points.shape[0]): - logger.warning(f'Soma(s) {soma} not found in points.') + logger.warning(f"Soma(s) {soma} not found in points.") soma = None else: if 0 < soma < self.points.shape[0]: - logger.warning(f'Soma {soma} not found in node table.') + logger.warning(f"Soma {soma} not found in node table.") soma = None return soma @@ -348,7 +355,7 @@ def soma(self) -> Optional[int]: @soma.setter def soma(self, value: Union[Callable, int, None]) -> None: """Set soma.""" - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self._soma = types.MethodType(value, self) elif isinstance(value, type(None)): self._soma = None @@ -358,20 +365,22 @@ def soma(self, value: Union[Callable, int, None]) -> None: if 0 < value < self.points.shape[0]: self._soma = value else: - raise ValueError('Soma must be function, None or a valid node index.') + raise ValueError("Soma must be function, None or a valid node index.") @property def type(self) -> str: """Neuron type.""" - return 'navis.Dotprops' - - def dist_dots(self, - other: 'Dotprops', - alpha: bool = False, - distance_upper_bound: Optional[float] = None, - **kwargs) -> Union[ - Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray] - ]: + return "navis.Dotprops" + + def dist_dots( + self, + other: "Dotprops", + alpha: bool = False, + distance_upper_bound: Optional[float] = None, + **kwargs, + ) -> Union[ + Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray] + ]: """Query this Dotprops against another. This function is mainly for `navis.nblast`. @@ -419,9 +428,9 @@ def dist_dots(self, # Scipy's KDTree does not like the distance to be None diub = distance_upper_bound if distance_upper_bound else np.inf - fast_dists, fast_idxs = other.kdtree.query(points, - distance_upper_bound=diub, - **kwargs) + fast_dists, fast_idxs = other.kdtree.query( + points, distance_upper_bound=diub, **kwargs + ) # If upper distance we have to worry about infinite distances if distance_upper_bound: @@ -479,7 +488,7 @@ def downsample(self, factor=5, inplace=False, **kwargs): return x return None - def copy(self) -> 'Dotprops': + def copy(self) -> "Dotprops": """Return a copy of the dotprops. Returns @@ -489,17 +498,40 @@ def copy(self) -> 'Dotprops': """ # Don't copy the KDtree - when using pykdtree, copy.copy throws an # error and the construction is super fast anyway - no_copy = ['_lock', '_tree'] + no_copy = ["_lock", "_tree"] # Generate new empty neuron - note we pass vect and alpha to # prevent calculation on initialization - x = self.__class__(points=np.zeros((0, 3)), k=1, - vect=np.zeros((0, 3)), alpha=np.zeros(0)) + x = self.__class__( + points=np.zeros((0, 3)), k=1, vect=np.zeros((0, 3)), alpha=np.zeros(0) + ) # Populate with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) + + return x + + def view(self) -> "Dotprops": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__( + points=np.zeros((0, 3)), k=1, vect=np.zeros((0, 3)), alpha=np.zeros(0) + ) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) return x - def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inplace=False): + def drop_fluff( + self, epsilon, keep_size: int = None, n_largest: int = None, inplace=False + ): """Remove fluff from neuron. By default, this function will remove all but the largest connected @@ -534,11 +566,156 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl Base function. See for details and examples. """ - x = morpho.drop_fluff(self, epsilon=epsilon, keep_size=keep_size, n_largest=n_largest, inplace=inplace) + x = morpho.drop_fluff( + self, + epsilon=epsilon, + keep_size=keep_size, + n_largest=n_largest, + inplace=inplace, + ) if not inplace: return x + def mask(self, mask, inplace=False, copy=False) -> "Dotprops": + """Mask neuron with given mask. + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with property name + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data (points, vectors, alpha, etc.) after masking. + This is useful if you want to avoid accidentally modifying + the original nodes table. + + Returns + ------- + n : Dotprops + The masked neuron. + + See Also + -------- + [`Dotprops.unmask`][navis.Dotprops.unmask] + Remove mask from neuron. + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + n = self + if not inplace: + n = self.view() + + if callable(mask): + mask = mask(n) + elif isinstance(mask, str): + mask = getattr(n, mask) + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != len(n): + raise ValueError("Mask must have same length as points.") + + n._mask = mask + n._masked_data = {} + n._masked_data["_points"] = n.points + + # Drop soma if masked out + if n.soma is not None: + if isinstance(n.soma, (list, np.ndarray)): + soma_left = n.soma[mask[n.soma]] + n._masked_data["_soma"] = n.soma + + if any(soma_left): + n.soma = soma_left + else: + n.soma = None + elif not mask[n.soma]: + n._masked_data["_soma"] = n.soma + n.soma = None + + # Apply the mask and make copy if requested + for att in ("_points", "_vect", "_alpha"): + if hasattr(n, att): + n._masked_data[att] = getattr(n, att) # save original data + setattr(n, att, getattr(n, att)[mask]) # apply mask + + if copy: + setattr(n, att, getattr(n, att).copy()) # copy masked data if requested + + if hasattr(n, "_connectors") and "point_ix" in n._connectors.columns: + n._masked_data["connectors"] = n.connectors + n._connectors = n._connectors.loc[ + n.connectors.point_ix.isin(np.arange(len(mask))[mask]) + ] + if copy: + n._connectors = n._connectors.copy() + + n._clear_temp_attr() + + return n + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`Dotprops.is_masked`][navis.Dotprops.is_masked] + Check if neuron is masked. + [`Dotprops.mask`][navis.Dotprops.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + for k, v in self._masked_data.items(): + # Combine with current data + if hasattr(self, k): + v = np.concatenate((v[~mask], getattr(self, k)), axis=0) + setattr(self, k, v) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self + def recalculate_tangents(self, k: int, inplace=False): """Recalculate tangent vectors and alpha with a new `k`. @@ -568,8 +745,9 @@ def recalculate_tangents(self, k: int, inplace=False): # Checks and balances n_points = x.points.shape[0] if n_points < k: - raise ValueError(f"Too few points ({n_points}) to calculate {k} " - "nearest-neighbors") + raise ValueError( + f"Too few points ({n_points}) to calculate {k} " "nearest-neighbors" + ) # Create the KDTree and get the k-nearest neighbors for each point dist, ix = self.kdtree.query(x.points, k=k) @@ -597,7 +775,7 @@ def recalculate_tangents(self, k: int, inplace=False): if not inplace: return x - def snap(self, locs, to='points'): + def snap(self, locs, to="points"): """Snap xyz location(s) to closest point or synapse. Parameters @@ -626,15 +804,16 @@ def snap(self, locs, to='points'): """ locs = np.asarray(locs).astype(np.float64) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ['points', 'connectors']: - raise ValueError('`to` must be "points" or "connectors", ' - f'got {to}') + if to not in ["points", "connectors"]: + raise ValueError('`to` must be "points" or "connectors", ' f"got {to}") # Generate tree tree = graph.neuron2KDTree(self, data=to) @@ -644,9 +823,9 @@ def snap(self, locs, to='points'): return ix, dist - def to_skeleton(self, - scale_vec: Union[float, Literal['auto']] = 'auto' - ) -> core.TreeNeuron: + def to_skeleton( + self, scale_vec: Union[float, Literal["auto"]] = "auto" + ) -> core.TreeNeuron: """Turn Dotprop into a TreeNeuron. This does *not* skeletonize the neuron but rather generates a line @@ -670,12 +849,13 @@ def to_skeleton(self, TreeNeuron """ - if not isinstance(scale_vec, numbers.Number) and scale_vec != 'auto': - raise ValueError('`scale_vect` must be "auto" or a number, ' - f'got {scale_vec}') + if not isinstance(scale_vec, numbers.Number) and scale_vec != "auto": + raise ValueError( + '`scale_vect` must be "auto" or a number, ' f"got {scale_vec}" + ) - if scale_vec == 'auto': - scale_vec = self.sampling_resolution * .8 + if scale_vec == "auto": + scale_vec = self.sampling_resolution * 0.8 # Prepare segments - this is based on nat:::plot3d.dotprops halfvect = self.vect / 2 * scale_vec @@ -688,16 +868,16 @@ def to_skeleton(self, segs[1::2] = ends # Generate node table - nodes = pd.DataFrame(segs, columns=['x', 'y', 'z']) - nodes['node_id'] = nodes.index - nodes['parent_id'] = -1 - nodes.loc[1::2, 'parent_id'] = nodes.index.values[::2] + nodes = pd.DataFrame(segs, columns=["x", "y", "z"]) + nodes["node_id"] = nodes.index + nodes["parent_id"] = -1 + nodes.loc[1::2, "parent_id"] = nodes.index.values[::2] # Produce a minimal TreeNeuron tn = core.TreeNeuron(nodes, units=self.units, id=self.id) # Carry over the label - if getattr(self, '_label', None): + if getattr(self, "_label", None): tn._label = self._label # Add some other relevant attributes directly @@ -706,4 +886,3 @@ def to_skeleton(self, tn._soma = self._soma return tn - diff --git a/navis/core/masking.py b/navis/core/masking.py new file mode 100644 index 00000000..be74976e --- /dev/null +++ b/navis/core/masking.py @@ -0,0 +1,201 @@ +# This script is part of navis (http://www.github.com/navis-org/navis). +# Copyright (C) 2018 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import numpy as np +import pandas as pd + +from .neuronlist import NeuronList +from .skeleton import TreeNeuron +from .dotprop import Dotprops +from .voxel import VoxelNeuron +from .mesh import MeshNeuron + + +__all__ = ["NeuronMask"] + +# mode = "r"\w"? +class NeuronMask: + """Mask neuron(s) by a specific property. + + Parameters + ---------- + x : Neuron/List + Neuron(s) to mask. + mask : str | array | callable | list | dict + The mask to apply: + - str: The name of the property to mask by + - array: boolean mask + - callable: A function that takes a neuron as input + and returns a boolean mask + - list: a list of the above + - dict: a dictionary mapping neuron IDs to one of the + above + copy_data : bool + Whether to copy the neuron data (e.g. node table for skeletons) + when masking. Set this to `True` if you know your code will modify + the masked data and you want to prevent changes to the original. + reset_neurons : bool + If True, reset the neurons to their original state after the + context manager exits. If False, will try to incorporate any + changes made to the masked neurons. Note that this may not + work for destructive operations. + validate_mask : bool + If True, validate `mask` against the neurons before setting it. + This is recommended but can come with an overhead (in particular + if `mask` is a callable). + + Examples + -------- + >>> import navis + >>> # Grab a few skeletons + >>> nl = navis.example_neurons(3) + >>> # Label axon and dendrites + >>> _ = navis.split_axon_dendrite(nl, label_only=True) + >>> # Mask by axon + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'): + ... print("Axon cable length:", nl.cable_length * nl[0].units) + Axon cable length: [363469.75 411147.1875 390231.8125] nanometer + >>> # Mask by dendrite + >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'dendrite'): + ... print("Dendrite cable length:", nl.cable_length * nl[0].units) + Dendrite cable length: [1410770.0 1612187.25 1510453.875] nanometer + + See Also + -------- + [`navis.BaseNeuron.is_masked`][] + Check if a neuron is masked. Property exists for all neuron types. + [`navis.BaseNeuron.mask`][] + Mask a neuron. Implementation details depend on the neuron type. + [`navis.BaseNeuron.unmask`][] + Unmask a neuron. Implementation details depend on the neuron type. + + """ + + def __init__(self, x, mask, reset_neurons=True, copy_data=False, validate_mask=True): + self.neurons = x + + if validate_mask: + self.mask = mask + else: + self._mask = mask + + self.reset = reset_neurons + self.copy = copy_data + + @property + def neurons(self): + return self._neurons + + @neurons.setter + def neurons(self, value): + self._neurons = NeuronList(value) + + if any(n.is_masked for n in self._neurons): + raise MaskingError("At least some neuron(s) are already masked") + + @property + def mask(self): + return self._mask + + @mask.setter + def mask(self, mask): + # Validate the mask + if isinstance(mask, str): + for n in self.neurons: + if isinstance(n, TreeNeuron): + if mask not in n.nodes.columns: + raise MaskingError(f"Neuron does not have '{mask}' column") + elif not hasattr(n, mask): + raise MaskingError(f"Neuron does not have '{mask}' attribute") + elif isinstance(mask, (list, np.ndarray, pd.Series)): + if len(self.neurons) == 1 and len(mask) != 1: + # If we only have one neuron, we can accept a single mask + # but we still want to wrap it in a list for consistency + mask = [np.asarray(mask)] + + if len(mask) != len(self.neurons): + raise MaskingError("Number of masks does not match number of neurons") + + # Validate each mask + for m, n in zip(mask, self.neurons): + validate_mask_length(m, n) + elif isinstance(mask, dict): + for n in self.neurons: + if n.id not in mask: + raise MaskingError(f"Neuron {n.id} not in mask dictionary") + validate_mask_length(mask[n.id], n) + elif callable(mask): + # If this is a function, try calling it on the first neuron + test = mask(self.neurons[0]) + if not isinstance(test, (pd.Series, np.ndarray)) or test.dtype != bool: + raise MaskingError("Callable mask must return a boolean array") + validate_mask_length(test, self.neurons[0]) + else: + raise MaskingError(f"Unexpected mask type: {type(mask)}") + + self._mask = mask + + def __enter__(self): + for i, n in enumerate(self.neurons): + if callable(self.mask): + mask = self.mask(n) + elif isinstance(self.mask, dict): + mask = self.mask[n.id] + elif isinstance(self.mask, str): + mask = self.mask + else: + mask = self.mask[i] + + n.mask(mask, copy=self.copy, inplace=True) + + return self + + def __exit__(self, *args): + for i, n in enumerate(self.neurons): + n.unmask(reset=self.reset) + + +def validate_mask_length(mask, neuron): + """Validate mask length for a given neuron.""" + if callable(mask): + mask = mask(neuron) + elif isinstance(mask, str): + if isinstance(neuron, TreeNeuron): + mask = neuron.nodes[mask] + else: + mask = getattr(neuron, mask) + + if isinstance(mask, list): + mask = np.asarray(mask) + + if not isinstance(mask, (np.ndarray, pd.Series)) or mask.dtype != bool: + raise MaskingError("Mask must be a boolean array") + + if isinstance(neuron, TreeNeuron): + if len(mask) != len(neuron.nodes): + raise MaskingError("Mask length does not match number of nodes") + elif isinstance(neuron, VoxelNeuron): + if neuron._base_data_type == "grid" and mask.shape != neuron.shape: + raise MaskingError("Mask shape does not match voxel shape") + elif len(neuron.voxels) != len(mask): + raise MaskingError("Mask length does not match number of voxels") + elif isinstance(neuron, Dotprops): + if len(mask) != len(neuron.points): + raise MaskingError("Mask length does not match number of points") + elif isinstance(neuron, MeshNeuron): + if len(mask) != len(neuron.vertices): + raise MaskingError("Mask length does not match number of vertices") + + +class MaskingError(Exception): + pass \ No newline at end of file diff --git a/navis/core/mesh.py b/navis/core/mesh.py index 233076e0..5cca369c 100644 --- a/navis/core/mesh.py +++ b/navis/core/mesh.py @@ -26,7 +26,7 @@ from typing import Union, Optional -from .. import utils, config, meshes, conversion, graph +from .. import utils, config, meshes, conversion, graph, morpho from .base import BaseNeuron from .neuronlist import NeuronList from .skeleton import TreeNeuron @@ -39,7 +39,7 @@ xxhash = None -__all__ = ['MeshNeuron'] +__all__ = ["MeshNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -89,24 +89,28 @@ class MeshNeuron(BaseNeuron): soma: Optional[Union[list, np.ndarray]] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'units', 'n_vertices', 'n_faces'] + SUMMARY_PROPS = ["type", "name", "units", "n_vertices", "n_faces"] #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['name', 'n_vertices', 'n_faces'] + EQ_ATTRIBUTES = ["name", "n_vertices", "n_faces"] #: Temporary attributes that need clearing when neuron data changes - TEMP_ATTR = ['_memory_usage', '_trimesh', '_skeleton', '_igraph', '_graph_nx'] + TEMP_ATTR = ["_memory_usage", "_trimesh", "_skeleton", "_igraph", "_graph_nx"] #: Core data table(s) used to calculate hash - CORE_DATA = ['vertices', 'faces'] - - def __init__(self, - x, - units: Union[pint.Unit, str] = None, - process: bool = True, - validate: bool = False, - **metadata - ): + CORE_DATA = ["vertices", "faces"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "vertices" + + def __init__( + self, + x, + units: Union[pint.Unit, str] = None, + process: bool = True, + validate: bool = False, + **metadata, + ): """Initialize Mesh Neuron.""" super().__init__() @@ -117,12 +121,12 @@ def __init__(self, if isinstance(x, MeshNeuron): self.__dict__.update(x.copy().__dict__) self.vertices, self.faces = x.vertices, x.faces - elif hasattr(x, 'faces') and hasattr(x, 'vertices'): + elif hasattr(x, "faces") and hasattr(x, "vertices"): self.vertices, self.faces = x.vertices, x.faces elif isinstance(x, dict): - if 'faces' not in x or 'vertices' not in x: + if "faces" not in x or "vertices" not in x: raise ValueError('Dictionary must contain "vertices" and "faces"') - self.vertices, self.faces = x['vertices'], x['faces'] + self.vertices, self.faces = x["vertices"], x["faces"] elif isinstance(x, str) and os.path.isfile(x): m = tm.load(x) self.vertices, self.faces = m.vertices, m.faces @@ -134,10 +138,12 @@ def __init__(self, self._skeleton = TreeNeuron(x) elif isinstance(x, tuple): if len(x) != 2 or any([not isinstance(v, np.ndarray) for v in x]): - raise TypeError('Expect tuple to be two arrays: (vertices, faces)') + raise TypeError("Expect tuple to be two arrays: (vertices, faces)") self.vertices, self.faces = x[0], x[1] else: - raise utils.ConstructionError(f'Unable to construct MeshNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct MeshNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -147,9 +153,9 @@ def __init__(self, if process and self.vertices.shape[0]: # For some reason we can't do self._trimesh at this stage - _trimesh = tm.Trimesh(self.vertices, self.faces, - process=process, - validate=validate) + _trimesh = tm.Trimesh( + self.vertices, self.faces, process=process, validate=validate + ) self.vertices = _trimesh.vertices self.faces = _trimesh.faces @@ -174,8 +180,8 @@ def __getstate__(self): state = {k: v for k, v in self.__dict__.items() if not callable(v)} # We don't need the trimesh object - if '_trimesh' in state: - _ = state.pop('_trimesh') + if "_trimesh" in state: + _ = state.pop("_trimesh") return state @@ -188,9 +194,9 @@ def __truediv__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.divide(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.divide(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] /= other + n.connectors.loc[:, ["x", "y", "z"]] /= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -209,9 +215,9 @@ def __mul__(self, other, copy=True): if isinstance(other, numbers.Number) or utils.is_iterable(other): # If a number, consider this an offset for coordinates n = self.copy() if copy else self - _ = np.multiply(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.multiply(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] *= other + n.connectors.loc[:, ["x", "y", "z"]] *= other # Convert units # Note: .to_compact() throws a RuntimeWarning and returns unchanged @@ -229,9 +235,9 @@ def __add__(self, other, copy=True): """Implement addition for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.add(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.add(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] += other + n.connectors.loc[:, ["x", "y", "z"]] += other self._clear_temp_attr() @@ -245,9 +251,9 @@ def __sub__(self, other, copy=True): """Implement subtraction for coordinates (vertices, connectors).""" if isinstance(other, numbers.Number) or utils.is_iterable(other): n = self.copy() if copy else self - _ = np.subtract(n.vertices, other, out=n.vertices, casting='unsafe') + _ = np.subtract(n.vertices, other, out=n.vertices, casting="unsafe") if n.has_connectors: - n.connectors.loc[:, ['x', 'y', 'z']] -= other + n.connectors.loc[:, ["x", "y", "z"]] -= other self._clear_temp_attr() @@ -261,8 +267,8 @@ def bbox(self) -> np.ndarray: mx = np.max(self.vertices, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -279,7 +285,7 @@ def vertices(self, verts): if not isinstance(verts, np.ndarray): raise TypeError(f'Vertices must be numpy array, got "{type(verts)}"') if verts.ndim != 2: - raise ValueError('Vertices must be 2-dimensional array') + raise ValueError("Vertices must be 2-dimensional array") self._vertices = verts self._clear_temp_attr() @@ -293,16 +299,16 @@ def faces(self, faces): if not isinstance(faces, np.ndarray): raise TypeError(f'Faces must be numpy array, got "{type(faces)}"') if faces.ndim != 2: - raise ValueError('Faces must be 2-dimensional array') + raise ValueError("Faces must be 2-dimensional array") self._faces = faces self._clear_temp_attr() @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of the vertex connectivity.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph @@ -312,7 +318,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of the vertex connectivity.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute self._graph_nx = graph.neuron2nx(self) return self._graph_nx @@ -335,13 +341,13 @@ def volume(self) -> float: @property @temp_property - def skeleton(self) -> 'TreeNeuron': + def skeleton(self) -> "TreeNeuron": """Skeleton representation of this neuron. Uses [`navis.conversion.mesh2skeleton`][]. """ - if not hasattr(self, '_skeleton'): + if not hasattr(self, "_skeleton"): self._skeleton = self.skeletonize() return self._skeleton @@ -357,7 +363,9 @@ def skeleton(self, s): @property def soma(self): """Not implemented for MeshNeurons - use `.soma_pos`.""" - raise AttributeError("MeshNeurons have a soma position (`.soma_pos`), not a soma.") + raise AttributeError( + "MeshNeurons have a soma position (`.soma_pos`), not a soma." + ) @property def soma_pos(self): @@ -365,7 +373,7 @@ def soma_pos(self): Returns `None` if no soma. """ - return getattr(self, '_soma_pos', None) + return getattr(self, "_soma_pos", None) @soma_pos.setter def soma_pos(self, value): @@ -377,38 +385,325 @@ def soma_pos(self, value): try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) self._soma_pos = value @property def type(self) -> str: """Neuron type.""" - return 'navis.MeshNeuron' + return "navis.MeshNeuron" @property @temp_property def trimesh(self): """Trimesh representation of the neuron.""" - if not getattr(self, '_trimesh', None): - self._trimesh = tm.Trimesh(vertices=self._vertices, - faces=self._faces, - process=False) + if not getattr(self, "_trimesh", None): + if hasattr(self, "extra_edges"): + # Only use TrimeshPlus if we actually need it + # to avoid unnecessarily breaking stuff elsewhere + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) + self._trimesh.extra_edges = self.extra_edges + else: + self._trimesh = tm.Trimesh( + vertices=self._vertices, faces=self._faces, process=False + ) return self._trimesh - def copy(self) -> 'MeshNeuron': + def copy(self) -> "MeshNeuron": """Return a copy of the neuron.""" - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new neuron x = self.__class__(None) # Override with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) return x - def snap(self, locs, to='vertices'): + def view(self) -> "MeshNeuron": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__(None) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) + + return x + + def mask(self, mask, inplace=False, copy=False): + """Mask neuron with given mask. + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - string with property name + - callable that accepts a neuron and returns a valid mask + The mask can be either for vertices or faces but will ultimately be + used to mask out faces. Vertices not participating in any face + will be removed regardless of the mask. + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data (faces, vertices, etc.) after masking. This + is useful if you want to avoid accidentally modifying + the original nodes table. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.unmask`][navis.MeshNeuron.unmask] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked! Layering multiple masks is currently not supported. " + "Please either apply the existing mask or unmask first." + ) + + n = self + if not inplace: + n = self.view() + + if callable(mask): + mask = mask(n) + elif isinstance(mask, str): + mask = getattr(n, mask) + + mask = np.asarray(mask) + + # Some checks + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif len(mask) not in (n.vertices.shape[0], n.faces.shape[0]): + raise ValueError("Mask length does not match number of vertices or faces.") + + # Transate vertex mask to face mask + if mask.shape[0] == n.vertices.shape[0]: + vert_mask = mask + face_mask = np.all(mask[n.faces], axis=1) + + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + n, vertex_index=np.where(vert_mask)[0], return_map=True + ) + else: + face_mask = mask + vert_mask = np.zeros(n.vertices.shape[0], dtype=bool) + vert_mask[np.unique(n.faces[face_mask])] = True + + # Apply mask + verts_new, faces_new, vert_map, face_map = morpho.subset.submesh( + n, faces_index=np.where(face_mask)[0], return_map=True + ) + + # The above will have likely dropped some vertices - we need to update the vertex mask + vert_mask = np.zeros(n.vertices.shape[0], dtype=bool) + vert_mask[np.where(vert_map != -1)[0]] = True + + # Track mask, vertices and faces before masking + n._mask = face_mask # mask is always the face mask + n._masked_data = {} + n._masked_data["_vertices"] = n._vertices + n._masked_data["_faces"] = n._faces + + # Update vertices and faces + n._vertices = verts_new + n._faces = faces_new + + # See if we can mask the mesh's skeleton as well + if hasattr(n, "_skeleton"): + # If the skeleton has a vertex map, we can use it to mask the skeleton + if hasattr(n._skeleton, "vertex_map"): + # Generate a mask for the skeleton + # (keep in mind vertex_map are node IDs, not indices) + sk_mask = n._skeleton.nodes.node_id.isin( + n._skeleton.vertex_map[vert_mask] + ) + + # Apply mask + n._skeleton.mask(sk_mask) + + # Last but not least: we need to update the vertex map + # Track the old map. N.B. we're not adding this to + # skeleton._masked_data since the remapping is done by + # the MeshNeuron itself! + n._skeleton._vertex_map_unmasked = n._skeleton.vertex_map + + # Subset the vertex map to the surviving mesh vertices + # N.B. that the node IDs don't change when masking skeletons! + n._skeleton.vertex_map = n._skeleton.vertex_map[vert_mask] + # If the skeleton has no vertex map, we have to ditch it and + # let it be regenerated when needed + else: + n._masked_data["_skeleton"] = n._skeleton + n._skeleton = None # Clear the skeleton + + # See if we need to mask any connectors as well + if hasattr(n, "_connectors"): + # Only mask if there is an actual "vertex_ind" or "face_ind" column + if "vertex_ind" in n._connectors.columns: + cn_mask = n._connectors.vertex_id.isin(np.where(vert_mask)[0]) + elif "face_ind" in n._connectors.columns: + cn_mask = n._connectors.face_id.isin(np.where(face_mask)[0]) + else: + cn_mask = None + + if cn_mask is not None: + n._masked_data["_connectors"] = n._connectors + n._connectors = n._connectors.loc[mask] + if copy: + n._connectors = n._connectors.copy() + + # Check if we need to drop the soma position + if hasattr(n, "soma_pos"): + vid = n.snap(self.soma_pos, to="vertices")[0] + if not vert_mask[vid]: + n._masked_data["_soma_pos"] = n.soma_pos + n.soma_pos = None + + # Clear temporary attributes but keep the skeleton since we already fixed that manually + n._clear_temp_attr(exclude=["_skeleton"]) + + return n + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`MeshNeuron.is_masked`][navis.MeshNeuron.is_masked] + Returns True if neuron is masked. + [`MeshNeuron.mask`][navis.MeshNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + # First fix the skeleton (if it exists) + skeleton = getattr(self, "_skeleton", None) + if skeleton is not None: + # If the skeleton is not masked, it was created after the masking + # - in which case we have to throw it away because we can't recover + # the full neuron state. + if not skeleton.is_masked: + skeleton = None + else: + # Unmask the skeleton as well + skeleton.unmask(reset=reset) + + # Manually restore the vertex map + # N.B. that any destructive action (e.g. twig pruning) may have + # removed nodes from the skeleton. If that's the case, we can't + # restore the vertex map and have to ditch it. + if hasattr(skeleton, "_vertex_map_unmasked"): + skeleton.vertex_map = skeleton._vertex_map_unmasked + + # Important note: currently the skeleton gets ditched whenever the MeshNeuron + # is stale. That's mainly because (a) functions modify the mesh but not + # the vertex map and (b) re-generating the skeleton is usually cheap. + # In the long run, we need to make sure the skeleton is always in sync + # and not cleared unless that's explicitly requested. + # I'm thinking something like a MeshNeuron.sync_skeleton() method that + # can either sync the skeleton with the mesh or vice versa. + + if reset: + # Unmask and reset (this will clear temporary attributes including the skeleton) + super().unmask() + if skeleton is not None: + self.skeleton = skeleton + return self + + # Regenerate the vertex mask from the stored face mask + face_mask = self._mask + vert_mask = np.zeros(len(self._masked_data["_vertices"]), dtype=bool) + vert_mask[np.unique(self._masked_data["_faces"][face_mask])] = True + + # Generate a mesh for the masked-out data: + # The mesh prior to masking + pre_mesh = tm.Trimesh( + self._masked_data["_vertices"], self._masked_data["_faces"] + ) + # The vertices and faces that were masked out + pre_vertices, pre_faces, vert_map, face_map = morpho.subset.submesh( + pre_mesh, faces_index=np.where(~face_mask)[0], return_map=True + ) + + # Combine the two + comb = tm.util.concatenate( + [tm.Trimesh(self.vertices, self.faces), tm.Trimesh(pre_vertices, pre_faces)] + ) + + # Drop duplicate faces + comb.update_faces(comb.unique_faces()) + + # Merge vertices that are exactly the same + comb.merge_vertices(digis=6) + + # Update the neuron + self._vertices, self._faces = np.asarray(comb.vertices), np.asarray(comb.faces) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + # Check if we can re-use the skeleton + if skeleton is not None: + # Check if the vertex map is still valid + # Note to self: we could do some elaborate checks here to map old to + # most likely new vertex / nodes but that's a bit overkill for now. + if hasattr(skeleton, "vertex_map"): + if skeleton.vertex_map.shape[0] != self._vertices.shape[0]: + skeleton = None + elif skeleton.vertex_map.max() >= self._faces.shape[0]: + skeleton = None + + # If we still have a skeleton at this point, we can re-use it + if skeleton is not None: + self._skeleton = skeleton + + return self + + def snap(self, locs, to="vertices"): """Snap xyz location(s) to closest vertex or synapse. Parameters @@ -436,15 +731,16 @@ def snap(self, locs, to='vertices'): """ locs = np.asarray(locs).astype(self.vertices.dtype) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ('vertices', 'vertex', 'connectors', 'connectors'): - raise ValueError('`to` must be "vertices" or "connectors", ' - f'got {to}') + if to not in ("vertices", "vertex", "connectors", "connectors"): + raise ValueError('`to` must be "vertices" or "connectors", ' f"got {to}") # Generate tree tree = scipy.spatial.cKDTree(data=self.vertices) @@ -454,7 +750,9 @@ def snap(self, locs, to='vertices'): return ix, dist - def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> 'TreeNeuron': + def skeletonize( + self, method="wavefront", heal=True, inv_dist=None, **kwargs + ) -> "TreeNeuron": """Skeletonize mesh. See [`navis.conversion.mesh2skeleton`][] for details. @@ -477,10 +775,13 @@ def skeletonize(self, method='wavefront', heal=True, inv_dist=None, **kwargs) -> Returns ------- skeleton : navis.TreeNeuron + Has a `.vertex_map` attribute that maps each vertex in the + input mesh to a skeleton node ID. """ - return conversion.mesh2skeleton(self, method=method, heal=heal, - inv_dist=inv_dist, **kwargs) + return conversion.mesh2skeleton( + self, method=method, heal=heal, inv_dist=inv_dist, **kwargs + ) def validate(self, inplace=False): """Use trimesh to try and fix some common mesh issues. diff --git a/navis/core/neuronlist.py b/navis/core/neuronlist.py index ab5e6a61..a59f4f16 100644 --- a/navis/core/neuronlist.py +++ b/navis/core/neuronlist.py @@ -780,12 +780,23 @@ def summary(self, if not isinstance(N, slice): N = slice(N) - return pd.DataFrame(data=[[getattr(n, a, 'NA') for a in props] - for n in config.tqdm(self.neurons[N], - desc='Summarizing', - leave=False, - disable=not progress)], - columns=props) + summary = pd.DataFrame( + data=[ + [getattr(n, a, "NA") for a in props] + for n in config.tqdm( + self.neurons[N], + desc="Summarizing", + leave=False, + disable=not progress, + ) + ], + columns=props, + ) + + if any((n.is_masked for n in self.neurons[N])): + summary['masked'] = [n.is_masked for n in self.neurons[N]] + + return summary def itertuples(self): """Helper to mimic `pandas.DataFrame.itertuples()`.""" diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 5bcab701..62e97302 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -25,7 +25,7 @@ from io import BufferedIOBase -from typing import Union, Callable, List, Sequence, Optional, Dict, overload +from typing import Union, Callable, List, Sequence, Optional, Dict from typing_extensions import Literal from .. import graph, morpho, utils, config, core, sampling, intersection @@ -39,7 +39,7 @@ except ModuleNotFoundError: xxhash = None -__all__ = ['TreeNeuron'] +__all__ = ["TreeNeuron"] # Set up logging logger = config.get_logger(__name__) @@ -52,15 +52,17 @@ def requires_nodes(func): """Return `None` if neuron has no nodes.""" + @functools.wraps(func) def wrapper(*args, **kwargs): self = args[0] # Return 0 - if isinstance(self.nodes, str) and self.nodes == 'NA': - return 'NA' + if isinstance(self.nodes, str) and self.nodes == "NA": + return "NA" if not isinstance(self.nodes, pd.DataFrame): return None return func(*args, **kwargs) + return wrapper @@ -95,8 +97,8 @@ class TreeNeuron(BaseNeuron): nodes: pd.DataFrame - graph: 'nx.DiGraph' - igraph: 'igraph.Graph' # type: ignore # doesn't know iGraph + graph: "nx.DiGraph" + igraph: "igraph.Graph" # type: ignore # doesn't know iGraph n_branches: int n_leafs: int @@ -117,37 +119,63 @@ class TreeNeuron(BaseNeuron): soma_detection_label: Union[float, int, str] = 1 #: Soma radius (e.g. for plotting). If string, must be column in nodes #: table. Default = 'radius'. - soma_radius: Union[float, int, str] = 'radius' + soma_radius: Union[float, int, str] = "radius" # Set default function for soma finding. Default = [`navis.morpho.find_soma`][] - _soma: Union[Callable[['TreeNeuron'], Sequence[int]], int] = morpho.find_soma + _soma: Union[Callable[["TreeNeuron"], Sequence[int]], int] = morpho.find_soma tags: Optional[Dict[str, List[int]]] = None #: Attributes to be used when comparing two neurons. - EQ_ATTRIBUTES = ['n_nodes', 'n_connectors', 'soma', 'root', - 'n_branches', 'n_leafs', 'cable_length', 'name'] + EQ_ATTRIBUTES = [ + "n_nodes", + "n_connectors", + "soma", + "root", + "n_branches", + "n_leafs", + "cable_length", + "name", + ] #: Temporary attributes that need to be regenerated when data changes. - TEMP_ATTR = ['_igraph', '_graph_nx', '_segments', '_small_segments', - '_geodesic_matrix', 'centrality_method', '_simple', - '_cable_length', '_memory_usage', '_adjacency_matrix'] + TEMP_ATTR = [ + "_igraph", + "_graph_nx", + "_segments", + "_small_segments", + "_geodesic_matrix", + "centrality_method", + "_simple", + "_cable_length", + "_memory_usage", + "_adjacency_matrix", + ] #: Attributes used for neuron summary - SUMMARY_PROPS = ['type', 'name', 'n_nodes', 'n_connectors', 'n_branches', - 'n_leafs', 'cable_length', 'soma', 'units'] + SUMMARY_PROPS = [ + "type", + "name", + "n_nodes", + "n_connectors", + "n_branches", + "n_leafs", + "cable_length", + "soma", + "units", + ] #: Core data table(s) used to calculate hash - CORE_DATA = ['nodes:node_id,parent_id,x,y,z'] - - def __init__(self, - x: Union[pd.DataFrame, - BufferedIOBase, - str, - 'TreeNeuron', - nx.DiGraph], - units: Union[pint.Unit, str] = None, - **metadata - ): + CORE_DATA = ["nodes:node_id,parent_id,x,y,z"] + + #: Property used to calculate length of neuron + _LENGTH_DATA = "nodes" + + def __init__( + self, + x: Union[pd.DataFrame, BufferedIOBase, str, "TreeNeuron", nx.DiGraph], + units: Union[pint.Unit, str] = None, + **metadata, + ): """Initialize Skeleton Neuron.""" super().__init__() @@ -157,10 +185,12 @@ def __init__(self, if isinstance(x, pd.DataFrame): self.nodes = x elif isinstance(x, pd.Series): - if not hasattr(x, 'nodes'): - raise ValueError('pandas.Series must have `nodes` entry.') + if not hasattr(x, "nodes"): + raise ValueError("pandas.Series must have `nodes` entry.") elif not isinstance(x.nodes, pd.DataFrame): - raise TypeError(f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"') + raise TypeError( + f'Nodes must be pandas DataFrame, got "{type(x.nodes)}"' + ) self.nodes = x.nodes metadata.update(x.to_dict()) elif isinstance(x, nx.Graph): @@ -182,13 +212,15 @@ def __init__(self, elif isinstance(x, tuple): # Tuple of vertices and edges if len(x) != 2: - raise ValueError('Tuple must have 2 elements: vertices and edges.') + raise ValueError("Tuple must have 2 elements: vertices and edges.") self.nodes = graph.edges2neuron(edges=x[1], vertices=x[0]).nodes elif isinstance(x, type(None)): # This is a essentially an empty neuron pass else: - raise utils.ConstructionError(f'Unable to construct TreeNeuron from "{type(x)}"') + raise utils.ConstructionError( + f'Unable to construct TreeNeuron from "{type(x)}"' + ) for k, v in metadata.items(): try: @@ -218,21 +250,23 @@ def __truediv__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Division by list/array requires 4 ' - 'divisors for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Division by list/array requires 4 " + "divisors for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] /= other + n.nodes[["x", "y", "z", "radius"]] /= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] /= other + n.connectors[["x", "y", "z"]] /= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius /= other @@ -243,7 +277,7 @@ def __truediv__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units * other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -255,21 +289,23 @@ def __mul__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 4: - raise ValueError('Multiplication by list/array requires 4' - 'multipliers for x/y/z and radius - ' - f'got {len(other)}') + raise ValueError( + "Multiplication by list/array requires 4" + "multipliers for x/y/z and radius - " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z', 'radius']] *= other + n.nodes[["x", "y", "z", "radius"]] *= other # At this point we can ditch any 4th unit if utils.is_iterable(other): other = other[:3] if n.has_connectors: - n.connectors[['x', 'y', 'z']] *= other + n.connectors[["x", "y", "z"]] *= other - if hasattr(n, 'soma_radius'): + if hasattr(n, "soma_radius"): if isinstance(n.soma_radius, numbers.Number): n.soma_radius *= other @@ -280,7 +316,7 @@ def __mul__(self, other, copy=True): warnings.simplefilter("ignore") n.units = (n.units / other).to_compact() - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -292,19 +328,21 @@ def __add__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] += other + n.nodes[["x", "y", "z"]] += other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] += other + n.connectors[["x", "y", "z"]] += other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n # If another neuron, return a list of neurons elif isinstance(other, BaseNeuron): @@ -319,19 +357,21 @@ def __sub__(self, other, copy=True): if len(set(other)) == 1: other == other[0] elif len(other) != 3: - raise ValueError('Addition by list/array requires 3' - 'multipliers for x/y/z coordinates ' - f'got {len(other)}') + raise ValueError( + "Addition by list/array requires 3" + "multipliers for x/y/z coordinates " + f"got {len(other)}" + ) # If a number, consider this an offset for coordinates n = self.copy() if copy else self - n.nodes[['x', 'y', 'z']] -= other + n.nodes[["x", "y", "z"]] -= other # Do the connectors if n.has_connectors: - n.connectors[['x', 'y', 'z']] -= other + n.connectors[["x", "y", "z"]] -= other - n._clear_temp_attr(exclude=['classify_nodes']) + n._clear_temp_attr(exclude=["classify_nodes"]) return n return NotImplemented @@ -341,10 +381,10 @@ def __getstate__(self): # Pickling the graphs actually takes longer than regenerating them # from scratch - if '_graph_nx' in state: - _ = state.pop('_graph_nx') - if '_igraph' in state: - _ = state.pop('_igraph') + if "_graph_nx" in state: + _ = state.pop("_graph_nx") + if "_igraph" in state: + _ = state.pop("_igraph") return state @@ -352,7 +392,7 @@ def __getstate__(self): @temp_property def adjacency_matrix(self): """Adjacency matrix of the skeleton.""" - if not hasattr(self, '_adjacency_matrix'): + if not hasattr(self, "_adjacency_matrix"): self._adjacency_matrix = graph.skeleton_adjacency_matrix(self) return self._adjacency_matrix @@ -360,7 +400,7 @@ def adjacency_matrix(self): @requires_nodes def vertices(self) -> np.ndarray: """Vertices of the skeleton.""" - return self.nodes[['x', 'y', 'z']].values + return self.nodes[["x", "y", "z"]].values @property @requires_nodes @@ -374,7 +414,7 @@ def edges(self) -> np.ndarray: """ not_root = self.nodes[self.nodes.parent_id >= 0] - return not_root[['node_id', 'parent_id']].values + return not_root[["node_id", "parent_id"]].values @property @requires_nodes @@ -387,7 +427,7 @@ def edge_coords(self) -> np.ndarray: Same but with node IDs instead of x/y/z coordinates. """ - locs = self.nodes.set_index('node_id')[['x', 'y', 'z']] + locs = self.nodes.set_index("node_id")[["x", "y", "z"]] edges = self.edges edges_co = np.zeros((edges.shape[0], 2, 3)) edges_co[:, 0, :] = locs.loc[edges[:, 0]].values @@ -396,10 +436,10 @@ def edge_coords(self) -> np.ndarray: @property @temp_property - def igraph(self) -> 'igraph.Graph': + def igraph(self) -> "igraph.Graph": """iGraph representation of this neuron.""" # If igraph does not exist, create and return - if not hasattr(self, '_igraph'): + if not hasattr(self, "_igraph"): # This also sets the attribute return self.get_igraph() return self._igraph @@ -409,7 +449,7 @@ def igraph(self) -> 'igraph.Graph': def graph(self) -> nx.DiGraph: """Networkx Graph representation of this neuron.""" # If graph does not exist, create and return - if not hasattr(self, '_graph_nx'): + if not hasattr(self, "_graph_nx"): # This also sets the attribute return self.get_graph_nx() return self._graph_nx @@ -419,7 +459,7 @@ def graph(self) -> nx.DiGraph: def geodesic_matrix(self): """Matrix with geodesic (along-the-arbor) distance between nodes.""" # If matrix has not yet been generated or needs update - if not hasattr(self, '_geodesic_matrix'): + if not hasattr(self, "_geodesic_matrix"): # (Re-)generate matrix self._geodesic_matrix = graph.geodesic_matrix(self) @@ -429,7 +469,7 @@ def geodesic_matrix(self): @requires_nodes def leafs(self) -> pd.DataFrame: """Leaf node table.""" - return self.nodes[self.nodes['type'] == 'end'] + return self.nodes[self.nodes["type"] == "end"] @property @requires_nodes @@ -441,7 +481,7 @@ def ends(self): @requires_nodes def branch_points(self): """Branch node table.""" - return self.nodes[self.nodes['type'] == 'branch'] + return self.nodes[self.nodes["type"] == "branch"] @property def nodes(self) -> pd.DataFrame: @@ -461,20 +501,24 @@ def nodes(self, v): def _set_nodes(self, v): # Redefine this function in subclass to change validation - self._nodes = utils.validate_table(v, - required=[('node_id', 'rowId', 'node', 'treenode_id', 'PointNo'), - ('parent_id', 'link', 'parent', 'Parent'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z')], - rename=True, - optional={('radius', 'W'): 0}, - restrict=False) + self._nodes = utils.validate_table( + v, + required=[ + ("node_id", "rowId", "node", "treenode_id", "PointNo"), + ("parent_id", "link", "parent", "Parent"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ], + rename=True, + optional={("radius", "W"): 0}, + restrict=False, + ) # Make sure we don't end up with object dtype anywhere as this can # cause problems - for c in ('node_id', 'parent_id'): - if self._nodes[c].dtype == 'O': + for c in ("node_id", "parent_id"): + if self._nodes[c].dtype == "O": self._nodes[c] = self._nodes[c].astype(int) graph.classify_nodes(self) @@ -504,8 +548,7 @@ def is_tree(self) -> bool: @property def subtrees(self) -> List[List[int]]: """List of subtrees. Sorted by size as sets of node IDs.""" - return sorted(graph._connected_components(self), - key=lambda x: -len(x)) + return sorted(graph._connected_components(self), key=lambda x: -len(x)) @property def connectors(self) -> pd.DataFrame: @@ -514,7 +557,7 @@ def connectors(self) -> pd.DataFrame: def _get_connectors(self) -> pd.DataFrame: # Redefine this function in subclass to change how nodes are retrieved - return getattr(self, '_connectors', None) + return getattr(self, "_connectors", None) @connectors.setter def connectors(self, v): @@ -528,15 +571,19 @@ def _set_connectors(self, v): if isinstance(v, type(None)): self._connectors = None else: - self._connectors = utils.validate_table(v, - required=[('connector_id', 'id'), - ('node_id', 'rowId', 'node', 'treenode_id'), - ('x', 'X'), - ('y', 'Y'), - ('z', 'Z'), - ('type', 'relation', 'label', 'prepost')], - rename=True, - restrict=False) + self._connectors = utils.validate_table( + v, + required=[ + ("connector_id", "id"), + ("node_id", "rowId", "node", "treenode_id"), + ("x", "X"), + ("y", "Y"), + ("z", "Z"), + ("type", "relation", "label", "prepost"), + ], + rename=True, + restrict=False, + ) @property @requires_nodes @@ -550,8 +597,9 @@ def cycles(self) -> Optional[List[int]]: """ try: - c = nx.find_cycle(self.graph, - source=self.nodes[self.nodes.type == 'end'].node_id.values) + c = nx.find_cycle( + self.graph, source=self.nodes[self.nodes.type == "end"].node_id.values + ) return c except nx.exception.NetworkXNoCycle: return None @@ -559,16 +607,16 @@ def cycles(self) -> Optional[List[int]]: raise @property - def simple(self) -> 'TreeNeuron': + def simple(self) -> "TreeNeuron": """Simplified representation consisting only of root, branch points and leafs.""" - if not hasattr(self, '_simple'): + if not hasattr(self, "_simple"): self._simple = self.copy() # Make sure we don't have a soma, otherwise that node will be preserved self._simple.soma = None # Downsample - self._simple.downsample(float('inf'), inplace=True) + self._simple.downsample(float("inf"), inplace=True) return self._simple @property @@ -592,11 +640,11 @@ def soma(self) -> Optional[Union[str, int]]: if all(pd.isnull(soma)): soma = None elif not any(self.nodes.node_id.isin(soma)): - logger.warning(f'Soma(s) {soma} not found in node table.') + logger.warning(f"Soma(s) {soma} not found in node table.") soma = None else: if soma not in self.nodes.node_id.values: - logger.warning(f'Soma {soma} not found in node table.') + logger.warning(f"Soma {soma} not found in node table.") soma = None return soma @@ -604,7 +652,7 @@ def soma(self) -> Optional[Union[str, int]]: @soma.setter def soma(self, value: Union[Callable, int, None]) -> None: """Set soma.""" - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self._soma = types.MethodType(value, self) elif isinstance(value, type(None)): self._soma = None @@ -614,7 +662,7 @@ def soma(self, value: Union[Callable, int, None]) -> None: if value in self.nodes.node_id.values: self._soma = value else: - raise ValueError('Soma must be function, None or a valid node ID.') + raise ValueError("Soma must be function, None or a valid node ID.") @property def soma_pos(self) -> Optional[Sequence]: @@ -633,7 +681,7 @@ def soma_pos(self) -> Optional[Sequence]: else: soma = utils.make_iterable(soma) - return self.nodes.loc[self.nodes.node_id.isin(soma), ['x', 'y', 'z']].values + return self.nodes.loc[self.nodes.node_id.isin(soma), ["x", "y", "z"]].values @soma_pos.setter def soma_pos(self, value: Sequence) -> None: @@ -645,16 +693,20 @@ def soma_pos(self, value: Sequence) -> None: try: value = np.asarray(value).astype(np.float64).reshape(3) except BaseException: - raise ValueError(f'Unable to convert soma position "{value}" ' - f'to numeric (3, ) numpy array.') + raise ValueError( + f'Unable to convert soma position "{value}" ' + f"to numeric (3, ) numpy array." + ) # Generate tree - id, dist = self.snap(value, to='nodes') + id, dist = self.snap(value, to="nodes") # A sanity check if dist > (self.sampling_resolution * 10): - logger.warning(f'New soma position for {self.id} is suspiciously ' - f'far away from the closest node: {dist}') + logger.warning( + f"New soma position for {self.id} is suspiciously " + f"far away from the closest node: {dist}" + ) self.soma = id @@ -673,26 +725,26 @@ def root(self, value: Union[int, List[int]]) -> None: @property def type(self) -> str: """Neuron type.""" - return 'navis.TreeNeuron' + return "navis.TreeNeuron" @property @requires_nodes def n_branches(self) -> Optional[int]: """Number of branch points.""" - return self.nodes[self.nodes.type == 'branch'].shape[0] + return self.nodes[self.nodes.type == "branch"].shape[0] @property @requires_nodes def n_leafs(self) -> Optional[int]: """Number of leaf nodes.""" - return self.nodes[self.nodes.type == 'end'].shape[0] + return self.nodes[self.nodes.type == "end"].shape[0] @property @temp_property @add_units(compact=True) def cable_length(self) -> Union[int, float]: """Cable length.""" - if not hasattr(self, '_cable_length'): + if not hasattr(self, "_cable_length"): self._cable_length = morpho.cable_length(self) return self._cable_length @@ -700,17 +752,21 @@ def cable_length(self) -> Union[int, float]: @add_units(compact=True, power=2) def surface_area(self) -> float: """Radius-based lateral surface area.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - area will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - area will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - area will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -721,23 +777,27 @@ def surface_area(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (np.pi * (r1 + r2) * np.sqrt( (r1-r2)**2 + h**2)).sum() + return (np.pi * (r1 + r2) * np.sqrt((r1 - r2) ** 2 + h**2)).sum() @property @add_units(compact=True, power=3) def volume(self) -> float: """Radius-based volume.""" - if 'radius' not in self.nodes.columns: - raise ValueError(f'Neuron {self.id} does not have radius information') + if "radius" not in self.nodes.columns: + raise ValueError(f"Neuron {self.id} does not have radius information") if any(self.nodes.radius < 0): - logger.warning(f'Neuron {self.id} has negative radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has negative radii - volume will not be correct." + ) if any(self.nodes.radius.isnull()): - logger.warning(f'Neuron {self.id} has NaN radii - volume will not be correct.') + logger.warning( + f"Neuron {self.id} has NaN radii - volume will not be correct." + ) # Generate radius dict - radii = self.nodes.set_index('node_id').radius.to_dict() + radii = self.nodes.set_index("node_id").radius.to_dict() # Drop root node(s) not_root = self.nodes.parent_id >= 0 # For each cylinder get the height @@ -748,17 +808,17 @@ def volume(self) -> float: r1 = nodes.node_id.map(radii).values r2 = nodes.parent_id.map(radii).values - return (1/3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() + return (1 / 3 * np.pi * (r1**2 + r1 * r2 + r2**2) * h).sum() @property def bbox(self) -> np.ndarray: """Bounding box (includes connectors).""" - mn = np.min(self.nodes[['x', 'y', 'z']].values, axis=0) - mx = np.max(self.nodes[['x', 'y', 'z']].values, axis=0) + mn = np.min(self.nodes[["x", "y", "z"]].values, axis=0) + mx = np.max(self.nodes[["x", "y", "z"]].values, axis=0) if self.has_connectors: - cn_mn = np.min(self.connectors[['x', 'y', 'z']].values, axis=0) - cn_mx = np.max(self.connectors[['x', 'y', 'z']].values, axis=0) + cn_mn = np.min(self.connectors[["x", "y", "z"]].values, axis=0) + cn_mx = np.max(self.connectors[["x", "y", "z"]].values, axis=0) mn = np.min(np.vstack((mn, cn_mn)), axis=0) mx = np.max(np.vstack((mx, cn_mx)), axis=0) @@ -780,9 +840,9 @@ def sampling_resolution(self) -> float: def segments(self) -> List[list]: """Neuron broken down into linear segments (see also `.small_segments`).""" # Calculate if required - if not hasattr(self, '_segments'): + if not hasattr(self, "_segments"): # This also sets the attribute - self._segments = self._get_segments(how='length') + self._segments = self._get_segments(how="length") return self._segments @property @@ -790,19 +850,18 @@ def segments(self) -> List[list]: def small_segments(self) -> List[list]: """Neuron broken down into small linear segments (see also `.segments`).""" # Calculate if required - if not hasattr(self, '_small_segments'): + if not hasattr(self, "_small_segments"): # This also sets the attribute - self._small_segments = self._get_segments(how='break') + self._small_segments = self._get_segments(how="break") return self._small_segments - def _get_segments(self, - how: Union[Literal['length'], - Literal['break']] = 'length' - ) -> List[list]: + def _get_segments( + self, how: Union[Literal["length"], Literal["break"]] = "length" + ) -> List[list]: """Generate segments for neuron.""" - if how == 'length': + if how == "length": return graph._generate_segments(self) - elif how == 'break': + elif how == "break": return graph._break_segments(self) else: raise ValueError(f'Unknown method: "{how}"') @@ -830,11 +889,11 @@ def _clear_temp_attr(self, exclude: list = []) -> None: elif self._soma not in self.nodes.node_id.values: self.soma = None - if 'classify_nodes' not in exclude: + if "classify_nodes" not in exclude: # Reclassify nodes graph.classify_nodes(self, inplace=True) - def copy(self, deepcopy: bool = False) -> 'TreeNeuron': + def copy(self, deepcopy: bool = False) -> "TreeNeuron": """Return a copy of the neuron. Parameters @@ -850,17 +909,19 @@ def copy(self, deepcopy: bool = False) -> 'TreeNeuron': TreeNeuron """ - no_copy = ['_lock'] + no_copy = ["_lock"] # Generate new empty neuron x = self.__class__(None) # Populate with this neuron's data - x.__dict__.update({k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy}) + x.__dict__.update( + {k: copy.copy(v) for k, v in self.__dict__.items() if k not in no_copy} + ) # Copy graphs only if neuron is not stale if not self.is_stale: - if '_graph_nx' in self.__dict__: + if "_graph_nx" in self.__dict__: x._graph_nx = self._graph_nx.copy(as_view=deepcopy is not True) - if '_igraph' in self.__dict__: + if "_igraph" in self.__dict__: if self._igraph is not None: # This is pretty cheap, so we will always make a deep copy x._igraph = self._igraph.copy() @@ -869,6 +930,22 @@ def copy(self, deepcopy: bool = False) -> 'TreeNeuron': return x + def view(self) -> "TreeNeuron": + """Create a view of the neuron without copying data. + + Be aware that changes to the view may affect the original neuron! + + """ + no_copy = ["_lock"] + + # Generate new empty neuron + x = self.__class__(None) + + # Override with this neuron's data + x.__dict__.update({k: v for k, v in self.__dict__.items() if k not in no_copy}) + + return x + def get_graph_nx(self) -> nx.DiGraph: """Calculate and return networkX representation of neuron. @@ -883,7 +960,7 @@ def get_graph_nx(self) -> nx.DiGraph: self._graph_nx = graph.neuron2nx(self) return self._graph_nx - def get_igraph(self) -> 'igraph.Graph': # type: ignore + def get_igraph(self) -> "igraph.Graph": # type: ignore """Calculate and return iGraph representation of neuron. Once calculated stored as `.igraph`. Call function again to update @@ -901,11 +978,192 @@ def get_igraph(self) -> 'igraph.Graph': # type: ignore self._igraph = graph.neuron2igraph(self, raise_not_installed=False) return self._igraph - @overload - def resample(self, resample_to: int, inplace: Literal[False]) -> 'TreeNeuron': ... + def mask(self, mask, inplace=False, copy=False) -> "TreeNeuron": + """Mask neuron with given mask. + + Parameters + ---------- + mask : np.ndarray + Mask to apply. Can be: + - 1D array with boolean values + - callable that accepts a neuron and returns a mask + - string with column name in nodes table + inplace : bool, optional + Whether to mask the neuron inplace. + copy : bool, optional + Whether to copy data such as the node table after masking. This + is useful if you want to avoid accidentally modifying + the original nodes table. + + Returns + ------- + n : TreeNeuron + The masked neuron. + + See Also + -------- + [`navis.MeshNeuron.unmask`][] + Remove mask from neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if self.is_masked: + raise ValueError( + "Neuron already masked. Layering multiple masks is currently not supported, please unmask first." + ) + + n = self + if not inplace: + n = self.view() - @overload - def resample(self, resample_to: int, inplace: Literal[True]) -> None: ... + if callable(mask): + mask = mask(n) + elif isinstance(mask, str): + mask = n.nodes[mask].values + + mask = np.asarray(mask) + + if mask.dtype != bool: + raise ValueError("Mask must be boolean array.") + elif mask.shape[0] != n.nodes.shape[0]: + raise ValueError("Mask must have same length as nodes table.") + + n._mask = mask + n._masked_data = {} + n._masked_data["_nodes"] = n.nodes + + # N.B. we're directly setting `._nodes`` to avoid overhead from checks + n._nodes = n._nodes.loc[mask].drop("type", axis=1, errors="ignore") + if copy: + n._nodes = n._nodes.copy() + + # See if any parent IDs have ceased to exist + missing_parents = ~n._nodes.parent_id.isin(n._nodes.node_id) & ( + n._nodes.parent_id >= 0 + ) + if any(missing_parents): + n.nodes.loc[missing_parents, "parent_id"] = -1 + + if hasattr(n, "_connectors"): + n._masked_data["_connectors"] = n.connectors + n._connectors = n._connectors.loc[ + n._connectors.node_id.isin(n.nodes.node_id) + ] + if copy: + n._connectors = n._connectors.copy() + + n._clear_temp_attr() + + return n + + def unmask(self, reset=True): + """Unmask neuron. + + Returns the neuron to its original state before masking. + + Parameters + ---------- + reset : bool + Whether to reset the neuron to its original state before masking. + If False, edits made to the neuron after masking will be kept. + + Returns + ------- + self + + See Also + -------- + [`TreeNeuron.is_masked`][navis.TreeNeuron.is_masked] + Check if neuron is masked. + [`TreeNeuron.mask`][navis.TreeNeuron.mask] + Mask neuron. + [`navis.NeuronMask`][] + Context manager for masking neurons. + + """ + if not self.is_masked: + raise ValueError("Neuron is not masked.") + + if reset: + # Unmask and reset to original state + super().unmask() + return self + + mask = self._mask + + # Combine post-mask data with original data + post_nodes = self.nodes + pre_nodes = self._masked_data["_nodes"].iloc[~mask] + + # We need to take care of a few things: + # 1. Make sure don't have any duplicate node IDs + duplicates = post_nodes.node_id.values[ + post_nodes.node_id.isin(pre_nodes.node_id) + ] + if any(duplicates): + # Start with a new max index + max_id = max(post_nodes.node_id.max(), pre_nodes.node_id.max()) + 1 + # New indices + new_ix = dict(zip(duplicates, range(max_id, max_id + len(duplicates)))) + # Update node IDs and parent IDs + post_nodes["node_id"] = post_nodes.node_id.replace(new_ix) + post_nodes["parent_id"] = post_nodes.parent_id.replace(new_ix) + + # Concatenate + self._nodes = pd.concat([pre_nodes, post_nodes], ignore_index=True) + + # 2. Re-connect the root nodes of the masked data + post_roots = post_nodes[post_nodes.parent_id < 0].node_id.values + pre_parents = ( + self._masked_data["_nodes"].set_index("node_id").parent_id.to_dict() + ) + to_fix = {} + for r in post_roots: + # Skip if this root is not in pre_parents + if r not in pre_parents: + continue + # Skip if this was also a root in the pre-masked data + if pre_parents[r] < 0: + continue + # Skip if the old parent does not exist anymore + if pre_parents[r] not in self.nodes.node_id.values: + continue + # If we made it here, there is a new root that we can connect to the old parent + to_fix[r] = pre_parents[r] + if to_fix: + # Fix parent IDs + self._nodes["parent_id"] = self._nodes["parent_id"].replace(to_fix) + + # 3. See if any parent IDs have ceased to exist + missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( + self._nodes.parent_id >= 0 + ) + if any(missing_parents): + self.nodes.loc[missing_parents, "parent_id"] = -1 + + # Force nodes to be re-classified + self.nodes.drop("type", axis=1, errors="ignore", inplace=True) + + # TODO: Make sure that edges have a consistent orientation + # (not sure this is much of a problem) + + # Connectors + # TODO: check `node_id` of connectors too + if "_connectors" in self._masked_data: + if not hasattr(self, "_connectors"): + self._connectors = self._masked_data["_connectors"] + else: + cn = self._masked_data["_connectors"] + cn = cn.loc[cn.node_id.isin(pre_nodes.node_id.values)] + self._connectors = pd.concat([self._connectors, cn], ignore_index=True) + + del self._mask + del self._masked_data + + self._clear_temp_attr() + + return self def resample(self, resample_to, inplace=False): """Resample neuron to given resolution. @@ -939,18 +1197,6 @@ def resample(self, resample_to, inplace=False): return x return None - @overload - def downsample(self, - factor: float, - inplace: Literal[False], - **kwargs) -> 'TreeNeuron': ... - - @overload - def downsample(self, - factor: float, - inplace: Literal[True], - **kwargs) -> None: ... - def downsample(self, factor=5, inplace=False, **kwargs): """Downsample the neuron by given factor. @@ -987,9 +1233,9 @@ def downsample(self, factor=5, inplace=False, **kwargs): return x return None - def reroot(self, - new_root: Union[int, str], - inplace: bool = False) -> Optional['TreeNeuron']: + def reroot( + self, new_root: Union[int, str], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Reroot neuron to given node ID or node tag. Parameters @@ -1020,9 +1266,9 @@ def reroot(self, return x return None - def prune_distal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_distal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Cut off nodes distal to given nodes. Parameters @@ -1047,7 +1293,7 @@ def prune_distal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - prox = graph.cut_skeleton(x, n, ret='proximal')[0] + prox = graph.cut_skeleton(x, n, ret="proximal")[0] # Reinitialise with proximal data x.__init__(prox) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1057,9 +1303,9 @@ def prune_distal_to(self, return x return None - def prune_proximal_to(self, - node: Union[str, int], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_proximal_to( + self, node: Union[str, int], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Remove nodes proximal to given node. Reroots neuron to cut node. Parameters @@ -1084,7 +1330,7 @@ def prune_proximal_to(self, node = utils.make_iterable(node, force_type=None) for n in node: - dist = graph.cut_skeleton(x, n, ret='distal')[0] + dist = graph.cut_skeleton(x, n, ret="distal")[0] # Reinitialise with distal data x.__init__(dist) # type: ignore # Cannot access "__init__" directly # Remove potential "left over" attributes (happens if we use a copy) @@ -1097,9 +1343,9 @@ def prune_proximal_to(self, return x return None - def prune_by_strahler(self, - to_prune: Union[int, List[int], slice], - inplace: bool = False) -> Optional['TreeNeuron']: + def prune_by_strahler( + self, to_prune: Union[int, List[int], slice], inplace: bool = False + ) -> Optional["TreeNeuron"]: """Prune neuron based on [Strahler order](https://en.wikipedia.org/wiki/Strahler_number). Will reroot neuron to soma if possible. @@ -1132,8 +1378,7 @@ def prune_by_strahler(self, else: x = self.copy() - morpho.prune_by_strahler( - x, to_prune=to_prune, reroot_soma=True, inplace=True) + morpho.prune_by_strahler(x, to_prune=to_prune, reroot_soma=True, inplace=True) # No need to call this as morpho.prune_by_strahler does this already # self._clear_temp_attr() @@ -1142,17 +1387,23 @@ def prune_by_strahler(self, return x return None - def prune_twigs(self, - size: float, - inplace: bool = False, - recursive: Union[int, bool, float] = False - ) -> Optional['TreeNeuron']: + def prune_twigs( + self, + size: float, + mask: Optional[np.ndarray] = None, + inplace: bool = False, + recursive: Union[int, bool, float] = False, + ) -> Optional["TreeNeuron"]: """Prune terminal twigs under a given size. Parameters ---------- size : int | float Twigs shorter than this will be pruned. + mask : iterable | callable, optional + Either a boolean mask, a list of node IDs or a callable taking + a neuron as input and returning one of the former. If provided, + only nodes that are in the mask will be considered for pruning. inplace : bool, optional If False, pruning is performed on copy of original neuron which is then returned. @@ -1172,17 +1423,18 @@ def prune_twigs(self, else: x = self.copy() - morpho.prune_twigs(x, size=size, inplace=True) + morpho.prune_twigs(x, size=size, mask=mask, inplace=True) if not inplace: return x return None - def prune_at_depth(self, - depth: Union[float, int], - source: Optional[int] = None, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_at_depth( + self, + depth: Union[float, int], + source: Optional[int] = None, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune all neurites past a given distance from a source. Parameters @@ -1220,10 +1472,11 @@ def prune_at_depth(self, return x return None - def cell_body_fiber(self, - reroot_soma: bool = True, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def cell_body_fiber( + self, + reroot_soma: bool = True, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron to its cell body fiber. Parameters @@ -1255,11 +1508,12 @@ def cell_body_fiber(self, return x return None - def prune_by_longest_neurite(self, - n: int = 1, - reroot_soma: bool = False, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def prune_by_longest_neurite( + self, + n: int = 1, + reroot_soma: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron down to the longest neurite. Parameters @@ -1284,8 +1538,7 @@ def prune_by_longest_neurite(self, else: x = self.copy() - graph.longest_neurite( - x, n, inplace=True, reroot_soma=reroot_soma) + graph.longest_neurite(x, n, inplace=True, reroot_soma=reroot_soma) # Clear temporary attributes x._clear_temp_attr() @@ -1294,14 +1547,13 @@ def prune_by_longest_neurite(self, return x return None - def prune_by_volume(self, - v: Union[core.Volume, - List[core.Volume], - Dict[str, core.Volume]], - mode: Union[Literal['IN'], Literal['OUT']] = 'IN', - prevent_fragments: bool = False, - inplace: bool = False - ) -> Optional['TreeNeuron']: + def prune_by_volume( + self, + v: Union[core.Volume, List[core.Volume], Dict[str, core.Volume]], + mode: Union[Literal["IN"], Literal["OUT"]] = "IN", + prevent_fragments: bool = False, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Prune neuron by intersection with given volume(s). Parameters @@ -1330,9 +1582,9 @@ def prune_by_volume(self, else: x = self.copy() - intersection.in_volume(x, v, inplace=True, - prevent_fragments=prevent_fragments, - mode=mode) + intersection.in_volume( + x, v, inplace=True, prevent_fragments=prevent_fragments, mode=mode + ) # Clear temporary attributes # x._clear_temp_attr() @@ -1341,9 +1593,7 @@ def prune_by_volume(self, return x return None - def to_swc(self, - filename: Optional[str] = None, - **kwargs) -> None: + def to_swc(self, filename: Optional[str] = None, **kwargs) -> None: """Generate SWC file from this neuron. Parameters @@ -1365,9 +1615,10 @@ def to_swc(self, """ return io.write_swc(self, filename, **kwargs) # type: ignore # double import of "io" - def reload(self, - inplace: bool = False, - ) -> Optional['TreeNeuron']: + def reload( + self, + inplace: bool = False, + ) -> Optional["TreeNeuron"]: """Reload neuron. Must have filepath as `.origin` as attribute. Returns @@ -1376,19 +1627,22 @@ def reload(self, If `inplace=False`. """ - if not hasattr(self, 'origin'): - raise AttributeError('To reload TreeNeuron must have `.origin` ' - 'attribute') + if not hasattr(self, "origin"): + raise AttributeError( + "To reload TreeNeuron must have `.origin` " "attribute" + ) - if self.origin in ('DataFrame', 'string'): - raise ValueError('Unable to reload TreeNeuron: it appears to have ' - 'been created from string or DataFrame.') + if self.origin in ("DataFrame", "string"): + raise ValueError( + "Unable to reload TreeNeuron: it appears to have " + "been created from string or DataFrame." + ) kwargs = {} - if hasattr(self, 'soma_label'): - kwargs['soma_label'] = self.soma_label - if hasattr(self, 'connector_labels'): - kwargs['connector_labels'] = self.connector_labels + if hasattr(self, "soma_label"): + kwargs["soma_label"] = self.soma_label + if hasattr(self, "connector_labels"): + kwargs["connector_labels"] = self.connector_labels x = io.read_swc(self.origin, **kwargs) @@ -1403,7 +1657,7 @@ def reload(self, x2._clear_temp_attr() return x - def snap(self, locs, to='nodes'): + def snap(self, locs, to="nodes"): """Snap xyz location(s) to closest node or synapse. Parameters @@ -1431,15 +1685,16 @@ def snap(self, locs, to='nodes'): """ locs = np.asarray(locs).astype(np.float64) - is_single = (locs.ndim == 1 and len(locs) == 3) - is_multi = (locs.ndim == 2 and locs.shape[1] == 3) + is_single = locs.ndim == 1 and len(locs) == 3 + is_multi = locs.ndim == 2 and locs.shape[1] == 3 if not is_single and not is_multi: - raise ValueError('Expected a single (x, y, z) location or a ' - '(N, 3) array of multiple locations') + raise ValueError( + "Expected a single (x, y, z) location or a " + "(N, 3) array of multiple locations" + ) - if to not in ['nodes', 'connectors']: - raise ValueError('`to` must be "nodes" or "connectors", ' - f'got {to}') + if to not in ["nodes", "connectors"]: + raise ValueError('`to` must be "nodes" or "connectors", ' f"got {to}") # Generate tree tree = graph.neuron2KDTree(self, data=to) @@ -1447,7 +1702,7 @@ def snap(self, locs, to='nodes'): # Find the closest node dist, ix = tree.query(locs) - if to == 'nodes': + if to == "nodes": id = self.nodes.node_id.values[ix] else: id = self.connectors.connector_id.values[ix] diff --git a/navis/core/voxel.py b/navis/core/voxel.py index 4a3e5cc3..b8934851 100644 --- a/navis/core/voxel.py +++ b/navis/core/voxel.py @@ -94,6 +94,9 @@ class VoxelNeuron(BaseNeuron): #: Core data table(s) used to calculate hash CORE_DATA = ['_data'] + #: Property used to calculate length of neuron + _LENGTH_DATA = 'voxels' + def __init__(self, x: Union[np.ndarray], offset: Optional[np.ndarray] = None, @@ -216,6 +219,12 @@ def __sub__(self, other, copy=True): return n return NotImplemented + def __len__(self): + """Return number of voxels.""" + # This is the only neuron for which we implement __len__ separately + # to avoid the potential overhead of generating the voxel representation + return np.product(self.shape) + @property def _base_data_type(self) -> str: """Type of data (grid or voxels) underlying this neuron.""" diff --git a/navis/morpho/subset.py b/navis/morpho/subset.py index 82c5b438..7b0d21fd 100644 --- a/navis/morpho/subset.py +++ b/navis/morpho/subset.py @@ -308,13 +308,9 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): axis=1, ) - # Make sure any new roots or leafs are properly typed - # We won't produce new slabs but roots and leaves might change - x.nodes.loc[x.nodes.parent_id < 0, "type"] = "root" - x.nodes.loc[ - (~x.nodes.node_id.isin(x.nodes.parent_id.values) & (x.nodes.parent_id >= 0)), - "type", - ] = "end" + # Make sure nodes are correctly classified (need to do this because we're not actually clearing + # the temporary attributes) + graph.classify_nodes(x, inplace=True) # Filter connectors if not keep_disc_cn and x.has_connectors: @@ -359,7 +355,7 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): return x -def submesh(mesh, *, faces_index=None, vertex_index=None): +def submesh(mesh, *, faces_index=None, vertex_index=None, return_map=False): """Re-imlementation of trimesh.submesh that is faster for our use case. Notably we: @@ -382,6 +378,9 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): Indices of faces to keep. vertex_index : array-like Indices of vertices to keep. + return_map : bool, optional + If True, will return a mapping of old to new vertex and + face indices. Returns ------- @@ -389,12 +388,20 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): Vertices of submesh. faces : np.ndarray Faces of submesh. + vert_map : np.ndarray + Only returned if `return_map` is True. Mapping of old vertex indices + to new vertex indices. Vertices that are not in the submesh have a + value of -1. + face_map : np.ndarray + Only returned if `return_map` is True. Mapping of old face indices + to new face indices. Faces that are not in the submesh have a value + of -1. """ if faces_index is None and vertex_index is None: raise ValueError("Either `faces_index` or `vertex_index` must be provided.") elif faces_index is not None and vertex_index is not None: - raise ValueError("Only one of `faces_index` or `vertex_index` can be provided.") + raise ValueError("Must provide either `faces_index` or `vertex_index`, not both.") # First check if we can return either an empty mesh or the original mesh right away if faces_index is not None: @@ -439,4 +446,14 @@ def submesh(mesh, *, faces_index=None, vertex_index=None): # (making a copy to allow `mask` to be garbage collected) faces = mask[faces].copy() - return vertices, faces + if not return_map: + return vertices, faces + else: + face_map = np.full(len(original_faces), -1, dtype=np.int32) + face_map[faces_index] = np.arange(len(faces_index)) + vert_map = np.full(len(original_vertices), -1, dtype=np.int32) + vert_map[unique] = np.arange(len(unique)) + return vertices, faces, vert_map, face_map + + + diff --git a/navis/utils/decorators.py b/navis/utils/decorators.py index 7ce83faf..d509011d 100644 --- a/navis/utils/decorators.py +++ b/navis/utils/decorators.py @@ -35,10 +35,12 @@ from .iterables import is_iterable, make_iterable -def map_neuronlist(desc: str = "", - can_zip: List[Union[str, int]] = [], - must_zip: List[Union[str, int]] = [], - allow_parallel: bool = False): +def map_neuronlist( + desc: str = "", + can_zip: List[Union[str, int]] = [], + must_zip: List[Union[str, int]] = [], + allow_parallel: bool = False, +): """Decorate function to run on all neurons in the NeuronList. This also updates the docstring. @@ -78,6 +80,7 @@ def map_neuronlist(desc: str = "", of cores a can be set using `n_cores` keyword argument. """ + # TODO: # - make can_zip/must_zip work with positional-only argumens to, i.e. let # it work with integers instead of strings @@ -85,6 +88,7 @@ def decorator(function): @wraps(function) def wrapper(*args, **kwargs): from .. import core + # Get the function's signature sig = inspect.signature(function) @@ -93,17 +97,18 @@ def wrapper(*args, **kwargs): except BaseException: fnname = str(function) - parallel = kwargs.pop('parallel', False) + parallel = kwargs.pop("parallel", False) if parallel and not allow_parallel: - raise ValueError(f'Function {fnname} does not support parallel ' - 'processing.') + raise ValueError( + f"Function {fnname} does not support parallel " "processing." + ) # First, we need to extract the neuronlist if args: # If there are positional arguments, the first one is # the input neuron(s) nl = args[0] - nl_key = '__args' + nl_key = "__args" else: # If not, we need to look for the name of the first argument # in the signature @@ -112,14 +117,16 @@ def wrapper(*args, **kwargs): # Complain if we did not get what we expected if isinstance(nl, type(None)): - raise ValueError('Unable to identify the neurons for call' - f'{fnname}:\n {args}\n {kwargs}') + raise ValueError( + "Unable to identify the neurons for call" + f"{fnname}:\n {args}\n {kwargs}" + ) # If we have a neuronlist if isinstance(nl, core.NeuronList): # Pop the neurons from kwargs or args so we don't pass the # neurons twice - if nl_key == '__args': + if nl_key == "__args": args = args[1:] else: _ = kwargs.pop(nl_key) @@ -134,8 +141,9 @@ def wrapper(*args, **kwargs): # If iterable but length does not match: complain le = len(kwargs[p]) if le != len(nl): - raise ValueError(f'Got {le} values of `{p}` for ' - f'{len(nl)} neurons.') + raise ValueError( + f"Got {le} values of `{p}` for " f"{len(nl)} neurons." + ) # Parse "must zip" arguments for p in must_zip: @@ -145,38 +153,43 @@ def wrapper(*args, **kwargs): values = make_iterable(kwargs[p]) if len(values) != len(nl): - raise ValueError(f'Got {len(values)} values of `{p}` for ' - f'{len(nl)} neurons.') + raise ValueError( + f"Got {len(values)} values of `{p}` for " + f"{len(nl)} neurons." + ) # If we use parallel processing it makes sense to modify neurons # "inplace" since they will be copied into the child processes # anyway and that way we can avoid making an additional copy - if 'inplace' in kwargs: + if "inplace" in kwargs: # First check keyword arguments - inplace = kwargs['inplace'] - elif 'inplace' in sig.parameters: + inplace = kwargs["inplace"] + elif "inplace" in sig.parameters: # Next check signatures default - inplace = sig.parameters['inplace'].default + inplace = sig.parameters["inplace"].default else: # All things failing assume it's not inplace inplace = False - if parallel and 'inplace' in sig.parameters: - kwargs['inplace'] = True + if parallel and "inplace" in sig.parameters: + kwargs["inplace"] = True # Prepare processor - n_cores = kwargs.pop('n_cores', os.cpu_count() // 2) - chunksize = kwargs.pop('chunksize', 1) + n_cores = kwargs.pop("n_cores", os.cpu_count() // 2) + chunksize = kwargs.pop("chunksize", 1) excl = list(kwargs.keys()) + list(range(1, len(args) + 1)) - proc = core.NeuronProcessor(nl, function, - parallel=parallel, - desc=desc, - warn_inplace=False, - progress=kwargs.pop('progress', True), - omit_failures=kwargs.pop('omit_failures', False), - chunksize=chunksize, - exclude_zip=excl, - n_cores=n_cores) + proc = core.NeuronProcessor( + nl, + function, + parallel=parallel, + desc=desc, + warn_inplace=False, + progress=kwargs.pop("progress", True), + omit_failures=kwargs.pop("omit_failures", False), + chunksize=chunksize, + exclude_zip=excl, + n_cores=n_cores, + ) # Apply function res = proc(nl, *args, **kwargs) @@ -201,10 +214,12 @@ def wrapper(*args, **kwargs): return decorator -def map_neuronlist_df(desc: str = "", - id_col: str = "neuron", - reset_index: bool = True, - allow_parallel: bool = False): +def map_neuronlist_df( + desc: str = "", + id_col: str = "neuron", + reset_index: bool = True, + allow_parallel: bool = False, +): """Decorate function to run on all neurons in the NeuronList. This version of the decorator is meant for functions that return a @@ -227,6 +242,7 @@ def map_neuronlist_df(desc: str = "", of cores a can be set using `n_cores` keyword argument. """ + # TODO: # - make can_zip/must_zip work with positional-only argumens to, i.e. let # it work with integers instead of strings @@ -235,6 +251,7 @@ def decorator(function): def wrapper(*args, **kwargs): # Lazy import to avoid issues with circular imports and pickling from .. import core + # Get the function's signature sig = inspect.signature(function) @@ -243,17 +260,18 @@ def wrapper(*args, **kwargs): except BaseException: fnname = str(function) - parallel = kwargs.pop('parallel', False) + parallel = kwargs.pop("parallel", False) if parallel and not allow_parallel: - raise ValueError(f'Function {fnname} does not allow parallel ' - 'processing.') + raise ValueError( + f"Function {fnname} does not allow parallel " "processing." + ) # First, we need to extract the neuronlist if args: # If there are positional arguments, the first one is # the input neuron(s) nl = args[0] - nl_key = '__args' + nl_key = "__args" else: # If not, we need to look for the name of the first argument # in the signature @@ -262,31 +280,36 @@ def wrapper(*args, **kwargs): # Complain if we did not get what we expected if isinstance(nl, type(None)): - raise ValueError('Unable to identify the neurons for call' - f'{fnname}:\n {args}\n {kwargs}') + raise ValueError( + "Unable to identify the neurons for call" + f"{fnname}:\n {args}\n {kwargs}" + ) # If we have a neuronlist if isinstance(nl, core.NeuronList): # Pop the neurons from kwargs or args so we don't pass the # neurons twice - if nl_key == '__args': + if nl_key == "__args": args = args[1:] else: _ = kwargs.pop(nl_key) # Prepare processor - n_cores = kwargs.pop('n_cores', os.cpu_count() // 2) - chunksize = kwargs.pop('chunksize', 1) + n_cores = kwargs.pop("n_cores", os.cpu_count() // 2) + chunksize = kwargs.pop("chunksize", 1) excl = list(kwargs.keys()) + list(range(1, len(args) + 1)) - proc = core.NeuronProcessor(nl, function, - parallel=parallel, - desc=desc, - warn_inplace=False, - progress=kwargs.pop('progress', True), - omit_failures=kwargs.pop('omit_failures', False), - chunksize=chunksize, - exclude_zip=excl, - n_cores=n_cores) + proc = core.NeuronProcessor( + nl, + function, + parallel=parallel, + desc=desc, + warn_inplace=False, + progress=kwargs.pop("progress", True), + omit_failures=kwargs.pop("omit_failures", False), + chunksize=chunksize, + exclude_zip=excl, + n_cores=n_cores, + ) # Apply function res = proc(nl, *args, **kwargs) @@ -316,20 +339,20 @@ def wrapper(*args, **kwargs): def map_neuronlist_update_docstring(func, allow_parallel): """Add additional parameters to docstring of function.""" # Parse docstring - lines = func.__doc__.split('\n') + lines = func.__doc__.split("\n") # Find a line with a parameter - pline = [l for l in lines if ' : ' in l][0] + pline = [l for l in lines if " : " in l][0] # Get the leading whitespaces - wspaces = ' ' * re.search('( *)', pline).end(1) + wspaces = " " * re.search("( *)", pline).end(1) # Get the offset for type and description - offset = re.search('( *: *)', pline).end(1) - len(wspaces) + offset = re.search("( *: *)", pline).end(1) - len(wspaces) # Find index of the last parameters (assuming there is a single empty # line between Returns and the last parameter) - lastp = [i for i, l in enumerate(lines) if ' Returns' in l][0] - 1 + lastp = [i for i, l in enumerate(lines) if " Returns" in l][0] - 1 - msg = '' + msg = "" if allow_parallel: msg += dedent(f"""\ parallel :{" " * (offset - 10)}bool @@ -353,7 +376,7 @@ def map_neuronlist_update_docstring(func, allow_parallel): lines.insert(lastp, indent(msg, wspaces)) # Update docstring - func.__doc__ = '\n'.join(lines) + func.__doc__ = "\n".join(lines) return func @@ -365,6 +388,7 @@ def lock_neuron(function): are being made. """ + @wraps(function) def wrapper(*args, **kwargs): # Lazy import to avoid issues with circular imports and pickling @@ -372,7 +396,7 @@ def wrapper(*args, **kwargs): # Lock if first argument is a neuron if isinstance(args[0], core.BaseNeuron): - args[0]._lock = getattr(args[0], '_lock', 0) + 1 + args[0]._lock = getattr(args[0], "_lock", 0) + 1 try: # Execute function res = function(*args, **kwargs) @@ -384,20 +408,25 @@ def wrapper(*args, **kwargs): args[0]._lock -= 1 # Return result return res + return wrapper -def meshneuron_skeleton(method: Union[Literal['subset'], - Literal['split'], - Literal['node_properties'], - Literal['node_to_vertex'], - Literal['pass_through']], - include_connectors: bool = False, - copy_properties: list = [], - disallowed_kwargs: dict = {}, - node_props: list = [], - reroot_soma: bool = False, - heal: bool = False): +def meshneuron_skeleton( + method: Union[ + Literal["subset"], + Literal["split"], + Literal["node_properties"], + Literal["node_to_vertex"], + Literal["pass_through"], + ], + include_connectors: bool = False, + copy_properties: list = [], + disallowed_kwargs: dict = {}, + node_props: list = [], + reroot_soma: bool = False, + heal: bool = False, +): """Decorate function such that MeshNeurons are automatically skeletonized, the function is run on the skeleton and changes are propagated back to the meshe. @@ -435,12 +464,17 @@ def meshneuron_skeleton(method: Union[Literal['subset'], assert isinstance(disallowed_kwargs, dict) assert isinstance(node_props, list) - allowed_methods = ('subset', 'node_to_vertex', 'split', 'node_properties', - 'pass_through') + allowed_methods = ( + "subset", + "node_to_vertex", + "split", + "node_properties", + "pass_through", + ) if method not in allowed_methods: raise ValueError(f'Unknown method "{method}"') - if method == 'node_properties' and not node_props: + if method == "node_properties" and not node_props: raise ValueError('Must provide `node_props` for method "node_properties"') def decorator(function): @@ -460,7 +494,7 @@ def wrapper(*args, **kwargs): # be the input neuron x = args[0] args = args[1:] - x_key = '__args' + x_key = "__args" else: # If not, we need to look for the name of the first argument # in the signature @@ -469,55 +503,67 @@ def wrapper(*args, **kwargs): # Complain if we did not get what we expected if isinstance(x, type(None)): - raise ValueError('Unable to identify the neurons for call' - f'{fnname}:\n {args}\n {kwargs}') + raise ValueError( + "Unable to identify the neurons for call" + f"{fnname}:\n {args}\n {kwargs}" + ) # If input not a MeshNeuron, just pass through # Note delayed import to avoid circular imports and IMPORTANTLY # funky interactions with pickle/dill from .. import core + if not isinstance(x, core.MeshNeuron): return function(x, *args, **kwargs) # Check for disallowed kwargs for k, v in disallowed_kwargs.items(): if k in kwargs and kwargs[k] == v: - raise ValueError(f'{k}={v} is not allowed when input is ' - 'MeshNeuron(s).') + raise ValueError( + f"{k}={v} is not allowed when input is " "MeshNeuron(s)." + ) # See if this is meant to be done inplace - if 'inplace' in kwargs: + if "inplace" in kwargs: # First check keyword arguments - inplace = kwargs['inplace'] - elif 'inplace' in sig.parameters: + inplace = kwargs["inplace"] + elif "inplace" in sig.parameters: # Next check signatures default - inplace = sig.parameters['inplace'].default + inplace = sig.parameters["inplace"].default else: # All things failing assume it's not inplace inplace = False - # Now skeletonize + # Now skeletonize (if the skeleton is not already present) sk = x.skeleton + if method != "pass_through" and not hasattr(sk, "vertex_map"): + raise ValueError( + "MeshNeuron must have a skeleton with a vertex->node mapping " + "as `.vertex_map` property to apply this function." + ) + # Delayed import to avoid circular imports # Note that this HAS to be in the inner function otherwise # we get a weird error when pickling for parallel processing from .. import morpho - if heal: - sk = morpho.heal_skeleton(sk, method='LEAFS') + if heal and len(sk.roots) > 1: + sk = morpho.heal_skeleton(sk, method="LEAFS") if reroot_soma and sk.has_soma: sk = sk.reroot(sk.soma) if include_connectors and x.has_connectors and not sk.has_connectors: sk._connectors = x.connectors.copy() - sk._connectors['node_id'] = sk.snap(sk.connectors[['x', 'y', 'z']].values)[0] + sk._connectors["node_id"] = sk.snap( + sk.connectors[["x", "y", "z"]].values + )[0] # Apply function res = function(sk, *args, **kwargs) - if method == 'subset': + if method == "subset": # See which vertices we need to keep keep = np.isin(sk.vertex_map, res.nodes.node_id.values) @@ -525,7 +571,7 @@ def wrapper(*args, **kwargs): for p in copy_properties: setattr(x, p, getattr(sk, p, None)) - elif method == 'split': + elif method == "split": meshes = [] for n in res: # See which vertices we need to keep @@ -536,14 +582,14 @@ def wrapper(*args, **kwargs): for p in copy_properties: setattr(meshes[-1], p, getattr(n, p, None)) x = core.NeuronList(meshes) - elif method == 'node_to_vertex': + elif method == "node_to_vertex": x = np.where(sk.vertex_map == res)[0] - elif method == 'node_properties': + elif method == "node_properties": for p in node_props: - node_map = sk.nodes.set_index('node_id')[p].to_dict() + node_map = sk.nodes.set_index("node_id")[p].to_dict() vertex_props = np.array([node_map[n] for n in sk.vertex_map]) setattr(x, p, vertex_props) - elif method == 'pass_through': + elif method == "pass_through": return res return x diff --git a/navis/utils/eval.py b/navis/utils/eval.py index 9ce23174..dd3bf84d 100644 --- a/navis/utils/eval.py +++ b/navis/utils/eval.py @@ -172,7 +172,7 @@ def eval_id(x: Union[uuid.UUID, str, 'core.NeuronObject', pd.DataFrame], List containing IDs. """ - if isinstance(x, (uuid.UUID, str, np.str, int, np.integer)): + if isinstance(x, (uuid.UUID, str, np.str_, int, np.integer)): return [x] elif isinstance(x, (list, np.ndarray, set)): uu: List[uuid.UUID] = [] @@ -288,7 +288,7 @@ def eval_node_ids(x: Union[int, str, """ if isinstance(x, (int, np.integer)): return [x] - elif isinstance(x, (str, np.str)): + elif isinstance(x, (str, np.str_)): try: return [int(x)] except BaseException: