From 7de016e9608c586b9f5e2db9b4987b36358b001e Mon Sep 17 00:00:00 2001
From: Philipp Schlegel <observing@web.de>
Date: Thu, 24 Oct 2024 14:35:09 +0100
Subject: [PATCH] a few additional fixes for masking: - fix NeuronMask doctest
 - TreeNeuron.un/mask: make sure to re-classify - TreeNeuron.unmask: fix
 re-connecting

---
 navis/core/masking.py  |  2 +-
 navis/core/skeleton.py | 14 ++++++++++++--
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/navis/core/masking.py b/navis/core/masking.py
index 345de824..c5a78861 100644
--- a/navis/core/masking.py
+++ b/navis/core/masking.py
@@ -63,7 +63,7 @@ class NeuronMask:
     >>> # Grab a few skeletons
     >>> nl = navis.example_neurons(3)
     >>> # Label axon and dendrites
-    >>> navis.split_axon_dendrite(nl, label_only=True)
+    >>> _ = navis.split_axon_dendrite(nl, label_only=True)
     >>> # Mask by axon
     >>> with navis.NeuronMask(nl, lambda x: x.nodes.compartment == 'axon'):
     ...    print("Axon cable length:", nl.cable_length * nl[0].units)
diff --git a/navis/core/skeleton.py b/navis/core/skeleton.py
index 892e7c88..41937e5d 100644
--- a/navis/core/skeleton.py
+++ b/navis/core/skeleton.py
@@ -1009,10 +1009,17 @@ def mask(self, mask, copy=True):
         self._masked_data["_nodes"] = self.nodes
 
         # N.B. we're directly setting `._nodes`` to avoid overhead from checks
-        self._nodes = self._nodes.loc[mask]
+        self._nodes = self._nodes.loc[mask].drop("type", axis=1, errors="ignore")
         if copy:
             self._nodes = self._nodes.copy()
 
+        # See if any parent IDs have ceased to exist
+        missing_parents = ~self._nodes.parent_id.isin(self._nodes.node_id) & (
+            self._nodes.parent_id >= 0
+        )
+        if any(missing_parents):
+            self.nodes.loc[missing_parents, "parent_id"] = -1
+
         if hasattr(self, "_connectors"):
             self._masked_data["_connectors"] = self.connectors
             self._connectors = self._connectors.loc[
@@ -1092,7 +1099,7 @@ def unmask(self, reset=True):
             if r not in pre_parents:
                 continue
             # Skip if this was also a root in the pre-masked data
-            if pre_parents[r] >= 0:
+            if pre_parents[r] < 0:
                 continue
             # Skip if the old parent does not exist anymore
             if pre_parents[r] not in self.nodes.node_id.values:
@@ -1110,6 +1117,9 @@ def unmask(self, reset=True):
         if any(missing_parents):
             self.nodes.loc[missing_parents, "parent_id"] = -1
 
+        # Force nodes to be re-classified
+        self.nodes.drop("type", axis=1, errors="ignore", inplace=True)
+
         # TODO: Make sure that edges have a consistent orientation
         # (not sure this is much of a problem)