Skip to content

Commit

Permalink
[builder] refactor builder to utilize Dask (#964)
Browse files Browse the repository at this point in the history
* first cut at fixed budget anndata handling

* memory

* refactor consolidate

* checkpoint refactoring for memory budget

* always have at least one worker

* smaller strides

* improve memory diagnostics

* autoupdate precommit modules

* fix bug in no-consolidate

* update test to match new manifest field requirements

* remove unused code

* further memory budget refinement and tuning

* add missing __len__ to AnnDataProxy

* further memory usage reduction

* preserve column ordering in dataframe loading

* comments and cleanup

* add extra verbose logging level

* back out parallel consolidation for now

* added a todo reminder

* a few more memory tuning tweaks

* simplify open_anndata interface

* pr review

* clean up logger

* lint

* snapshot initial dask explorations

* pr feedback

* additional dask refactoring

* fix empty slice bug

* additional refactoring to use dask

* refine async consolidator

* checkpoint progress

* additional X layer processing refinement

* fix pytest

* fix mocks in test

* update package deps for builder

* comment

* improve dataset shuffle

* tuning

* update to latest tiledb

* update to latest tiledb

* cleanup

* additional scale updates

* fix numpy cast error

* shorten step count for async consolidator

* additional cleanup

* update to latest cellxgene_census

* update tiledbsoma dep

* lint

* tune thread count cap

* update to latest tiledbsoma

* lint

* remove debugging code
  • Loading branch information
Bruce Martin authored Feb 23, 2024
1 parent a4cdcf4 commit 38fff2d
Show file tree
Hide file tree
Showing 21 changed files with 1,044 additions and 940 deletions.
10 changes: 6 additions & 4 deletions tools/cellxgene_census_builder/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ dependencies= [
"numpy==1.24.4",
# IMPORTANT: consider TileDB format compat before advancing this version. It is important that
# IMPORTANT: the tiledbsoma version lag that used in cellxgene-census package.
"tiledbsoma==1.7.0",
"cellxgene-census==1.10.1",
"tiledbsoma==1.7.2",
"cellxgene-census==1.10.2",
"scipy==1.12.0",
"fsspec[http]==2023.12.2",
"s3fs==2023.12.2",
"fsspec[http]==2024.2.0",
"s3fs==2024.2.0",
"requests==2.31.0",
"aiohttp==3.9.3",
"Cython", # required by owlready2
Expand All @@ -47,6 +47,8 @@ dependencies= [
"psutil==5.9.8",
"pyyaml==6.0.1",
"numba==0.58.1",
"dask==2024.2.0",
"distributed==2024.2.0",
]

[project.urls]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import s3fs

from .build_soma import build as build_a_soma
from .build_soma import validate as validate_a_soma
from .build_state import CENSUS_BUILD_CONFIG, CENSUS_BUILD_STATE, CensusBuildArgs, CensusBuildConfig, CensusBuildState
from .util import log_process_resource_status, process_init, start_resource_logger, urlcat

Expand Down Expand Up @@ -54,12 +53,17 @@ def main() -> int:
do_prebuild_set_defaults,
do_prebuild_checks,
do_build_soma,
do_validate_soma,
do_create_reports,
do_data_copy,
do_report_copy,
do_log_copy,
],
"test-build": [ # for testing only
do_prebuild_set_defaults,
do_prebuild_checks,
do_build_soma,
do_create_reports,
],
"mock-build": [
do_mock_build,
do_data_copy,
Expand All @@ -71,7 +75,6 @@ def main() -> int:
do_prebuild_set_defaults,
do_prebuild_checks,
do_build_soma,
do_validate_soma,
do_create_reports,
do_data_copy,
do_the_release,
Expand Down Expand Up @@ -157,16 +160,6 @@ def do_prebuild_checks(args: CensusBuildArgs) -> bool:
def do_build_soma(args: CensusBuildArgs) -> bool:
if (cc := build_a_soma(args)) != 0:
logger.critical(f"Build of census failed with code {cc}.")
return False

return True


def do_validate_soma(args: CensusBuildArgs) -> bool:
if not validate_a_soma(args):
logger.critical("Validation of the census build has failed.")
return False

return True


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .build_soma import build
from .validate_soma import validate

__all__ = [
"build",
"validate",
]
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Dev/test CLI for the build_soma package. This is not used for builds."""

import argparse
import logging
import pathlib
Expand All @@ -8,8 +10,6 @@

from ..build_state import CensusBuildArgs, CensusBuildConfig
from ..util import log_process_resource_status, process_init, start_resource_logger
from .build_soma import build
from .validate_soma import validate

logger = logging.getLogger(__name__)

Expand All @@ -27,17 +27,25 @@ def main() -> int:
}
)
args = CensusBuildArgs(working_dir=pathlib.PosixPath(cli_args.working_dir), config=default_config)
# enable the dashboard depending on verbosity level
if args.config.verbose:
args.config.dashboard = True

process_init(args)
logger.info(args)

start_resource_logger()

cc = 0
if cli_args.subcommand == "build":
cc = build(args)
from . import build

if cc == 0 and (cli_args.subcommand == "validate" or cli_args.validate):
validate(args)
cc = build(args, validate=cli_args.validate)
elif cli_args.subcommand == "validate":
from .validate_soma import validate

# stand-alone validate - requires previously built objects.
cc = validate(args)

log_process_resource_status(level=logging.INFO)
return cc
Expand Down Expand Up @@ -89,6 +97,12 @@ def create_args_parser() -> argparse.ArgumentParser:
default=True,
help="Consolidate TileDB objects after build",
)
build_parser.add_argument(
"--dashboard",
action=argparse.BooleanOptionalAction,
default=False,
help="Start Dask dashboard",
)
build_parser.add_argument(
"--dataset_id_blocklist_uri",
help="Dataset blocklist URI",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from contextlib import AbstractContextManager
from functools import cached_property
from os import PathLike
from types import TracebackType
from typing import Any, Protocol, Self, TypedDict, cast

import h5py
Expand Down Expand Up @@ -36,11 +38,13 @@ def _slice_index(prev: Index1D, new: Index1D, length: int) -> slice | npt.NDArra
assert rng.stop >= 0
return slice(rng.start, rng.stop, rng.step)
else:
return np.arange(*prev.indices(length))[new]
idx = np.arange(*prev.indices(length))[new]
return idx if len(idx) else slice(0, 0)
elif isinstance(prev, np.ndarray):
if prev.dtype == np.bool_: # a mask
prev = np.nonzero(prev)[0].astype(np.int64)
return cast(npt.NDArray[np.int64], prev[new])
idx = cast(npt.NDArray[np.int64], prev[new])
return idx if len(idx) else slice(0, 0)

# else confusion
raise IndexError("Unsupported indexing types")
Expand All @@ -57,7 +61,7 @@ def _normed_index(idx: Index) -> tuple[Index1D, Index1D]:
raise IndexError("Indexing supported on two dimensions only")


class AnnDataProxy:
class AnnDataProxy(AbstractContextManager["AnnDataProxy"]):
"""Recommend using `open_anndata()` rather than instantiating this class directly.
AnnData-like proxy for the version 0.1.0 AnnData H5PY file encoding (aka H5AD).
Expand All @@ -83,6 +87,7 @@ class AnnDataProxy:
_obs: pd.DataFrame
_var: pd.DataFrame
_X: h5py.Dataset | CSRDataset | CSCDataset
_file: h5py.File | None

def __init__(
self,
Expand All @@ -97,16 +102,25 @@ def __init__(
self.filename = filename

if view_of is None:
self._file = h5py.File(self.filename, mode="r")
self._obs, self._var, self._X = self._load_h5ad(obs_column_names, var_column_names)
self._obs_idx: slice | npt.NDArray[np.int64] = slice(None)
self._var_idx: slice | npt.NDArray[np.int64] = slice(None)
else:
self._file = None
self._obs, self._var, self._X = (view_of._obs, view_of._var, view_of._X)
assert obs_idx is not None
assert var_idx is not None
self._obs_idx = obs_idx
self._var_idx = var_idx

def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
if self._file:
self._file.close()
self._file = None

@property
def X(self) -> sparse.spmatrix | npt.NDArray[np.integer[Any] | np.floating[Any]]:
# For CS*Dataset, slice first on the major axis, then on the minor, as
Expand Down Expand Up @@ -156,6 +170,19 @@ def __getitem__(self, key: Index) -> "AnnDataProxy":
vdx = _slice_index(self._var_idx, vdx, self.n_vars)
return AnnDataProxy(self.filename, view_of=self, obs_idx=odx, var_idx=vdx)

def get_estimated_density(self) -> float:
"""Return an estimated density for the H5AD, based upon the full file density.
This is NOT the density for any given slice.
Approach: divide the whole file nnz by the product of the shape.
"""
nnz: int
if isinstance(self._X, CSRDataset | CSCDataset):
nnz = self._X.group["data"].size
else:
nnz = self._X.size
return nnz / (self.n_obs * self.n_vars)

def _load_dataframe(self, elem: h5py.Group, column_names: tuple[str, ...] | None) -> pd.DataFrame:
# if reading all, just use the built-in
if not column_names:
Expand Down Expand Up @@ -196,27 +223,27 @@ def _load_h5ad(
This code utilizes the AnnData on-disk spec and several experimental API (as of 0.10.0).
Spec: https://anndata.readthedocs.io/en/latest/fileformat-prose.html
"""
file = h5py.File(self.filename, mode="r")
assert isinstance(self._file, h5py.File)

# Known to be compatible with this AnnData file encoding
assert (
file.attrs["encoding-type"] == "anndata" and file.attrs["encoding-version"] == "0.1.0"
self._file.attrs["encoding-type"] == "anndata" and self._file.attrs["encoding-version"] == "0.1.0"
), "Unsupported AnnData encoding-type or encoding-version - likely indicates file was created with an unsupported AnnData version"

# Verify we are reading the expected CxG schema version.
schema_version = read_elem(file["uns/schema_version"])
schema_version = read_elem(self._file["uns/schema_version"])
if schema_version != CXG_SCHEMA_VERSION:
raise ValueError(
f"{self.filename} -- incorrect CxG schema version (got {schema_version}, expected {CXG_SCHEMA_VERSION})"
)

obs = self._load_dataframe(file["obs"], obs_column_names)
if "raw" in file:
var = self._load_dataframe(file["raw/var"], var_column_names)
X = file["raw/X"]
obs = self._load_dataframe(self._file["obs"], obs_column_names)
if "raw" in self._file:
var = self._load_dataframe(self._file["raw/var"], var_column_names)
X = self._file["raw/X"]
else:
var = self._load_dataframe(file["var"], var_column_names)
X = file["X"]
var = self._load_dataframe(self._file["var"], var_column_names)
X = self._file["X"]

if isinstance(X, h5py.Group):
X = sparse_dataset(X)
Expand All @@ -232,12 +259,13 @@ def _load_h5ad(


def open_anndata(
base_path: str,
dataset: Dataset,
dataset: str | Dataset,
*,
base_path: str | None = None,
include_filter_columns: bool = False,
obs_column_names: tuple[str, ...] | None = None,
var_column_names: tuple[str, ...] | None = None,
filter_spec: AnnDataFilterSpec | None = None,
) -> AnnDataProxy:
"""Open the dataset and return an AnnData-like AnnDataProxy object.
Expand All @@ -247,15 +275,19 @@ def open_anndata(
include_filter_columns: if True, ensure that any obs/var columns required for H5AD filtering are included. If
False (default), only load the columsn specified by the user.
"""
h5ad_path = dataset.dataset_h5ad_path if isinstance(dataset, Dataset) else dataset
h5ad_path = urlcat(base_path, h5ad_path) if base_path is not None else h5ad_path

include_filter_columns = include_filter_columns or (filter_spec is not None)
if include_filter_columns:
obs_column_names = tuple(set(CXG_OBS_COLUMNS_MINIMUM_READ + (obs_column_names or ())))
var_column_names = tuple(set(CXG_VAR_COLUMNS_MINIMUM_READ + (var_column_names or ())))

return AnnDataProxy(
urlcat(base_path, dataset.dataset_h5ad_path),
obs_column_names=obs_column_names,
var_column_names=var_column_names,
)
adata = AnnDataProxy(h5ad_path, obs_column_names=obs_column_names, var_column_names=var_column_names)
if filter_spec is not None:
adata = make_anndata_cell_filter(filter_spec)(adata)

return adata


class AnnDataFilterFunction(Protocol):
Expand Down
Loading

0 comments on commit 38fff2d

Please sign in to comment.