diff --git a/navis/interfaces/r.py b/navis/interfaces/r.py index 6571eadf..f3fb08a0 100644 --- a/navis/interfaces/r.py +++ b/navis/interfaces/r.py @@ -620,7 +620,10 @@ def nblast_allbyall(x: 'core.NeuronList', # type: ignore # doesn't like n_core def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], # type: ignore # doesn't like n_cores default target: Optional[str] = None, - mean_scores: bool = False, + scores: Union[Literal['forward'], + Literal['mean'], + Literal['min'], + Literal['max']] = 'forward', n_cores: int = os.cpu_count() - 2, normalized: bool = True, use_alpha: bool = False, @@ -643,9 +646,15 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], 3. an R file object (e.g. ``robjects.r("load('.../gmrdps.rds')")``) 4. a URL to load the list from (e.g. ``'http://.../gmrdps.rds'``) 5. "flycircuit" - mean_scores : bool - If True, will run query->target NBLAST followed by a - target->query NBLAST and return the mean scores. + + scores : 'forward' | 'mean' | 'min' | 'max' + Determines the final scores: + + - 'forward' (default) returns query->target scores + - 'mean' returns the mean of query->target and target->query scores + - 'min' returns the minium between query->target and target->query scores + - 'max' returns the maximum between query->target and target->query scores + n_cores : int, optional Number of cores to use for nblasting. Default is ``os.cpu_count() - 2``. @@ -685,6 +694,10 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], >>> scores = r.nblast(nl_um) """ + utils.eval_param(scores, + name='scores', + allowed_values=('forward', 'mean', 'min', 'max')) + if n_cores > 1: doMC = importr('doMC') doMC.registerDoMC(int(n_cores)) @@ -749,7 +762,7 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], res = res.T # Calculate reverse scores - if mean_scores: + if scores != 'forward': logger.info('Running reverse nblast.') scr = r_nblast.nblast(target_dps, query_dps, **{'normalised': normalized, @@ -761,7 +774,15 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], if isinstance(scr, dict): scr = _parse_nblast_dict(scr) - res = (res + scr) / 2 + # Make 100% sure that the order is correct + scr = scr.loc[res.index, res.columns] + + if scores == 'mean': + res = (res + scr) / 2 + elif scores == 'min': + res.loc[:, :] = np.dstack((res.values, scr.values)).min(axis=2) + elif scores == 'max': + res.loc[:, :] = np.dstack((res.values, scr.values)).max(axis=2) return res