diff --git a/src/kbmod/region_search.py b/src/kbmod/region_search.py index 16f9300a3..97a895571 100644 --- a/src/kbmod/region_search.py +++ b/src/kbmod/region_search.py @@ -14,6 +14,7 @@ def _chunked_data_ids(dataIds, chunk_size=200): for i in range(0, len(dataIds), chunk_size): yield dataIds[i : i + chunk_size] + class RegionSearch: """ A class for searching through a dataset for data suitable for KBMOD processing, @@ -34,7 +35,7 @@ def __init__( dataset_types, butler=None, visit_info_str="Exposure.visitInfo", - parallel=False, + max_workers=None, fetch_data=False, ): """ @@ -50,9 +51,9 @@ def __init__( The Butler object to use for data access. If None, a new Butler object will be created from `repo_path`. visit_info_str : `str` The name used when querying the butler for VisitInfo for exposures. Default is "Exposure.visitInfo". - parallel : `bool` - If True, use parallel processing where possible. Note that each parallel worker - will instantiate its own Butler objects, Default is False. + max_workers : `int`, optional + The maximum number of workers to use in parallel processing. Note that each parallel worker will instantiate its own Butler + objects. If not provided, parallel processing is disabled. fetch_data: `bool`, optional If True, fetch the VDR data when the object is created. Default is True. """ @@ -65,7 +66,7 @@ def __init__( self.collections = collections self.dataset_types = dataset_types self.visit_info_str = visit_info_str - self.parallel = parallel + self.max_workers = max_workers # Create an empty table to store the VDR (Visit, Detector, Region) data from the butler. self.vdr_data = Table() @@ -128,6 +129,9 @@ def get_dataset_type_freq(butler=None, repo_path=None, collections=None): return ref_freq + def is_parallel(self): + return self.max_workers is not None + def new_butler(self): """Instantiates a new Butler object from the repo_path.""" return dafButler.Butler(self.repo_path) @@ -294,7 +298,7 @@ def get_uris(self, data_ids, dataset_types=None, collections=None): raise ValueError("No collections specified") collections = self.collections - if not self.parallel: + if not self.is_parallel(): return self.get_uris_serial(data_ids, dataset_types, collections) # Divide the data_ids into chunks to be processed in parallel @@ -302,7 +306,7 @@ def get_uris(self, data_ids, dataset_types=None, collections=None): # Use a ProcessPoolExecutor to fetch URIs in parallel uris = [] - with ProcessPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor: + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: futures = [ executor.submit( self.get_uris_serial, diff --git a/tests/test_region_search.py b/tests/test_region_search.py index fbfb758df..f8197b503 100644 --- a/tests/test_region_search.py +++ b/tests/test_region_search.py @@ -43,16 +43,13 @@ def setUp(self): self.default_collections, self.default_datasetTypes, butler=self.butler, - parallel=False, ) def test_init(self): """ Test that the region search object can be initialized. """ - rs = region_search.RegionSearch( - MOCK_REPO_PATH, [], [], butler=self.butler, parallel=False, fetch_data=False - ) + rs = region_search.RegionSearch(MOCK_REPO_PATH, [], [], butler=self.butler, fetch_data=False) self.assertTrue(rs is not None) self.assertEqual(0, len(rs.vdr_data)) # No data should be fetched @@ -190,7 +187,7 @@ def func(repo_path): self.default_collections, self.default_datasetTypes, butler=self.butler, - parallel=False, # TODO Turn on after fixing pickle issue for mocked objects + # TODO Turn on after fixing pickle issue for mocked objects ) uris = parallel_rs.get_uris(data_ids)