Skip to content

Commit

Permalink
Revert to ProcessPool
Browse files Browse the repository at this point in the history
  • Loading branch information
tongyx361 committed Sep 15, 2024
1 parent 979142a commit cc37381
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 98 deletions.
4 changes: 2 additions & 2 deletions symeval/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@
'symeval.core.norm_str2bool': ('core.html#norm_str2bool', 'symeval/core.py'),
'symeval.core.norm_str2weekday': ('core.html#norm_str2weekday', 'symeval/core.py'),
'symeval.core.parse': ('core.html#parse', 'symeval/core.py'),
'symeval.core.process_tasks': ('core.html#process_tasks', 'symeval/core.py'),
'symeval.core.rm_latex_env': ('core.html#rm_latex_env', 'symeval/core.py'),
'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py')}}}
'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py'),
'symeval.core.task_wrapper': ('core.html#task_wrapper', 'symeval/core.py')}}}
132 changes: 36 additions & 96 deletions symeval/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_maj_ans_from_votes(self, ans_votes: T_Counter[str]) -> str:


DEF_TIMEOUT: int = 5
DEF_N_PROC: int = 2
DEF_N_WORKER: int = 2


class EvaluatorBatchBase(EvaluatorBase):
Expand All @@ -245,7 +245,7 @@ class EvaluatorBatchBase(EvaluatorBase):
Whether to extract answers strictly. If `False`, speculate the answer from the last number if needed.
timeout : int, default: DEF_TIMEOUT:=5
The timeout for each evaluation in seconds.
n_procs: int, default: 2
n_workers: int, default: 2
The number of processes to use for multiprocessing.
use_tqdm: bool, default: True
Whether to use tqdm for progress bar.
Expand All @@ -255,13 +255,13 @@ def __init__(
self,
strict_extract: bool = True,
timeout: int = DEF_TIMEOUT,
n_procs: int = DEF_N_PROC,
n_workers: int = DEF_N_WORKER,
use_tqdm: bool = True,
):
EvaluatorBase.__init__(self)
self.strict_extract: bool = strict_extract
self.timeout: int = timeout
self.n_procs: int = n_procs
self.n_workers: int = n_workers
self.use_tqdm: bool = use_tqdm

def batch_eval(
Expand All @@ -282,9 +282,9 @@ def batch_extract_ans(
answers: List[str] = batch_exec(
self.extract_ans,
[{"resp_str": resp} for resp in resps],
self.n_procs,
self.timeout,
self.use_tqdm,
n_workers=self.n_workers,
timeout=self.timeout,
use_tqdm=self.use_tqdm,
desc="Extracting",
def_val="",
)
Expand All @@ -306,9 +306,9 @@ def batch_eq(
{"ref_ans": ref_ans, "pred": pred}
for ref_ans, pred in uniq_ref_pref_ans_pairs
],
self.n_procs,
self.timeout,
self.use_tqdm,
n_workers=self.n_workers,
timeout=self.timeout,
use_tqdm=self.use_tqdm,
desc="Judging",
def_val=False,
)
Expand Down Expand Up @@ -372,108 +372,48 @@ def batch_get_maj_answers(
def batch_exec(
func: Callable[..., Any],
kwargs_list: List[T_Dict[str, Any]],
n_procs: int = DEF_N_PROC,
n_workers: int = DEF_N_WORKER,
timeout: int = DEF_TIMEOUT,
use_tqdm: bool = True,
desc: str = "Processing",
def_val: Any = None,
max_tasks_per_worker: int = 1024,
) -> List[Any]:
"""Execute a function in batch using multiprocessing with efficient per-task timeout."""
"""Execute a function in batch using `ProcessPool` with efficient per-task timeout."""
n_samples: int = len(kwargs_list)
n_procs = min(n_procs, n_samples)

# Shuffle the tasks
shuffled_tasks = list(enumerate(kwargs_list))
random.shuffle(shuffled_tasks)

# Split tasks into chunks for each process
chunk_size = -(-n_samples // n_procs) # Ceiling division
task_chunks = [
shuffled_tasks[i : i + chunk_size] for i in range(0, n_samples, chunk_size)
]
n_workers = min(n_workers, n_samples)

results: List[Any] = [def_val] * n_samples
pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None

with ProcessPool(max_workers=n_procs) as pool:
with ProcessPool(max_workers=n_workers, max_tasks=max_tasks_per_worker) as pool:
future = pool.map(
process_tasks,
task_chunks,
[func] * len(task_chunks),
[timeout] * len(task_chunks),
[def_val] * len(task_chunks),
task_wrapper, [func] * len(kwargs_list), kwargs_list, timeout=timeout
)
iterator = future.result()
# NOTE: This keeps the order

for chunk_results in iterator:
for idx, result in chunk_results:
results[idx] = result
if pbar:
pbar.update(1)

if pbar:
pbar.close()
pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None

return results


def process_tasks(
task_list: List[T_Tuple[int, T_Dict[str, Any]]],
func: Callable[..., Any],
timeout: int,
def_val: Any,
) -> List[T_Tuple[int, Any]]:
pid: int = os.getpid()
print(f"Process {pid} is processing {len(task_list)} tasks")
results: List[T_Tuple[int, Any]] = []

def task_wrapper(idx: int, kwargs: T_Dict[str, Any]) -> T_Tuple[int, Any]:
try:
result = func(**kwargs)
return idx, result
except Exception:
return idx, def_val

with ThreadPool() as pool:
futures = [
pool.schedule(task_wrapper, args=(idx, kwargs)) for idx, kwargs in task_list
]

for future in futures:
idx: int = 0
while idx < n_samples:
try:
idx, result = future.result(timeout=timeout)
except TimeoutError:
idx = task_list[len(results)][0]
result = def_val
result: Any = next(iterator)
results[idx] = result
except StopIteration:
break
except Exception:
idx = task_list[len(results)][0]
result = def_val
results.append((idx, result))
pass
if pbar is not None:
pbar.update(1)
idx += 1

if pbar:
pbar.close()

print(f"Process {pid} finished processing {len(task_list)} tasks")
return results


# def process_tasks(
# task_list: List[T_Tuple[int, T_Dict[str, Any]]],
# func: Callable[..., Any],
# timeout: int,
# def_val: Any,
# ) -> List[T_Tuple[int, Any]]:
# pid: int = os.getpid()
# print(f"Process {pid} is processing {len(task_list)} tasks")
# results = []
# for idx, kwargs in task_list:
# result: Any
# try:
# result = run_with_timeout(func, kwargs, timeout)
# except TimeoutError:
# result = def_val
# except Exception:
# result = def_val
# results.append((idx, result))
# print(f"Process {pid} finished processing {len(task_list)} tasks")
# return results
def task_wrapper(func: Callable[..., Any], kwargs: T_Dict[str, Any]) -> Any:
return func(**kwargs)


def run_with_timeout(
Expand Down Expand Up @@ -1633,7 +1573,7 @@ class EvaluatorMathBatch(EvaluatorMath, EvaluatorBatchBase):
Only allowing ASCII characters
timeout : int, default: DEF_TIMEOUT:=5
The timeout for each evaluation.
n_procs: int, default: 2
n_workers: int, default: 2
The number of processes to use for multiprocessing.
use_tqdm: bool, default: True
Whether to use tqdm for progress bar.
Expand All @@ -1648,7 +1588,7 @@ def __init__(
percent_rel_tol: float = DEF_PERCENT_REL_TOL,
ascii_only: bool = True,
timeout: int = DEF_TIMEOUT,
n_procs: int = DEF_N_PROC,
n_workers: int = DEF_N_WORKER,
use_tqdm: bool = True,
):

Expand All @@ -1661,6 +1601,6 @@ def __init__(
ascii_only=ascii_only,
)
EvaluatorBatchBase.__init__(
self, timeout=timeout, n_procs=n_procs, use_tqdm=use_tqdm
self, timeout=timeout, n_workers=n_workers, use_tqdm=use_tqdm
)
self.strict_extract = strict_extract

0 comments on commit cc37381

Please sign in to comment.