Skip to content

Commit

Permalink
new function: nblast_prime
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Jul 27, 2023
1 parent 25bb1e1 commit 11c6f42
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ Utilities for NBLAST
navis.nbl.update_scores
navis.nbl.compress_scores
navis.nbl.extract_matches
navis.nbl.nblast_prime


Polarity metrics
Expand Down
51 changes: 50 additions & 1 deletion navis/nbl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pandas as pd
import scipy.cluster.hierarchy as sch

from scipy.spatial.distance import squareform
from scipy.spatial.distance import squareform, pdist

from .. import config

Expand Down Expand Up @@ -366,3 +366,52 @@ def make_clusters(x, t, criterion='n_clusters', method='ward', **kwargs):
cl = sch.fcluster(Z, t=t, criterion=criterion, **kwargs)

return cl


def nblast_prime(scores, n_dim=.2, metric='euclidean'):
"""Generate a smoothed version of the NBLAST scores.
In brief:
1. Run PCA on the NBLAST scores and extract the first N components.
2. From that calulate a new similarity matrix.
Requires scikit-learn.
Parameters
----------
scores : pandas.DataFrame
The all-by-all NBLAST scores.
n_dim : float | int
The number of dimensions to use. If float (0 < n_dim < 1) will
use `scores.shape[0] * n_dim`.
metric : str
Which distance metric to use. Directly passed through to the
`scipy.spatial.distance.pdist` function.
Returns
-------
scores_new
"""
try:
from sklearn.decomposition import PCA
except ImportError:
raise ImportError('Please install scikit-learn to use `nblast_prime`:\n'
' pip3 install scikit-learn -U')

if not isinstance(scores, pd.DataFrame):
raise TypeError(f'`scores` must be pandas DataFrame, got "{type(scores)}"')

if (scores.shape[0] != scores.shape[1]) or ~np.all(scores.columns == scores.index):
logger.warning('NBLAST matrix is not symmetric - are you sure this is '
'an all-by-all matrix?')

if n_dim < 1:
n_dim = int(scores.shape[1] * n_dim)

pca = PCA(n_components=n_dim)
X_new = pca.fit_transform(scores.values)

dist = pdist(X_new, metric=metric)

return pd.DataFrame(1 - squareform(dist), index=scores.index, columns=scores.columns)

0 comments on commit 11c6f42

Please sign in to comment.