Skip to content

Commit

Permalink
numerical precision improvements in Census builder (#716)
Browse files Browse the repository at this point in the history
* add vscode dir to gitignore

* update to latest typing extensions version

* intentional handling of non-canonical sparse matrices - see issue 715

* numerical stability and precision improvements - see issues 706 and 714

* add retries if HTTP fetch of dataset fails - see issue 311

* improve logging
  • Loading branch information
Bruce Martin authored Aug 16, 2023
1 parent c82b2d2 commit 10cefb4
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,5 @@ temp
.Rproj.user
.Rhistory

# vscode
.vscode
2 changes: 1 addition & 1 deletion tools/cellxgene_census_builder/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies= [
"typing_extensions==4.6.3",
"typing_extensions==4.7.1",
"pyarrow==12.0.1",
"pandas[performance]==2.0.3",
"anndata==0.8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import anndata
import numpy as np
import pandas as pd
import scipy.sparse as sparse

from ..util import urlcat
from .datasets import Dataset
Expand Down Expand Up @@ -83,9 +84,17 @@ def open_anndata(
X = ad.X
var = ad.var

ad = anndata.AnnData(X=X if need_X else None, obs=ad.obs, var=var, raw=None, uns=ad.uns, dtype=np.float32)
if need_X and isinstance(X, (sparse.csr_matrix, sparse.csc_matrix)) and not X.has_canonical_format:
logging.warning(f"H5AD with non-canonical X matrix at {path}")
X.sum_duplicates()

assert (
not isinstance(X, (sparse.csr_matrix, sparse.csc_matrix)) or X.has_canonical_format
), f"Found H5AD with non-canonical X matrix in {path}"

ad = anndata.AnnData(X=X if need_X else None, obs=ad.obs, var=var, raw=None, uns=ad.uns, dtype=np.float32)
assert not need_X or ad.X.shape == (len(ad.obs), len(ad.var))

# TODO: In principle, we could look up missing feature_name, but for now, just assert they exist
assert ((ad.var.feature_name != "") & (ad.var.feature_name != None)).all() # noqa: E711

Expand Down Expand Up @@ -141,10 +150,11 @@ def _filter(ad: anndata.AnnData, need_X: Optional[bool] = True) -> anndata.AnnDa
assert ad.raw is None

# This discards all other ancillary state, eg, obsm/varm/....
if not need_X:
ad = anndata.AnnData(X=None, obs=obs, var=var, dtype=np.float32)
else:
ad = anndata.AnnData(X=X, obs=obs, var=var, dtype=np.float32)
ad = anndata.AnnData(X=X, obs=obs, var=var, dtype=np.float32)

assert (
X is None or isinstance(X, np.ndarray) or X.has_canonical_format
), "Found H5AD with non-canonical X matrix"

return ad

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CENSUS_OBS_TERM_COLUMNS,
CENSUS_VAR_PLATFORM_CONFIG,
CENSUS_VAR_TERM_COLUMNS,
CENSUS_X_LAYER_NORMALIZED_FLOAT_SCALE_FACTOR,
CENSUS_X_LAYERS,
CENSUS_X_LAYERS_PLATFORM_CONFIG,
CXG_OBS_TERM_COLUMNS,
Expand Down Expand Up @@ -773,6 +774,23 @@ def _write_X_normalized(args: Tuple[str, int, int, npt.NDArray[np.float32]]) ->
"""
experiment_uri, obs_joinid_start, n, raw_sum = args
logging.info(f"Write X normalized - starting block {obs_joinid_start}")

"""
Adjust normlized layer to never encode zero-valued cells where the raw count
value is greater than zero. In our current schema configuration, FloatScaleFilter
reduces the precision of each value, storing ``round((raw_float - offset) / factor)``
as a four byte int.
To ensure non-zero raw values, which would _normally_ scale to zero under
these conditions, we add the smallest possible sigma to each value (note that
zero valued coordinates are not stored, as this is a sparse array).
Reducing the above transformation, and assuming float32 values, the smallest sigma is
1/2 of the scale factor (bits of precision). Accounting for IEEE float precision,
this reduces to:
"""
sigma = 0.5 * (CENSUS_X_LAYER_NORMALIZED_FLOAT_SCALE_FACTOR + np.finfo(np.float32).epsneg)

with soma.open(
urlcat(experiment_uri, "ms", MEASUREMENT_RNA_NAME, "X", "raw"), mode="r", context=SOMA_TileDB_Context()
) as X_raw:
Expand All @@ -791,13 +809,14 @@ def _write_X_normalized(args: Tuple[str, int, int, npt.NDArray[np.float32]]) ->
(
X_tbl["soma_dim_0"],
X_tbl["soma_dim_1"],
X_tbl["soma_data"].to_numpy() / raw_sum[X_tbl["soma_dim_0"]],
X_tbl["soma_data"].to_numpy() / raw_sum[X_tbl["soma_dim_0"]] + sigma,
)
for X_tbl in lazy_reader
),
pool=pool,
)
for soma_dim_0, soma_dim_1, soma_data in lazy_divider:
assert np.all(soma_data > 0.0), "Found unexpected zero value in raw layer data"
X_normalized.write(
pa.Table.from_arrays(
[soma_dim_0, soma_dim_1, soma_data],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@
}
CENSUS_OBS_STATS_COLUMNS = {
# Columns computed during the Census build and written to the Census obs dataframe.
"raw_sum": pa.float32(),
"raw_sum": pa.float64(),
"nnz": pa.int64(),
"raw_mean_nnz": pa.float32(),
"raw_variance_nnz": pa.float32(),
"raw_mean_nnz": pa.float64(),
"raw_variance_nnz": pa.float64(),
"n_measured_vars": pa.int64(),
}
CENSUS_OBS_TERM_COLUMNS = {
Expand Down Expand Up @@ -265,6 +265,7 @@
},
}
}
CENSUS_X_LAYER_NORMALIZED_FLOAT_SCALE_FACTOR = 1.0 / 2**18
CENSUS_X_LAYERS_PLATFORM_CONFIG = {
"raw": {
**CENSUS_DEFAULT_X_LAYERS_PLATFORM_CONFIG,
Expand All @@ -278,7 +279,7 @@
"filters": [
{
"_type": "FloatScaleFilter",
"factor": 1.0 / 2**18,
"factor": CENSUS_X_LAYER_NORMALIZED_FLOAT_SCALE_FACTOR,
"offset": 0.5,
"bytewidth": 4,
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
import urllib.parse
from typing import List, Tuple, cast

Expand Down Expand Up @@ -46,7 +47,20 @@ def _copy_file(n: int, dataset: Dataset, asset_dir: str, N: int) -> str:
dataset_path = f"{asset_dir}/{dataset_file_name}"

logging.info(f"Staging {dataset.dataset_id} ({n} of {N}) to {dataset_path}")
fs.get_file(dataset.dataset_asset_h5ad_uri, dataset_path)

sleep_for_secs = 10
last_error: aiohttp.ClientPayloadError | None = None
for attempt in range(4):
try:
fs.get_file(dataset.dataset_asset_h5ad_uri, dataset_path)
break
except aiohttp.ClientPayloadError as e:
logging.error(f"Fetch of {dataset.dataset_id} at {dataset_path} failed: {str(e)}")
last_error = e
time.sleep(2**attempt * sleep_for_secs)
else:
assert last_error is not None
raise last_error

# verify file size is as expected, if we know the size a priori
assert (dataset.asset_h5ad_filesize == -1) or (dataset.asset_h5ad_filesize == os.path.getsize(dataset_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def get_obs_stats(
if not isinstance(raw_X, sparse.csr_matrix) and not isinstance(raw_X, sparse.csc_matrix):
raise NotImplementedError(f"get_obs_stats: unsupported type {type(raw_X)}")

raw_sum = raw_X.sum(axis=1).A1
raw_sum = raw_X.sum(axis=1, dtype=np.float64).A1
nnz = raw_X.getnnz(axis=1)
raw_mean_nnz = raw_sum / nnz
raw_variance_nnz = _var(raw_X, axis=1, ddof=1)
n_measured_vars = np.full((raw_X.shape[0],), (raw_X.sum(axis=0) > 0).sum(), dtype=np.int64)
n_measured_vars = np.full((raw_X.shape[0],), (raw_X.sum(axis=0, dtype=np.float64) > 0).sum(), dtype=np.int64)

return pd.DataFrame(
data={
Expand All @@ -46,7 +46,7 @@ def get_var_stats(
else:
raise NotImplementedError(f"get_var_stats: unsupported array type {type(raw_X)}")

n_measured_obs = raw_X.shape[0] * (raw_X.sum(axis=0) > 0).A1
n_measured_obs = raw_X.shape[0] * (raw_X.sum(axis=0, dtype=np.float64) > 0).A1

return pd.DataFrame(
data={
Expand Down Expand Up @@ -87,13 +87,13 @@ def _var_ndarray(data: npt.NDArray[np.float32], ddof: int) -> float:
numba.types.Array(numba.float32, 1, "C", readonly=True),
numba.types.Array(numba.int32, 1, "C", readonly=True),
numba.int64,
numba.float32[:],
numba.float64[:],
),
numba.void(
numba.types.Array(numba.float32, 1, "C", readonly=True),
numba.types.Array(numba.int64, 1, "C", readonly=True),
numba.int64,
numba.float32[:],
numba.float64[:],
),
],
nopython=True,
Expand All @@ -103,7 +103,7 @@ def _var_matrix(
data: npt.NDArray[np.float32],
indptr: npt.NDArray[np.int32],
ddof: int,
out: npt.NDArray[np.float32],
out: npt.NDArray[np.float64],
) -> None:
n_elem = len(indptr) - 1
for i in range(n_elem):
Expand All @@ -114,14 +114,14 @@ def _var(
matrix: Union[sparse.csr_matrix, sparse.csc_matrix],
axis: int = 0,
ddof: int = 1,
) -> npt.NDArray[np.float32]:
) -> npt.NDArray[np.float64]:
if axis == 0:
n_elem, axis_len = matrix.shape
matrix = matrix.tocsc()
else:
axis_len, n_elem = matrix.shape
matrix = matrix.tocsr()

out = np.empty((axis_len,), dtype=np.float32)
out = np.empty((axis_len,), dtype=np.float64)
_var_matrix(matrix.data, matrix.indptr, ddof, out)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _validate_Xraw_contents_by_dataset(args: Tuple[str, str, Dataset, List[Exper
continue

assert _validate_X_obs_axis_stats(eb, dataset, obs_df, ad)
obs_df = obs_df[["soma_joinid"]] # save some memory
obs_df = obs_df[["soma_joinid", "raw_sum"]] # save some memory

# get the joinids for the var axis
var_df = (
Expand Down Expand Up @@ -421,6 +421,14 @@ def _validate_Xraw_contents_by_dataset(args: Tuple[str, str, Dataset, List[Exper
rows_by_position = pd.Index(obs_joinids_split).get_indexer(X_raw_obs_joinids) # type: ignore[no-untyped-call]
del X_raw_obs_joinids

# Check that raw_sum stat matches raw layer
raw_sum = np.zeros((len(obs_joinids_split),), dtype=np.float64) # 64 bit for numerical stability
np.add.at(raw_sum, rows_by_position, X_raw_data)
raw_sum = raw_sum.astype(
CENSUS_OBS_STATS_COLUMNS["raw_sum"].to_pandas_dtype()
) # back to the storage type
assert np.allclose(raw_sum, obs_df.raw_sum.iloc[idx : idx + STRIDE].to_numpy())

# Assertion 1 - the contents of the X matrix are EQUAL for all var values present in the AnnData
assert (
sparse.coo_matrix(
Expand Down Expand Up @@ -539,7 +547,9 @@ def _validate_Xnorm_layer(args: Tuple[ExperimentSpecification, str, int, int]) -

assert np.array_equal(raw["soma_dim_0"].to_numpy(), norm["soma_dim_0"].to_numpy())
assert np.array_equal(raw["soma_dim_1"].to_numpy(), norm["soma_dim_1"].to_numpy())
assert np.all(norm["soma_data"].to_numpy() >= 0.0)
# If we wrote a value, it MUST be larger than zero (i.e., represents a raw count value of 1 or greater)
assert np.all(raw["soma_data"].to_numpy() > 0.0), "Found zero value in raw layer"
assert np.all(norm["soma_data"].to_numpy() > 0.0), "Found zero value in normalized layer"

dim0 = norm["soma_dim_0"].to_numpy()
dim1 = norm["soma_dim_1"].to_numpy()
Expand Down

0 comments on commit 10cefb4

Please sign in to comment.