Skip to content

Commit

Permalink
Revert to ProcessPool
Browse files Browse the repository at this point in the history
  • Loading branch information
tongyx361 committed Sep 15, 2024
1 parent 979142a commit a119c6a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 89 deletions.
4 changes: 2 additions & 2 deletions symeval/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@
'symeval.core.norm_str2bool': ('core.html#norm_str2bool', 'symeval/core.py'),
'symeval.core.norm_str2weekday': ('core.html#norm_str2weekday', 'symeval/core.py'),
'symeval.core.parse': ('core.html#parse', 'symeval/core.py'),
'symeval.core.process_tasks': ('core.html#process_tasks', 'symeval/core.py'),
'symeval.core.rm_latex_env': ('core.html#rm_latex_env', 'symeval/core.py'),
'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py')}}}
'symeval.core.run_with_timeout': ('core.html#run_with_timeout', 'symeval/core.py'),
'symeval.core.task_wrapper': ('core.html#task_wrapper', 'symeval/core.py')}}}
114 changes: 27 additions & 87 deletions symeval/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def batch_extract_ans(
answers: List[str] = batch_exec(
self.extract_ans,
[{"resp_str": resp} for resp in resps],
self.n_procs,
self.timeout,
self.use_tqdm,
n_procs=self.n_procs,
timeout=self.timeout,
use_tqdm=self.use_tqdm,
desc="Extracting",
def_val="",
)
Expand All @@ -306,9 +306,9 @@ def batch_eq(
{"ref_ans": ref_ans, "pred": pred}
for ref_ans, pred in uniq_ref_pref_ans_pairs
],
self.n_procs,
self.timeout,
self.use_tqdm,
n_procs=self.n_procs,
timeout=self.timeout,
use_tqdm=self.use_tqdm,
desc="Judging",
def_val=False,
)
Expand Down Expand Up @@ -377,103 +377,43 @@ def batch_exec(
use_tqdm: bool = True,
desc: str = "Processing",
def_val: Any = None,
max_tasks_per_proc: int = 1024,
) -> List[Any]:
"""Execute a function in batch using multiprocessing with efficient per-task timeout."""
"""Execute a function in batch using `ProcessPool` with efficient per-task timeout."""
n_samples: int = len(kwargs_list)
n_procs = min(n_procs, n_samples)

# Shuffle the tasks
shuffled_tasks = list(enumerate(kwargs_list))
random.shuffle(shuffled_tasks)

# Split tasks into chunks for each process
chunk_size = -(-n_samples // n_procs) # Ceiling division
task_chunks = [
shuffled_tasks[i : i + chunk_size] for i in range(0, n_samples, chunk_size)
]

results: List[Any] = [def_val] * n_samples
pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None

with ProcessPool(max_workers=n_procs) as pool:
with ProcessPool(max_workers=n_procs, max_tasks=max_tasks_per_proc) as pool:
future = pool.map(
process_tasks,
task_chunks,
[func] * len(task_chunks),
[timeout] * len(task_chunks),
[def_val] * len(task_chunks),
task_wrapper, [func] * len(kwargs_list), kwargs_list, timeout=timeout
)
iterator = future.result()
# NOTE: This keeps the order

for chunk_results in iterator:
for idx, result in chunk_results:
results[idx] = result
if pbar:
pbar.update(1)

if pbar:
pbar.close()
pbar: Optional[tqdm] = tqdm(total=n_samples, desc=desc) if use_tqdm else None

return results


def process_tasks(
task_list: List[T_Tuple[int, T_Dict[str, Any]]],
func: Callable[..., Any],
timeout: int,
def_val: Any,
) -> List[T_Tuple[int, Any]]:
pid: int = os.getpid()
print(f"Process {pid} is processing {len(task_list)} tasks")
results: List[T_Tuple[int, Any]] = []

def task_wrapper(idx: int, kwargs: T_Dict[str, Any]) -> T_Tuple[int, Any]:
try:
result = func(**kwargs)
return idx, result
except Exception:
return idx, def_val

with ThreadPool() as pool:
futures = [
pool.schedule(task_wrapper, args=(idx, kwargs)) for idx, kwargs in task_list
]

for future in futures:
idx: int = 0
while idx < n_samples:
try:
idx, result = future.result(timeout=timeout)
except TimeoutError:
idx = task_list[len(results)][0]
result = def_val
result: Any = next(iterator)
results[idx] = result
except StopIteration:
break
except Exception:
idx = task_list[len(results)][0]
result = def_val
results.append((idx, result))
pass
if pbar is not None:
pbar.update(1)
idx += 1

if pbar:
pbar.close()

print(f"Process {pid} finished processing {len(task_list)} tasks")
return results


# def process_tasks(
# task_list: List[T_Tuple[int, T_Dict[str, Any]]],
# func: Callable[..., Any],
# timeout: int,
# def_val: Any,
# ) -> List[T_Tuple[int, Any]]:
# pid: int = os.getpid()
# print(f"Process {pid} is processing {len(task_list)} tasks")
# results = []
# for idx, kwargs in task_list:
# result: Any
# try:
# result = run_with_timeout(func, kwargs, timeout)
# except TimeoutError:
# result = def_val
# except Exception:
# result = def_val
# results.append((idx, result))
# print(f"Process {pid} finished processing {len(task_list)} tasks")
# return results
def task_wrapper(func: Callable[..., Any], kwargs: T_Dict[str, Any]) -> Any:
return func(**kwargs)


def run_with_timeout(
Expand Down

0 comments on commit a119c6a

Please sign in to comment.