diff --git a/navis/sampling/resampling.py b/navis/sampling/resampling.py index 63c5922e..a2c9d715 100644 --- a/navis/sampling/resampling.py +++ b/navis/sampling/resampling.py @@ -65,13 +65,17 @@ def resample_neuron(x: 'core.NeuronObject', ) -> Optional['core.NeuronObject']: """Resample neuron(s) to given resolution. - Preserves root, leafs, branchpoints. Connectors (if they exist) are mapped - onto the closest new node. + Preserves root, leafs, branchpoints. Soma, connectors and node tags + (if present) are mapped onto the closest node in the resampled neuron. Important --------- - This generates an entirely new set of node IDs! Those will be unique - within a neuron, but you may encounter duplicates across neurons. + A few things to keep in mind: + - This generates an entirely new set of node IDs! Those will be unique + within a neuron, but you may encounter duplicates across neurons. + - Any non-standard node table columns (e.g. "labels") will be lost. + - Soma(s) will be pinned to the closest node in the resampled neuron. + Also: be aware that high-resolution neurons will use A LOT of memory. @@ -128,15 +132,17 @@ def resample_neuron(x: 'core.NeuronObject', if isinstance(x, core.NeuronList): if not inplace: x = x.copy() - results = [resample_neuron(x[i], resample_to, - method=method, inplace=True, - skip_errors=skip_errors) - for i in config.trange(x.shape[0], - desc='Resampl. neurons', - disable=config.pbar_hide, - leave=config.pbar_leave)] + _ = [resample_neuron(x[i], + resample_to=resample_to, + method=method, inplace=True, + skip_errors=skip_errors) + for i in config.trange(x.shape[0], + desc='Resampl. neurons', + disable=config.pbar_hide, + leave=config.pbar_leave)] + if not inplace: - return core.NeuronList(results) + return x return None elif not isinstance(x, core.TreeNeuron): raise TypeError(f'Unable to resample data of type "{type(x)}"') @@ -237,32 +243,72 @@ def resample_neuron(x: 'core.NeuronObject', # Generate new nodes dataframe new_nodes = pd.DataFrame(data=new_nodes, columns=['node_id', 'parent_id', - 'x', 'y', 'z', 'radius'], - dtype=object - ) + 'x', 'y', 'z', 'radius']) # Convert columns to appropriate dtypes - dtypes = {'node_id': int, 'parent_id': int, 'x': float, 'y': float, - 'z': float, 'radius': float} + dtypes = {k: x.nodes[k].dtype for k in ['node_id', 'parent_id', 'x', 'y', 'z', 'radius']} - for k, v in dtypes.items(): - new_nodes[k] = new_nodes[k].astype(v) + for cols in new_nodes.columns: + new_nodes = new_nodes.astype(dtypes, errors='ignore') # Remove duplicate nodes (branch points) new_nodes = new_nodes[~new_nodes.node_id.duplicated()] + # Generate KDTree + tree = scipy.spatial.cKDTree(new_nodes[['x', 'y', 'z']].values) + # Map soma onto new nodes if required + # Note that if `._soma` is a soma detection function we can't tell + # how to deal with it. Ideally the new soma node will + # be automatically detected but it is possible, for example, that + # the radii of nodes have changed due to interpolation such that more + # than one soma is detected now. Also a "label" column in the node + # table would be lost at this point. + # We will go for the easy option which is to pin the soma at this point. + if np.any(getattr(x, 'soma')): + soma_nodes = utils.make_iterable(x.soma) + old_pos = nodes.loc[soma_nodes, ['x', 'y', 'z']].values + + # Get nearest neighbours + dist, ix = tree.query(old_pos) + node_map = dict(zip(soma_nodes, new_nodes.node_id.values[ix])) + + # Map back onto neuron + if utils.is_iterable(x.soma): + x.soma = [node_map[n] for n in x.soma] + else: + x.soma = node_map[x.soma] + else: + # If `._soma` was (read: is) a function but it didn't detect anything in + # the original neurons, this makes sure that the resampled neuron + # doesn't have a soma either: + x.soma = None + + # Map connectors back if necessary if x.has_connectors: - # Map connectors back: - # 1. Get position of old synapse-bearing nodes - old_tn_position = x.nodes.set_index('node_id', - inplace=False).loc[x.connectors.node_id, - ['x', 'y', 'z']].values - # 2. Get closest neighbours - distances = scipy.spatial.distance.cdist(old_tn_position, - new_nodes[['x', 'y', 'z']].values) - min_ix = np.argmin(distances, axis=1) - # 3. Map back onto neuron - x.connectors['node_id'] = new_nodes.iloc[min_ix].node_id.values + # Get position of old synapse-bearing nodes + old_tn_position = nodes.loc[x.connectors.node_id, ['x', 'y', 'z']].values + + # Get nearest neighbours + dist, ix = tree.query(old_tn_position) + + # Map back onto neuron + x.connectors['node_id'] = new_nodes.node_id.values[ix] + + # Map tags back if necessary + # Expects `tags` to be a dictionary {'tag': [node_id1, node_id2, ...]} + if x.has_tags and isinstance(x.tags, dict): + # Get nodes that need remapping + nodes_to_remap = {n for l in x.tags.values() for n in l} + + # Get position of old tag-bearing nodes + old_tn_position = nodes.loc[nodes_to_remap, ['x', 'y', 'z']].values + + # Get nearest neighbours + dist, ix = tree.query(old_tn_position) + + # Map back onto tags + node_map = dict(zip(nodes_to_remap, new_nodes.node_id.values[ix])) + x.tags = {k: [node_map[n] for n in v] for k, v in x.tags.items()} # Set nodes x.nodes = new_nodes