Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pearson Residuals: normalization and hvg #2980

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
70 changes: 61 additions & 9 deletions src/scanpy/experimental/pp/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@
from math import sqrt
from typing import TYPE_CHECKING

import dask.array as da
import numba as nb
import numpy as np
import pandas as pd
import scipy.sparse as sp_sparse
from anndata import AnnData

from scanpy import logging as logg
from scanpy._compat import DaskArray
from scanpy._settings import Verbosity, settings
from scanpy._utils import _doc_params, check_nonnegative_integers, view_to_actual
from scanpy._utils import (
_doc_params,
check_nonnegative_integers,
clip_array,
view_to_actual,
)
from scanpy.experimental._docs import (
doc_adata,
doc_check_values,
Expand All @@ -24,7 +31,7 @@
)
from scanpy.get import _get_obs_rep
from scanpy.preprocessing._distributed import materialize_as_ndarray
from scanpy.preprocessing._utils import _get_mean_var
from scanpy.preprocessing._utils import _get_mean_var, axis_mean, axis_sum

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -92,6 +99,41 @@ def clac_clipped_res_sparse(gene: int, cell: int, value: np.float64) -> np.float
return residuals


def _calculate_res_dense_vectorized(
matrix: np.ndarray[np.float64],
*,
sums_genes: np.ndarray[np.float64],
sums_cells: np.ndarray[np.float64],
sum_total: np.float64,
clip: np.float64,
theta: np.float64,
n_genes: int, # TODO: delete if not used
n_cells: int,
) -> np.ndarray[np.float64]:
# TODO: potentially a lot to rewrite here
# TODO: is there a better/more common way we use? e.g. with dispatching?
if isinstance(matrix, DaskArray):
mu = da.outer(sums_genes, sums_cells) / sum_total
else:
mu = np.outer(sums_genes, sums_cells) / sum_total

values = matrix.T

mu_sum = values - mu
pre_res = mu_sum / np.sqrt(mu + mu * mu / theta)

# np clip doesn't work with sparse-in-dask: although pre_res is not sparse since computed as outer product
clipped_res = clip_array(pre_res, -clip, clip)

mean_clipped_res = axis_mean(clipped_res, axis=1, dtype=np.float64)
var_sum = axis_sum(
(clipped_res.T - mean_clipped_res) ** 2, axis=0, dtype=np.float64
)

residuals = var_sum / n_cells
return residuals


@nb.njit(parallel=True)
def _calculate_res_dense(
matrix,
Expand Down Expand Up @@ -175,9 +217,13 @@ def _highly_variable_pearson_residuals(

# Filter out zero genes
with settings.verbosity.override(Verbosity.error):
nonzero_genes = np.ravel(X_batch_prefilter.sum(axis=0)) != 0
adata_subset = adata_subset_prefilter[:, nonzero_genes]
X_batch = _get_obs_rep(adata_subset, layer=layer)
nonzero_genes = (
np.ravel(axis_sum(X_batch_prefilter, axis=0, dtype=np.float64)) != 0
)
# TODO: a good way of doing that? nonzero_genes is a 1xn_genes array
if isinstance(nonzero_genes, DaskArray):
nonzero_genes = nonzero_genes.compute()
X_batch = X_batch_prefilter[:, nonzero_genes]

# Prepare clipping
if clip is None:
Expand All @@ -186,7 +232,11 @@ def _highly_variable_pearson_residuals(
if clip < 0:
raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.")

if sp_sparse.issparse(X_batch):
if isinstance(X_batch, DaskArray):
# TODO: map_block with modified _calculate_res_sparse and _calculate_res_dense possible?
calculate_res = partial(_calculate_res_dense_vectorized, X_batch)

elif sp_sparse.issparse(X_batch):
X_batch = X_batch.tocsc()
X_batch.eliminate_zeros()
calculate_res = partial(
Expand All @@ -196,13 +246,15 @@ def _highly_variable_pearson_residuals(
X_batch.data.astype(np.float64),
)
else:
X_batch = np.array(X_batch, dtype=np.float64, order="F")
# TODO: why this line needed?
# X_batch = np.array(X_batch, dtype=np.float64, order="F")
calculate_res = partial(_calculate_res_dense, X_batch)

sums_genes = np.array(X_batch.sum(axis=0)).ravel()
sums_cells = np.array(X_batch.sum(axis=1)).ravel()
sums_genes = np.asarray(axis_sum(X_batch, axis=0, dtype=np.float64)).ravel()
sums_cells = np.asarray(axis_sum(X_batch, axis=1, dtype=np.float64)).ravel()
sum_total = np.sum(sums_genes)

# TODO: da.reduction with modified _calculate_res_sparse possible?
residual_gene_var = calculate_res(
sums_genes=sums_genes,
sums_cells=sums_cells,
Expand Down
41 changes: 27 additions & 14 deletions src/scanpy/experimental/pp/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

import numpy as np
from anndata import AnnData
from scipy.sparse import issparse

from ... import logging as logg
from ..._compat import DaskArray
from ..._utils import (
_doc_params,
_empty,
axis_sum,
check_nonnegative_integers,
clip_array,
view_to_actual,
)
from ...experimental._docs import (
Expand All @@ -34,8 +36,16 @@

from ..._utils import Empty

from scipy.sparse import spmatrix

def _pearson_residuals(X, theta, clip, check_values, copy: bool = False):

def _pearson_residuals(
X: np.ndarray | spmatrix | DaskArray,
theta: np.float64,
clip: np.float64,
check_values: bool,
copy: bool = False,
) -> np.ndarray | spmatrix | DaskArray:
X = X.copy() if copy else X

# check theta
Expand All @@ -56,21 +66,24 @@ def _pearson_residuals(X, theta, clip, check_values, copy: bool = False):
UserWarning,
)

if issparse(X):
sums_genes = np.sum(X, axis=0)
sums_cells = np.sum(X, axis=1)
sum_total = np.sum(sums_genes).squeeze()
else:
sums_genes = np.sum(X, axis=0, keepdims=True)
sums_cells = np.sum(X, axis=1, keepdims=True)
sum_total = np.sum(sums_genes)
sums_genes = axis_sum(X, axis=0, dtype=np.float64).reshape(1, -1)
sums_cells = axis_sum(X, axis=1, dtype=np.float64).reshape(-1, 1)
sum_total = sums_genes.sum()

mu = np.array(sums_cells @ sums_genes / sum_total)
diff = np.array(X - mu)
# TODO: Consider deduplicating computations below which are similarly required in _highly_variable_genes?
if not isinstance(X, DaskArray):
mu = np.array(sums_cells @ sums_genes / sum_total)
diff = np.array(X - mu)
else:
mu = sums_cells @ sums_genes / sum_total
diff = (
X - mu
) # here, potentially a dask sparse array and a dense sparse array are subtracted from each other
diff = diff.map_blocks(np.array, dtype=np.float64)
residuals = diff / np.sqrt(mu + mu**2 / theta)

# clip
residuals = np.clip(residuals, a_min=-clip, a_max=clip)
# residuals are dense, hence no circumventing for sparse-in-dask needed
residuals = clip_array(residuals, -clip, clip)

return residuals

Expand Down
Loading