diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 1013f0ff1..a380da351 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -40,7 +40,12 @@ def __init__( self.test_data = utils.SharedNumpyArray(test_data) log.debug(f"test dataset columns: {len(test_data)}") - def search(self, test_np: utils.SharedNumpyArray) -> tuple[int, float]: + def search(self, test_np: utils.SharedNumpyArray, q: mp.Queue, cond: mp.Condition) -> tuple[int, float]: + # sync all process + q.put(1) + with cond: + cond.wait() + with self.db.init(): test_data = test_np.read().tolist() num, idx = len(test_data), 0 @@ -77,7 +82,7 @@ def search(self, test_np: utils.SharedNumpyArray) -> tuple[int, float]: @staticmethod def get_mp_context(): - mp_start_method = "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn" + mp_start_method = "spawn" log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}") return mp.get_context(mp_start_method) @@ -85,15 +90,26 @@ def _run_all_concurrencies_mem_efficient(self) -> float: max_qps = 0 try: for conc in self.concurrencies: - with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor: - start = time.perf_counter() - log.info(f"start search {self.duration}s in concurrency {conc}, filters: {self.filters}") - future_iter = executor.map(self.search, [self.test_data for i in range(conc)]) - all_count = sum([r[0] for r in future_iter]) - - cost = time.perf_counter() - start - qps = round(all_count / cost, 4) - log.info(f"end search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") + with mp.Manager() as m: + q, cond = m.Queue(), m.Condition() + with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor: + log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}") + future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)] + # Sync all processes + while q.qsize() < conc: + sleep_t = conc if conc < 10 else 10 + time.sleep(sleep_t) + + with cond: + cond.notify_all() + log.info(f"Syncing all process and start concurency search, concurency={conc}") + + start = time.perf_counter() + all_count = sum([r.result()[0] for r in future_iter]) + cost = time.perf_counter() - start + + qps = round(all_count / cost, 4) + log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") if qps > max_qps: max_qps = qps