diff --git a/symeval/_modidx.py b/symeval/_modidx.py index 4ffc97e..3f54d3d 100644 --- a/symeval/_modidx.py +++ b/symeval/_modidx.py @@ -65,6 +65,6 @@ 'symeval.core.norm_str2bool': ('core.html#norm_str2bool', 'symeval/core.py'), 'symeval.core.norm_str2weekday': ('core.html#norm_str2weekday', 'symeval/core.py'), 'symeval.core.parse': ('core.html#parse', 'symeval/core.py'), - 'symeval.core.process_tasks': ('core.html#process_tasks', 'symeval/core.py'), 'symeval.core.rm_latex_env': ('core.html#rm_latex_env', 'symeval/core.py'), - 'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py')}}} + 'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py'), + 'symeval.core.task_wrapper': ('core.html#task_wrapper', 'symeval/core.py')}}} diff --git a/symeval/core.py b/symeval/core.py index 4d02e95..a31463e 100644 --- a/symeval/core.py +++ b/symeval/core.py @@ -282,9 +282,9 @@ def batch_extract_ans( answers: List[str] = batch_exec( self.extract_ans, [{"resp_str": resp} for resp in resps], - self.n_procs, - self.timeout, - self.use_tqdm, + n_procs=self.n_procs, + timeout=self.timeout, + use_tqdm=self.use_tqdm, desc="Extracting", def_val="", ) @@ -306,9 +306,9 @@ def batch_eq( {"ref_ans": ref_ans, "pred": pred} for ref_ans, pred in uniq_ref_pref_ans_pairs ], - self.n_procs, - self.timeout, - self.use_tqdm, + n_procs=self.n_procs, + timeout=self.timeout, + use_tqdm=self.use_tqdm, desc="Judging", def_val=False, ) @@ -377,103 +377,43 @@ def batch_exec( use_tqdm: bool = True, desc: str = "Processing", def_val: Any = None, + max_tasks_per_proc: int = 1024, ) -> List[Any]: - """Execute a function in batch using multiprocessing with efficient per-task timeout.""" + """Execute a function in batch using `ProcessPool` with efficient per-task timeout.""" n_samples: int = len(kwargs_list) n_procs = min(n_procs, n_samples) - # Shuffle the tasks - shuffled_tasks = list(enumerate(kwargs_list)) - random.shuffle(shuffled_tasks) - - # Split tasks into chunks for each process - chunk_size = -(-n_samples // n_procs) # Ceiling division - task_chunks = [ - shuffled_tasks[i : i + chunk_size] for i in range(0, n_samples, chunk_size) - ] - results: List[Any] = [def_val] * n_samples - pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None - - with ProcessPool(max_workers=n_procs) as pool: + with ProcessPool(max_workers=n_procs, max_tasks=max_tasks_per_proc) as pool: future = pool.map( - process_tasks, - task_chunks, - [func] * len(task_chunks), - [timeout] * len(task_chunks), - [def_val] * len(task_chunks), + task_wrapper, [func] * len(kwargs_list), kwargs_list, timeout=timeout ) iterator = future.result() + # NOTE: This keeps the order - for chunk_results in iterator: - for idx, result in chunk_results: - results[idx] = result - if pbar: - pbar.update(1) - - if pbar: - pbar.close() + pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None - return results - - -def process_tasks( - task_list: List[T_Tuple[int, T_Dict[str, Any]]], - func: Callable[..., Any], - timeout: int, - def_val: Any, -) -> List[T_Tuple[int, Any]]: - pid: int = os.getpid() - print(f"Process {pid} is processing {len(task_list)} tasks") - results: List[T_Tuple[int, Any]] = [] - - def task_wrapper(idx: int, kwargs: T_Dict[str, Any]) -> T_Tuple[int, Any]: - try: - result = func(**kwargs) - return idx, result - except Exception: - return idx, def_val - - with ThreadPool() as pool: - futures = [ - pool.schedule(task_wrapper, args=(idx, kwargs)) for idx, kwargs in task_list - ] - - for future in futures: + idx: int = 0 + while idx < n_samples: try: - idx, result = future.result(timeout=timeout) - except TimeoutError: - idx = task_list[len(results)][0] - result = def_val + result: Any = next(iterator) + results[idx] = result + except StopIteration: + break except Exception: - idx = task_list[len(results)][0] - result = def_val - results.append((idx, result)) + pass + if pbar is not None: + pbar.update(1) + idx += 1 + + if pbar: + pbar.close() - print(f"Process {pid} finished processing {len(task_list)} tasks") return results -# def process_tasks( -# task_list: List[T_Tuple[int, T_Dict[str, Any]]], -# func: Callable[..., Any], -# timeout: int, -# def_val: Any, -# ) -> List[T_Tuple[int, Any]]: -# pid: int = os.getpid() -# print(f"Process {pid} is processing {len(task_list)} tasks") -# results = [] -# for idx, kwargs in task_list: -# result: Any -# try: -# result = run_with_timeout(func, kwargs, timeout) -# except TimeoutError: -# result = def_val -# except Exception: -# result = def_val -# results.append((idx, result)) -# print(f"Process {pid} finished processing {len(task_list)} tasks") -# return results +def task_wrapper(func: Callable[..., Any], kwargs: T_Dict[str, Any]) -> Any: + return func(**kwargs) def run_with_timeout(