Skip to content

Commit

Permalink
Fix validation for presence matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Jan 28, 2025
1 parent 093d4ca commit 4a7ed8f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def _validate_X_layers_has_unique_coords(


def validate_X_layers_presence(
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification]
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification], assets_path: str
) -> Delayed[bool]:
"""Validate that the presence matrix accurately summarizes X[raw] for each experiment.
Expand Down Expand Up @@ -579,33 +579,30 @@ def _validate_X_layers_presence_general(experiment_specifications: list[Experime

@logit(logger, msg="{0.dataset_id}")
def _validate_X_layers_presence(
dataset: Dataset, experiment_specifications: list[ExperimentSpecification], soma_path: str
dataset: Dataset,
experiment_specifications: list[ExperimentSpecification],
soma_path: str,
assets_path: str,
) -> bool:
"""For a given dataset and experiment, confirm that the presence matrix matches contents of X[raw]."""
for es in experiment_specifications:
with open_experiment(soma_path, es) as exp:
obs_df = (
exp.obs.read(
value_filter=f"dataset_id == '{dataset.soma_joinid}'",
value_filter=f"dataset_id == '{dataset.dataset_id}'",
column_names=["soma_joinid", "n_measured_vars"],
)
.concat()
.to_pandas()
)
if len(obs_df) > 0: # skip empty experiments
X_raw = exp.ms[MEASUREMENT_RNA_NAME].X["raw"]
feature_ids = pd.Index(
exp.ms[MEASUREMENT_RNA_NAME].var.read(column_names=["feature_id"]).concat().to_pandas()
exp.ms[MEASUREMENT_RNA_NAME]
.var.read(column_names=["feature_id"])
.concat()
.to_pandas()["feature_id"]
)

presence_accumulator = np.zeros((X_raw.shape[1]), dtype=np.bool_)
for block, _ in (
X_raw.read(coords=(obs_df.soma_joinids.to_numpy(), slice(None)))
.blockwise(axis=0, size=2**20, eager=False, reindex_disable_on_axis=[0, 1])
.tables()
):
presence_accumulator[block["soma_dim_1"].to_numpy()] = 1

presence = (
exp.ms[MEASUREMENT_RNA_NAME][FEATURE_DATASET_PRESENCE_MATRIX_NAME]
.read((dataset.soma_joinid,))
Expand All @@ -614,22 +611,21 @@ def _validate_X_layers_presence(
)

# Get soma_joinids for feature in the original h5ad
orig_feature_ids = _read_var_names(dataset.dataset_h5ad_path)
orig_feature_ids = _read_var_names(f"{assets_path}/{dataset.dataset_h5ad_path}")
orig_indices = np.sort(feature_ids.get_indexer(feature_ids.intersection(orig_feature_ids)))

np.testing.assert_array_equal(presence["soma_dim_1"], orig_indices)

assert np.array_equal(presence_accumulator, presence), "Presence value does not match X[raw]"

assert (
obs_df.n_measured_vars.to_numpy() == presence_accumulator.sum()
).all(), f"{es.name}:{dataset.dataset_id} obs.n_measured_vars incorrect."

return True

check_presence_values = (
dask.bag.from_sequence(datasets, partition_size=8)
.map(_validate_X_layers_presence, soma_path=soma_path, experiment_specifications=experiment_specifications)
.map(
_validate_X_layers_presence,
soma_path=soma_path,
experiment_specifications=experiment_specifications,
assets_path=assets_path,
)
.reduction(all, all)
.to_delayed()
)
Expand Down Expand Up @@ -1114,7 +1110,7 @@ def validate_soma(args: CensusBuildArgs, client: dask.distributed.Client) -> das
dask.delayed(validate_X_layers_schema)(soma_path, experiment_specifications, eb_info),
validate_X_layers_normalized(soma_path, experiment_specifications),
validate_X_layers_has_unique_coords(soma_path, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications, assets_path),
)
)
],
Expand Down
9 changes: 5 additions & 4 deletions tools/cellxgene_census_builder/tests/anndata/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,31 @@
from ..conftest import GENE_IDS, ORGANISMS, get_anndata


def test_open_anndata(datasets: list[Dataset]) -> None:
def test_open_anndata(datasets: list[Dataset], census_build_args: CensusBuildArgs) -> None:
"""`open_anndata` should open the h5ads for each of the dataset in the argument,
and yield both the dataset and the corresponding AnnData object.
This test does not involve additional filtering steps.
The `datasets` used here have no raw layer.
"""
assets_path = census_build_args.h5ads_path.as_posix()

def _todense(X: npt.NDArray[np.float32] | sparse.spmatrix) -> npt.NDArray[np.float32]:
if isinstance(X, np.ndarray):
return X
else:
return cast(npt.NDArray[np.float32], X.todense())

result = [(d, open_anndata(d, base_path=".")) for d in datasets]
result = [(d, open_anndata(d, base_path=assets_path)) for d in datasets]
assert len(result) == len(datasets) and len(datasets) > 0
for i, (dataset, anndata_obj) in enumerate(result):
assert dataset == datasets[i]
opened_anndata = anndata.read_h5ad(dataset.dataset_h5ad_path)
opened_anndata = anndata.read_h5ad(f"{assets_path}/{dataset.dataset_h5ad_path}")
assert opened_anndata.obs.equals(anndata_obj.obs)
assert opened_anndata.var.equals(anndata_obj.var)
assert np.array_equal(_todense(opened_anndata.X), _todense(anndata_obj.X))

# also check context manager
with open_anndata(datasets[0], base_path=".") as ad:
with open_anndata(datasets[0], base_path=assets_path) as ad:
assert ad.n_obs == len(ad.obs)


Expand Down
6 changes: 3 additions & 3 deletions tools/cellxgene_census_builder/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
h5ad = get_anndata(
organism, GENE_IDS[i], no_zero_counts=False, assay_ontology_term_id=ASSAY_IDS[i], X_format=X_FORMAT[i]
)
h5ad_path = f"{assets_path}/{organism.name}_{i}.h5ad"
h5ad.write_h5ad(h5ad_path)
h5ad_name = f"{organism.name}_{i}.h5ad"
h5ad.write_h5ad(f"{assets_path}/{h5ad_name}")
datasets.append(
Dataset(
dataset_id=f"{organism.name}_{i}",
Expand All @@ -170,7 +170,7 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
collection_id=f"id_{organism.name}",
collection_name=f"collection_{organism.name}",
dataset_asset_h5ad_uri="mock",
dataset_h5ad_path=h5ad_path,
dataset_h5ad_path=h5ad_name,
dataset_version_id=f"{organism.name}_{i}_v0",
),
)
Expand Down

0 comments on commit 4a7ed8f

Please sign in to comment.