-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pearson Residuals: normalization and hvg #2980
base: main
Are you sure you want to change the base?
Changes from 7 commits
d731b71
f4a439e
a5f7485
306c8f8
eafee9f
a5eca04
6a8f43d
7365b97
7eb31ff
d58e083
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,15 @@ | |
from math import sqrt | ||
from typing import TYPE_CHECKING, Literal | ||
|
||
import dask.array as da | ||
import numba as nb | ||
import numpy as np | ||
import pandas as pd | ||
import scipy.sparse as sp_sparse | ||
from anndata import AnnData | ||
|
||
from scanpy import logging as logg | ||
from scanpy._compat import DaskArray | ||
from scanpy._settings import Verbosity, settings | ||
from scanpy._utils import _doc_params, check_nonnegative_integers, view_to_actual | ||
from scanpy.experimental._docs import ( | ||
|
@@ -24,7 +26,7 @@ | |
) | ||
from scanpy.get import _get_obs_rep | ||
from scanpy.preprocessing._distributed import materialize_as_ndarray | ||
from scanpy.preprocessing._utils import _get_mean_var | ||
from scanpy.preprocessing._utils import _get_mean_var, axis_mean, axis_sum | ||
|
||
if TYPE_CHECKING: | ||
from numpy.typing import NDArray | ||
|
@@ -90,6 +92,47 @@ | |
return residuals | ||
|
||
|
||
def _calculate_res_dense_vectorized( | ||
matrix: np.ndarray[np.float64], | ||
*, | ||
sums_genes: np.ndarray[np.float64], | ||
sums_cells: np.ndarray[np.float64], | ||
sum_total: np.float64, | ||
clip: np.float64, | ||
theta: np.float64, | ||
n_genes: int, # TODO: delete if not used | ||
n_cells: int, | ||
) -> np.ndarray[np.float64]: | ||
# TODO: potentially a lot to rewrite here | ||
# TODO: is there a better/more common way we use? e.g. with dispatching? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something I've asked myself multiple times |
||
if isinstance(matrix, DaskArray): | ||
mu = da.outer(sums_genes, sums_cells) / sum_total | ||
else: | ||
mu = np.outer(sums_genes, sums_cells) / sum_total | ||
|
||
values = matrix.T | ||
|
||
mu_sum = values - mu | ||
pre_res = mu_sum / np.sqrt(mu + mu * mu / theta) | ||
|
||
def custom_clip(x, clip_val): | ||
x[x < -clip_val] = -clip_val | ||
x[x > clip_val] = clip_val | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make a utility and use with the other implementation. Also maybe a better name? |
||
|
||
# np clip doesn't work with sparse-in-dask: although pre_res is not sparse since computed as outer product | ||
# TODO: we have such a clip function in multiple places..? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not priority here |
||
clipped_res = custom_clip(pre_res, clip) | ||
|
||
mean_clipped_res = axis_mean(clipped_res, axis=1, dtype=np.float64) | ||
var_sum = axis_sum( | ||
(clipped_res.T - mean_clipped_res) ** 2, axis=0, dtype=np.float64 | ||
) | ||
|
||
residuals = var_sum / n_cells | ||
return residuals | ||
|
||
|
||
@nb.njit(parallel=True) | ||
def _calculate_res_dense( | ||
matrix, | ||
|
@@ -173,9 +216,13 @@ | |
|
||
# Filter out zero genes | ||
with settings.verbosity.override(Verbosity.error): | ||
nonzero_genes = np.ravel(X_batch_prefilter.sum(axis=0)) != 0 | ||
adata_subset = adata_subset_prefilter[:, nonzero_genes] | ||
X_batch = _get_obs_rep(adata_subset, layer=layer) | ||
nonzero_genes = ( | ||
np.ravel(axis_sum(X_batch_prefilter, axis=0, dtype=np.float64)) != 0 | ||
) | ||
# TODO: a good way of doing that? nonzero_genes is a 1xn_genes array | ||
if isinstance(nonzero_genes, DaskArray): | ||
nonzero_genes = nonzero_genes.compute() | ||
X_batch = X_batch_prefilter[:, nonzero_genes] | ||
|
||
# Prepare clipping | ||
if clip is None: | ||
|
@@ -184,7 +231,11 @@ | |
if clip < 0: | ||
raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.") | ||
|
||
if sp_sparse.issparse(X_batch): | ||
if isinstance(X_batch, DaskArray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having _calculate_res_dense and _calculate_res_sparse was there before, maybe singledispatch cleaner here. the non-jitted dask computation is ~10x slower at the moment. |
||
# TODO: map_block with modified _calculate_res_sparse and _calculate_res_dense possible? | ||
calculate_res = partial(_calculate_res_dense_vectorized, X_batch) | ||
|
||
elif sp_sparse.issparse(X_batch): | ||
X_batch = X_batch.tocsc() | ||
X_batch.eliminate_zeros() | ||
calculate_res = partial( | ||
|
@@ -194,13 +245,15 @@ | |
X_batch.data.astype(np.float64), | ||
) | ||
else: | ||
X_batch = np.array(X_batch, dtype=np.float64, order="F") | ||
# TODO: why this line needed? | ||
# X_batch = np.array(X_batch, dtype=np.float64, order="F") | ||
calculate_res = partial(_calculate_res_dense, X_batch) | ||
|
||
sums_genes = np.array(X_batch.sum(axis=0)).ravel() | ||
sums_cells = np.array(X_batch.sum(axis=1)).ravel() | ||
sums_genes = np.asarray(axis_sum(X_batch, axis=0, dtype=np.float64)).ravel() | ||
sums_cells = np.asarray(axis_sum(X_batch, axis=1, dtype=np.float64)).ravel() | ||
sum_total = np.sum(sums_genes) | ||
|
||
# TODO: da.reduction with modified _calculate_res_sparse possible? | ||
residual_gene_var = calculate_res( | ||
sums_genes=sums_genes, | ||
sums_cells=sums_cells, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,15 @@ | |
|
||
import numpy as np | ||
from anndata import AnnData | ||
from scipy.sparse import issparse | ||
from scipy.sparse import issparse, spmatrix | ||
|
||
from ... import logging as logg | ||
from ..._compat import DaskArray | ||
from ..._utils import ( | ||
Empty, | ||
_doc_params, | ||
_empty, | ||
axis_sum, | ||
check_nonnegative_integers, | ||
view_to_actual, | ||
) | ||
|
@@ -33,7 +35,13 @@ | |
from collections.abc import Mapping | ||
|
||
|
||
def _pearson_residuals(X, theta, clip, check_values, copy: bool = False): | ||
def _pearson_residuals( | ||
X: np.ndarray | spmatrix | DaskArray, | ||
theta: np.float64, | ||
clip: np.float64, | ||
check_values: bool, | ||
copy: bool = False, | ||
) -> np.ndarray | spmatrix | DaskArray: | ||
X = X.copy() if copy else X | ||
|
||
# check theta | ||
|
@@ -54,21 +62,34 @@ def _pearson_residuals(X, theta, clip, check_values, copy: bool = False): | |
UserWarning, | ||
) | ||
|
||
if issparse(X): | ||
sums_genes = np.sum(X, axis=0) | ||
sums_cells = np.sum(X, axis=1) | ||
sum_total = np.sum(sums_genes).squeeze() | ||
if not isinstance(X, DaskArray): | ||
if issparse(X): | ||
sums_genes = np.sum(X, axis=0) | ||
sums_cells = np.sum(X, axis=1) | ||
sum_total = np.sum(sums_genes).squeeze() | ||
else: | ||
sums_genes = np.sum(X, axis=0, keepdims=True) | ||
sums_cells = np.sum(X, axis=1, keepdims=True) | ||
sum_total = np.sum(sums_genes).squeeze() | ||
else: | ||
sums_genes = np.sum(X, axis=0, keepdims=True) | ||
sums_cells = np.sum(X, axis=1, keepdims=True) | ||
sum_total = np.sum(sums_genes) | ||
|
||
mu = np.array(sums_cells @ sums_genes / sum_total) | ||
diff = np.array(X - mu) | ||
sums_genes = axis_sum(X, axis=0, dtype=np.float64).reshape(1, -1) | ||
sums_cells = axis_sum(X, axis=1, dtype=np.float64).reshape(-1, 1) | ||
sum_total = sums_genes.sum() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These functions should work across the types, no? They rely on single dispatch of the first argument and just call the other imeplementations |
||
|
||
# TODO: Consider deduplicating computations below which are similarly required in _highly_variable_genes? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's actually quite some duplication, think this might be reduced |
||
if not isinstance(X, DaskArray): | ||
mu = np.array(sums_cells @ sums_genes / sum_total) | ||
diff = np.array(X - mu) | ||
else: | ||
mu = sums_cells @ sums_genes / sum_total | ||
diff = ( | ||
X - mu | ||
) # here, potentially a dask sparse array and a dense sparse array are subtracted from each other | ||
diff = diff.map_blocks(np.array, dtype=np.float64) | ||
residuals = diff / np.sqrt(mu + mu**2 / theta) | ||
|
||
# clip | ||
residuals = np.clip(residuals, a_min=-clip, a_max=clip) | ||
# residuals are dense, hence no circumventing for sparse-in-dask needed | ||
residuals = residuals.clip(-clip, clip) | ||
|
||
return residuals | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from string import ascii_letters | ||
from typing import Callable, Literal | ||
|
||
import dask.array as da | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
@@ -17,6 +18,8 @@ | |
from scanpy.testing._pytest.marks import needs | ||
from scanpy.testing._pytest.params import ARRAY_TYPES | ||
|
||
from ..preprocessing._utils import _get_mean_var | ||
|
||
FILE = Path(__file__).parent / Path("_scripts/seurat_hvg.csv") | ||
FILE_V3 = Path(__file__).parent / Path("_scripts/seurat_hvg_v3.csv.gz") | ||
FILE_V3_BATCH = Path(__file__).parent / Path("_scripts/seurat_hvg_v3_batch.csv") | ||
|
@@ -136,14 +139,29 @@ def _check_pearson_hvg_columns(output_df: pd.DataFrame, n_top_genes: int): | |
assert np.nanmax(output_df["highly_variable_rank"].to_numpy()) <= n_top_genes - 1 | ||
|
||
|
||
def test_pearson_residuals_inputchecks(pbmc3k_parametrized_small): | ||
adata = pbmc3k_parametrized_small() | ||
@pytest.mark.parametrize("array_type", ARRAY_TYPES) | ||
@pytest.mark.parametrize("dtype", ["float32", "int64"]) | ||
def test_pearson_residuals_inputchecks(array_type, dtype): | ||
# TODO: do we have a preferred way of making such a small dataset, wich the array types option? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copied the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so? |
||
adata = pbmc3k() | ||
sc.pp.filter_genes(adata, min_cells=10) | ||
adata.X = array_type(adata.X).astype(dtype) | ||
|
||
# depending on check_values, warnings should be raised for non-integer data | ||
if adata.X.dtype == "float32": | ||
adata_noninteger = adata.copy() | ||
x, y = np.nonzero(adata_noninteger.X) | ||
adata_noninteger.X[x[0], y[0]] = 0.5 | ||
|
||
def clip(x, min, max): | ||
x[x < min] = min | ||
x[x > max] = max | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely a feather in the cap for making this a utility. |
||
|
||
if "dask" in array_type.__name__: | ||
adata_noninteger.X = da.map_blocks( | ||
clip, adata.X, 0, 0.5, dtype=adata.X.dtype | ||
) | ||
else: | ||
adata_noninteger.X = clip(adata_noninteger.X, 0, 0.5) | ||
|
||
_check_check_values_warnings( | ||
function=sc.experimental.pp.highly_variable_genes, | ||
|
@@ -170,16 +188,21 @@ def test_pearson_residuals_inputchecks(pbmc3k_parametrized_small): | |
) | ||
|
||
|
||
@pytest.mark.parametrize("array_type", ARRAY_TYPES) | ||
# @pytest.mark.parametrize("dtype", ["float32", "int64"]) | ||
@pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) | ||
@pytest.mark.parametrize( | ||
"clip", [None, np.Inf, 30], ids=["noclip", "infclip", "30clip"] | ||
) | ||
@pytest.mark.parametrize("theta", [100, np.Inf], ids=["100theta", "inftheta"]) | ||
@pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) | ||
def test_pearson_residuals_general( | ||
pbmc3k_parametrized_small, subset, clip, theta, n_top_genes | ||
): | ||
adata = pbmc3k_parametrized_small() | ||
@pytest.mark.parametrize( | ||
"n_top_genes", [100, 200], ids=["100n", "200n"] | ||
) # TODO: is this necessary? | ||
def test_pearson_residuals_general(array_type, subset, clip, theta, n_top_genes): | ||
adata = pbmc3k() | ||
sc.pp.filter_genes(adata, min_cells=10) | ||
adata.X = array_type(adata.X) | ||
|
||
# cleanup var | ||
del adata.var | ||
|
||
|
@@ -188,7 +211,9 @@ def test_pearson_residuals_general( | |
adata, clip=clip, theta=theta, inplace=False | ||
) | ||
assert isinstance(residuals_res, dict) | ||
residual_variances_reference = np.var(residuals_res["X"], axis=0) | ||
_, residual_variances_reference = _get_mean_var(residuals_res["X"], axis=0) | ||
# if "dask" in array_type.__name__: | ||
# residual_variances_reference = residual_variances_reference#.compute(scheduler='single-threaded') | ||
|
||
if subset: | ||
# lazyly sort by residual variance and take top N | ||
|
@@ -238,10 +263,13 @@ def test_pearson_residuals_general( | |
assert np.allclose( | ||
output_df["residual_variances"].to_numpy()[sort_output_idx], | ||
residual_variances_reference, | ||
rtol=1e-3, | ||
) | ||
else: | ||
assert np.allclose( | ||
output_df["residual_variances"].to_numpy(), residual_variances_reference | ||
output_df["residual_variances"].to_numpy(), | ||
residual_variances_reference, | ||
rtol=1e-3, | ||
) | ||
|
||
# check hvg flag | ||
|
@@ -258,10 +286,17 @@ def test_pearson_residuals_general( | |
_check_pearson_hvg_columns(output_df, n_top_genes) | ||
|
||
|
||
@pytest.mark.parametrize("array_type", ARRAY_TYPES) | ||
@pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) | ||
@pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) | ||
def test_pearson_residuals_batch(pbmc3k_parametrized_small, subset, n_top_genes): | ||
adata = pbmc3k_parametrized_small() | ||
@pytest.mark.parametrize( | ||
"n_top_genes", [100, 200], ids=["100n", "200n"] | ||
) # TODO: is this necessary? | ||
def test_pearson_residuals_batch(array_type, subset, n_top_genes): | ||
adata = pbmc3k() | ||
sc.pp.filter_genes(adata, min_cells=10) | ||
adata.obs["batch"] = np.random.choice(3, adata.n_obs) | ||
adata.X = array_type(adata.X) | ||
|
||
# cleanup var | ||
del adata.var | ||
n_genes = adata.shape[1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this called
dense
if it operates on sparse inputs? Also why are the type hintsnp.ndarray
then? Also, why the extra dtype? Can we guarantee that?