Skip to content

Commit

Permalink
[python] Fix HVG crasher (#821)
Browse files Browse the repository at this point in the history
* suppress invalid value exception in meanvar

* fix crash when n_cells per batch was one

* improve unit tests for hvg

* speed up standard unit tests

* lower memory use

* add additional unit test comments

* lint

* add comment
  • Loading branch information
Bruce Martin authored Oct 25, 2023
1 parent 8a7a341 commit e81ea42
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def _highly_variable_genes_seurat_v3(
N = n_samples[batch]

not_const = v > 0
if N == 1 or not not_const.any():
reg_std[batch].fill(1)
clip_val[batch, :] = u
continue

y = np.log10(v[not_const])
x = np.log10(u[not_const])

Expand Down Expand Up @@ -185,9 +190,12 @@ def _highly_variable_genes_seurat_v3(
acc.update(var_dim, data)

counts_sum, squared_counts_sum = acc.finalize()
norm_gene_vars = (1 / ((n_samples - 1) * np.square(reg_std.T))).T * (
(n_samples * np.square(batches_u.T)).T + squared_counts_sum - 2 * counts_sum * batches_u
)
# Don't raise python errors for 0/0, etc, just generate Inf/NaN
with np.errstate(divide="ignore", invalid="ignore"):
norm_gene_vars = (1 / ((n_samples - 1) * np.square(reg_std.T))).T * (
(n_samples * np.square(batches_u.T)).T + squared_counts_sum - 2 * counts_sum * batches_u
)
norm_gene_vars[np.isnan(norm_gene_vars)] = 0
del acc, counts_sum, squared_counts_sum

ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def finalize(

# Note: if N-ddof is less than or equal to 0, we will return Inf - this is consistent
# with the numpy.var behavior.
with np.errstate(divide="ignore"):
with np.errstate(divide="ignore", invalid="ignore"):
batches_var = (self.M2.T / np.maximum(0, (self.n_samples - self.ddof))).T

# accum all batches using Chan's
Expand Down
69 changes: 58 additions & 11 deletions api/python/cellxgene_census/tests/experimental/pp/test_hvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
"experiment_name,obs_value_filter",
[
(
"mus_musculus",
'is_primary_data == True and tissue_general == "liver"',
),
pytest.param(
"mus_musculus",
'is_primary_data == True and tissue_general == "skin of body"',
marks=pytest.mark.expensive,
),
pytest.param(
"mus_musculus",
Expand All @@ -31,26 +36,46 @@
),
pytest.param(
"mus_musculus",
'is_primary_data == True and assay == "Smart-seq"',
'is_primary_data == True and assay == "Smart-seq2"',
marks=pytest.mark.expensive,
),
],
)
@pytest.mark.parametrize("n_top_genes", (5, 500))
@pytest.mark.parametrize("n_top_genes", (50, 500))
@pytest.mark.parametrize(
"batch_key",
(
None,
"dataset_id",
["suspension_type", "assay_ontology_term_id"],
("suspension_type", "assay_ontology_term_id", "dataset_id"),
["dataset_id", "assay_ontology_term_id", "suspension_type", "donor_id"],
pytest.param(
("suspension_type", "assay_ontology_term_id", "dataset_id"),
marks=pytest.mark.expensive,
),
pytest.param(
["dataset_id", "assay_ontology_term_id", "suspension_type", "donor_id"],
marks=pytest.mark.expensive,
),
),
)
@pytest.mark.parametrize(
"span",
(
pytest.param(None, marks=pytest.mark.expensive),
0.5,
),
)
@pytest.mark.parametrize(
"version",
(
"latest",
pytest.param("stable", marks=pytest.mark.expensive),
),
)
@pytest.mark.parametrize("span", (None, 0.5))
def test_hvg_vs_scanpy(
n_top_genes: int,
obs_value_filter: str,
version: str,
experiment_name: str,
batch_key: Optional[Union[str, tuple[str], list[str]]],
span: float,
Expand All @@ -66,7 +91,7 @@ def test_hvg_vs_scanpy(
if span is not None:
kwargs["span"] = span

with cellxgene_census.open_soma(census_version="stable", context=small_mem_context) as census:
with cellxgene_census.open_soma(census_version=version, context=small_mem_context) as census:
# Get the highly variable genes
with census["census_data"][experiment_name].axis_query(
measurement_name="RNA",
Expand All @@ -83,7 +108,16 @@ def test_hvg_vs_scanpy(
)
kwargs["batch_key"] = "the_batch_key"

scanpy_hvg = sc.pp.highly_variable_genes(adata, inplace=False, **kwargs)
try:
scanpy_hvg = sc.pp.highly_variable_genes(adata, inplace=False, **kwargs)
except ZeroDivisionError:
# There are test cases where ScanPy will fail, rendering this "compare vs scanpy"
# test moot. The known cases involve overly partitioned data that results in batches
# with a very small number of samples (which manifest as a divide by zero error).
# In these known cases, go ahead and perform the HVG (above), but skip the compare
# assertions below.
pytest.skip("ScanPy generated an error, likely due to batches with 1 sample")

scanpy_hvg.index.name = "soma_joinid"
scanpy_hvg.index = scanpy_hvg.index.astype(int)
assert len(scanpy_hvg) == len(hvg)
Expand All @@ -104,16 +138,13 @@ def test_hvg_vs_scanpy(
rtol=1e-2,
equal_nan=True,
)
if "highly_variable_nbatches" in scanpy_hvg.keys() or "highly_variable_nbatches" in hvg.keys():
assert (hvg.highly_variable_nbatches == scanpy_hvg.highly_variable_nbatches).all()
assert np.allclose(
hvg.variances_norm.to_numpy(),
scanpy_hvg.variances_norm.to_numpy(),
atol=1e-5,
rtol=1e-2,
equal_nan=True,
)
assert (hvg.highly_variable == scanpy_hvg.highly_variable).all()

# Online calculation of normalized variance will differ slightly from ScanPy's calculation,
# so look for rank of HVGs to be close, but not identical. Don't worry about the non-HVGs
Expand All @@ -126,6 +157,22 @@ def test_hvg_vs_scanpy(
/ n_top_genes
) < 0.01

# Ranking will also have some noise, so check that ranking is close in the highly variable subset
scanpy_rank = scanpy_hvg.highly_variable_rank.copy()
hvg_rank = hvg.highly_variable_rank.copy()
hvg_rank[pd.isna(hvg_rank)] = n_top_genes
scanpy_rank[pd.isna(scanpy_rank)] = n_top_genes
rank_diff = (hvg_rank - scanpy_rank)[hvg.highly_variable]
# +/- 5 in ranking, choosen arbitrarily
assert rank_diff.min() >= -5 and rank_diff.max() <= 5

if "highly_variable_nbatches" in scanpy_hvg.keys() or "highly_variable_nbatches" in hvg.keys():
# Also subject to noise, so look for "close" match
nbatches_diff = hvg.highly_variable_nbatches - scanpy_hvg.highly_variable_nbatches
assert nbatches_diff.min() >= -2 and nbatches_diff.max() <= 2

assert (hvg.highly_variable == scanpy_hvg.highly_variable).all()


@pytest.mark.experimental
@pytest.mark.live_corpus
Expand Down Expand Up @@ -248,7 +295,7 @@ def batch_key_func(srs: pd.Series[Any]) -> str:
assert set(srs.keys()) == keys
return "batch0"

with cellxgene_census.open_soma(census_version="stable", context=small_mem_context) as census:
with cellxgene_census.open_soma(census_version="latest", context=small_mem_context) as census:
# Get the highly variable genes
with census["census_data"]["mus_musculus"].axis_query(
measurement_name="RNA",
Expand Down

0 comments on commit e81ea42

Please sign in to comment.