Skip to content

Commit

Permalink
handle processpool breaking
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jan 2, 2025
1 parent 6e9af63 commit e1c7cec
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/datatrove/pipeline/extractors/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool

from datatrove.data import DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
Expand Down Expand Up @@ -46,25 +47,42 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
Returns:
"""
with ProcessPoolExecutor(
max_workers=1
) as executor: # Single process. We use ProcessPool instead of ThreadPool
# to be able to kill timed out tasks, lest them continue running for minutes and OOM
executor = ProcessPoolExecutor(max_workers=1)
try:
for doc in data:
self.stat_update(StatHints.total)
with self.track_time():
future = executor.submit(self.extract, doc.text)
# If submit fails, the pool was already broken from previous task
try:
future = executor.submit(self.extract, doc.text)
except BrokenProcessPool:
logger.warning(
"Found broken process pool, creating new executor and retrying current document"
)
executor.shutdown(wait=False)
executor = ProcessPoolExecutor(max_workers=1)
self.stat_update("broken_pool")
try:
future = executor.submit(self.extract, doc.text)
except BrokenProcessPool:
logger.error("New pool also broke, skipping document")
continue

try:
doc.text = future.result(timeout=self.timeout)
self.stat_update("extracted")
except TimeoutError:
future.cancel() # Kill the timed-out process
future.cancel()
logger.warning("⏰ Timeout while cleaning record text. Skipping record.")
self.stat_update("timeout")
continue
except Exception as e:
future.cancel() # Good practice to cancel on other errors too
future.cancel()
self.stat_update("clean_error")
if not self._warned_error:
logger.warning(
f'❌ Error "{e}" while cleaning record text. Skipping record. This message will only appear once.'
f'❌ Error "{e}" while cleaning record text. Skipping record. This message will only '
f"appear once."
)
self._warned_error = True
continue
Expand All @@ -75,3 +93,5 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
yield doc
else:
self.stat_update(StatHints.dropped)
finally:
executor.shutdown(wait=False, cancel_futures=True)

0 comments on commit e1c7cec

Please sign in to comment.