Skip to content

Commit

Permalink
refactor dask shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmartinjr committed Mar 14, 2024
1 parent afade1c commit 08f4e37
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,68 +73,61 @@ def build(args: CensusBuildArgs, *, validate: bool = True) -> int:

prepare_file_system(args)

try:
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)
with create_dask_client(args, n_workers=n_workers, threads_per_worker=1, memory_limit=0) as client:
# Step 1 - get all source datasets
datasets = build_step1_get_source_datasets(args)
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)
with create_dask_client(args, n_workers=n_workers, threads_per_worker=1, memory_limit=0) as client:
# Step 1 - get all source datasets
datasets = build_step1_get_source_datasets(args)

# Step 2 - create root collection, and all child objects, but do not populate any dataframes or matrices
root_collection = build_step2_create_root_collection(args.soma_path.as_posix(), experiment_builders)
# Step 2 - create root collection, and all child objects, but do not populate any dataframes or matrices
root_collection = build_step2_create_root_collection(args.soma_path.as_posix(), experiment_builders)

# Step 3 - populate axes
filtered_datasets = build_step3_populate_obs_and_var_axes(
args.h5ads_path.as_posix(), datasets, experiment_builders, args
)
# Step 3 - populate axes
filtered_datasets = build_step3_populate_obs_and_var_axes(
args.h5ads_path.as_posix(), datasets, experiment_builders, args
)

# Constraining parallelism is critical at this step, as each worker utilizes (max) ~64GiB+ of memory to
# process the X array (partitions are large to reduce TileDB fragment count, which reduces consolidation time).
#
# TODO: when global order writes are supported, processing of much smaller slices will be
# possible, and this budget should drop considerably. When that is implemented, n_workers should be
# be much larger (eg., use default value of #CPUs or some such).
# https://github.com/single-cell-data/TileDB-SOMA/issues/2054
MEM_BUDGET = 64 * 1024**3
n_workers = clamp(int(psutil.virtual_memory().total // MEM_BUDGET), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n_workers)

# Step 4 - populate X layers
build_step4_populate_X_layers(args.h5ads_path.as_posix(), filtered_datasets, experiment_builders, args)

# Prune datasets that we will not use, and do not want to include in the build
prune_unused_datasets(args.h5ads_path, datasets, filtered_datasets)

# Step 5- write out dataset manifest and summary information
build_step5_save_axis_and_summary_info(
root_collection, experiment_builders, filtered_datasets, args.config.build_tag
)
# Constraining parallelism is critical at this step, as each worker utilizes (max) ~64GiB+ of memory to
# process the X array (partitions are large to reduce TileDB fragment count, which reduces consolidation time).
#
# TODO: when global order writes are supported, processing of much smaller slices will be
# possible, and this budget should drop considerably. When that is implemented, n_workers should be
# be much larger (eg., use default value of #CPUs or some such).
# https://github.com/single-cell-data/TileDB-SOMA/issues/2054
MEM_BUDGET = 64 * 1024**3
n_workers = clamp(int(psutil.virtual_memory().total // MEM_BUDGET), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n_workers)

# Step 4 - populate X layers
build_step4_populate_X_layers(args.h5ads_path.as_posix(), filtered_datasets, experiment_builders, args)

# Prune datasets that we will not use, and do not want to include in the build
prune_unused_datasets(args.h5ads_path, datasets, filtered_datasets)

# Step 5- write out dataset manifest and summary information
build_step5_save_axis_and_summary_info(
root_collection, experiment_builders, filtered_datasets, args.config.build_tag
)

# Temporary work-around. Can be removed when single-cell-data/TileDB-SOMA#1969 fixed.
tiledb_soma_1969_work_around(root_collection.uri)

# Temporary work-around. Can be removed when single-cell-data/TileDB-SOMA#1969 fixed.
tiledb_soma_1969_work_around(root_collection.uri)

# Scale the cluster up as we are no longer memory constrained in the following phases
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n=n_workers)

if args.config.consolidate:
for f in dask.distributed.as_completed(
submit_consolidate(root_collection.uri, pool=client, vacuum=True)
):
assert f.result()
if validate:
for f in dask.distributed.as_completed(validate_soma(args, client)):
assert f.result()
if args.config.consolidate and validate:
validate_consolidation(args)
logger.info("Validation & consolidation complete.")

shutdown_dask_cluster(client)

except TimeoutError:
# quiet tornado race conditions (harmless) on shutdown
pass
# Scale the cluster up as we are no longer memory constrained in the following phases
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n=n_workers)

if args.config.consolidate:
for f in dask.distributed.as_completed(submit_consolidate(root_collection.uri, pool=client, vacuum=True)):
assert f.result()
if validate:
for f in dask.distributed.as_completed(validate_soma(args, client)):
assert f.result()
if args.config.consolidate and validate:
validate_consolidation(args)
logger.info("Validation & consolidation complete.")

shutdown_dask_cluster(client)

return 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ def create_dask_client(

def shutdown_dask_cluster(client: dask.distributed.Client) -> None:
"""Clean-ish shutdown, designed to prevent hangs and error messages in log."""
client.retire_workers()
try:
client.retire_workers()
except TimeoutError:
# Quiet Tornado errors
pass

time.sleep(1)
client.shutdown()
try:
client.shutdown()
except TimeoutError:
# Quiet Tornado errors
pass

logger.info("Dask cluster shut down")
Original file line number Diff line number Diff line change
Expand Up @@ -1108,15 +1108,10 @@ def validate(args: CensusBuildArgs) -> int:
logger.info("Validating correct consolidation and vacuuming - start")
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)

try:
with create_dask_client(args, n_workers=n_workers, threads_per_worker=1, memory_limit=None) as client:
assert all(r.result() for r in distributed.wait(validate_soma(args, client)).done)
logging.info("Validation complete.")

shutdown_dask_cluster(client)

except TimeoutError:
pass
with create_dask_client(args, n_workers=n_workers, threads_per_worker=1, memory_limit=None) as client:
assert all(r.result() for r in distributed.wait(validate_soma(args, client)).done)
shutdown_dask_cluster(client)
logging.info("Validation complete.")

assert validate_consolidation(args)
logger.info("Validating correct consolidation and vacuuming - complete")
Expand Down
1 change: 1 addition & 0 deletions tools/cellxgene_census_builder/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_build_step1_get_source_datasets(tmp_path: pathlib.Path, census_build_ar
census_build_args.h5ads_path.mkdir(parents=True, exist_ok=True)

# Call the function
process_init(census_build_args)
with create_dask_client(census_build_args) as client:
datasets = build_step1_get_source_datasets(census_build_args)
shutdown_dask_cluster(client)
Expand Down
2 changes: 2 additions & 0 deletions tools/cellxgene_census_builder/tests/test_source_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cellxgene_census_builder.build_soma.mp import create_dask_client, shutdown_dask_cluster
from cellxgene_census_builder.build_soma.source_assets import stage_source_assets
from cellxgene_census_builder.build_state import CensusBuildArgs
from cellxgene_census_builder.process_init import process_init


def test_source_assets(tmp_path: pathlib.Path, census_build_args: CensusBuildArgs) -> None:
Expand All @@ -23,6 +24,7 @@ def test_source_assets(tmp_path: pathlib.Path, census_build_args: CensusBuildArg
datasets.append(dataset)

# Call the function
process_init(census_build_args)
with create_dask_client(census_build_args) as client:
stage_source_assets(datasets, census_build_args)
shutdown_dask_cluster(client)
Expand Down

0 comments on commit 08f4e37

Please sign in to comment.