Skip to content

Commit

Permalink
Configure max_workers
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonbb committed Apr 3, 2024
1 parent 224d0ed commit 045308b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
18 changes: 11 additions & 7 deletions src/kbmod/region_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _chunked_data_ids(dataIds, chunk_size=200):
for i in range(0, len(dataIds), chunk_size):
yield dataIds[i : i + chunk_size]


class RegionSearch:
"""
A class for searching through a dataset for data suitable for KBMOD processing,
Expand All @@ -34,7 +35,7 @@ def __init__(
dataset_types,
butler=None,
visit_info_str="Exposure.visitInfo",
parallel=False,
max_workers=None,
fetch_data=False,
):
"""
Expand All @@ -50,9 +51,9 @@ def __init__(
The Butler object to use for data access. If None, a new Butler object will be created from `repo_path`.
visit_info_str : `str`
The name used when querying the butler for VisitInfo for exposures. Default is "Exposure.visitInfo".
parallel : `bool`
If True, use parallel processing where possible. Note that each parallel worker
will instantiate its own Butler objects, Default is False.
max_workers : `int`, optional
The maximum number of workers to use in parallel processing. Note that each parallel worker will instantiate its own Butler
objects. If not provided, parallel processing is disabled.
fetch_data: `bool`, optional
If True, fetch the VDR data when the object is created. Default is True.
"""
Expand All @@ -65,7 +66,7 @@ def __init__(
self.collections = collections
self.dataset_types = dataset_types
self.visit_info_str = visit_info_str
self.parallel = parallel
self.max_workers = max_workers

# Create an empty table to store the VDR (Visit, Detector, Region) data from the butler.
self.vdr_data = Table()
Expand Down Expand Up @@ -128,6 +129,9 @@ def get_dataset_type_freq(butler=None, repo_path=None, collections=None):

return ref_freq

def is_parallel(self):
return self.max_workers is not None

def new_butler(self):
"""Instantiates a new Butler object from the repo_path."""
return dafButler.Butler(self.repo_path)
Expand Down Expand Up @@ -294,15 +298,15 @@ def get_uris(self, data_ids, dataset_types=None, collections=None):
raise ValueError("No collections specified")
collections = self.collections

if not self.parallel:
if not self.is_parallel():
return self.get_uris_serial(data_ids, dataset_types, collections)

# Divide the data_ids into chunks to be processed in parallel
data_id_chunks = list(_chunked_data_ids(data_ids))

# Use a ProcessPoolExecutor to fetch URIs in parallel
uris = []
with ProcessPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(
self.get_uris_serial,
Expand Down
7 changes: 2 additions & 5 deletions tests/test_region_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,13 @@ def setUp(self):
self.default_collections,
self.default_datasetTypes,
butler=self.butler,
parallel=False,
)

def test_init(self):
"""
Test that the region search object can be initialized.
"""
rs = region_search.RegionSearch(
MOCK_REPO_PATH, [], [], butler=self.butler, parallel=False, fetch_data=False
)
rs = region_search.RegionSearch(MOCK_REPO_PATH, [], [], butler=self.butler, fetch_data=False)
self.assertTrue(rs is not None)
self.assertEqual(0, len(rs.vdr_data)) # No data should be fetched

Expand Down Expand Up @@ -190,7 +187,7 @@ def func(repo_path):
self.default_collections,
self.default_datasetTypes,
butler=self.butler,
parallel=False, # TODO Turn on after fixing pickle issue for mocked objects
# TODO Turn on after fixing pickle issue for mocked objects
)

uris = parallel_rs.get_uris(data_ids)
Expand Down

0 comments on commit 045308b

Please sign in to comment.