Skip to content

Commit

Permalink
improve r.nblast and fix potential bug
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Dec 22, 2020
1 parent 5818658 commit 32aa702
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions navis/interfaces/r.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,10 @@ def nblast_allbyall(x: 'core.NeuronList', # type: ignore # doesn't like n_core

def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'], # type: ignore # doesn't like n_cores default
target: Optional[str] = None,
mean_scores: bool = False,
scores: Union[Literal['forward'],
Literal['mean'],
Literal['min'],
Literal['max']] = 'forward',
n_cores: int = os.cpu_count() - 2,
normalized: bool = True,
use_alpha: bool = False,
Expand All @@ -643,9 +646,15 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
3. an R file object (e.g. ``robjects.r("load('.../gmrdps.rds')")``)
4. a URL to load the list from (e.g. ``'http://.../gmrdps.rds'``)
5. "flycircuit"
mean_scores : bool
If True, will run query->target NBLAST followed by a
target->query NBLAST and return the mean scores.
scores : 'forward' | 'mean' | 'min' | 'max'
Determines the final scores:
- 'forward' (default) returns query->target scores
- 'mean' returns the mean of query->target and target->query scores
- 'min' returns the minium between query->target and target->query scores
- 'max' returns the maximum between query->target and target->query scores
n_cores : int, optional
Number of cores to use for nblasting. Default is
``os.cpu_count() - 2``.
Expand Down Expand Up @@ -685,6 +694,10 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
>>> scores = r.nblast(nl_um)
"""
utils.eval_param(scores,
name='scores',
allowed_values=('forward', 'mean', 'min', 'max'))

if n_cores > 1:
doMC = importr('doMC')
doMC.registerDoMC(int(n_cores))
Expand Down Expand Up @@ -749,7 +762,7 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
res = res.T

# Calculate reverse scores
if mean_scores:
if scores != 'forward':
logger.info('Running reverse nblast.')
scr = r_nblast.nblast(target_dps, query_dps,
**{'normalised': normalized,
Expand All @@ -761,7 +774,15 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
if isinstance(scr, dict):
scr = _parse_nblast_dict(scr)

res = (res + scr) / 2
# Make 100% sure that the order is correct
scr = scr.loc[res.index, res.columns]

if scores == 'mean':
res = (res + scr) / 2
elif scores == 'min':
res.loc[:, :] = np.dstack((res.values, scr.values)).min(axis=2)
elif scores == 'max':
res.loc[:, :] = np.dstack((res.values, scr.values)).max(axis=2)

return res

Expand Down

0 comments on commit 32aa702

Please sign in to comment.