From 7de016e9608c586b9f5e2db9b4987b36358b001e Mon Sep 17 00:00:00 2001 From: Philipp Schlegel <observing@web.de> Date: Thu, 24 Oct 2024 14:35:09 +0100 Subject: [PATCH] a few additional fixes for masking: - fix NeuronMask doctest - TreeNeuron.un/mask: make sure to re-classify - TreeNeuron.unmask: fix re-connecting --- navis/core/masking.py | 2 +- navis/core/skeleton.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/navis/core/masking.py b/navis/core/masking.py index 345de824..c5a78861 100644 --- a/navis/core/masking.py +++ b/navis/core/masking.py @@ -63,7 +63,7 @@ class NeuronMask: >>> # Grab a few skeletons >>> nl = navis.example_neurons(3) >>> # Label axon and dendrites - >>> navis.split_axon_dendrite(nl, label_only=True) + >>> _ = navis.split_axon_dendrite(nl, label_only=True) >>> # Mask by axon >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'): ... print("Axon cable length:", nl.cable_length * nl[0].units) diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py index 892e7c88..41937e5d 100644 --- a/navis/core/skeleton.py +++ b/navis/core/skeleton.py @@ -1009,10 +1009,17 @@ def mask(self, mask, copy=True): self._masked_data["_nodes"] = self.nodes # N.B. we're directly setting `._nodes`` to avoid overhead from checks - self._nodes = self._nodes.loc[mask] + self._nodes = self._nodes.loc[mask].drop("type", axis=1, errors="ignore") if copy: self._nodes = self._nodes.copy() + # See if any parent IDs have ceased to exist + missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & ( + self._nodes.parent_id >= 0 + ) + if any(missing_parents): + self.nodes.loc[missing_parents, "parent_id"] = -1 + if hasattr(self, "_connectors"): self._masked_data["_connectors"] = self.connectors self._connectors = self._connectors.loc[ @@ -1092,7 +1099,7 @@ def unmask(self, reset=True): if r not in pre_parents: continue # Skip if this was also a root in the pre-masked data - if pre_parents[r] >= 0: + if pre_parents[r] < 0: continue # Skip if the old parent does not exist anymore if pre_parents[r] not in self.nodes.node_id.values: @@ -1110,6 +1117,9 @@ def unmask(self, reset=True): if any(missing_parents): self.nodes.loc[missing_parents, "parent_id"] = -1 + # Force nodes to be re-classified + self.nodes.drop("type", axis=1, errors="ignore", inplace=True) + # TODO: Make sure that edges have a consistent orientation # (not sure this is much of a problem)