Skip to content

Commit

Permalink
bugfixes with logging + save a dump of the executor in json format
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Nov 10, 2023
1 parent 30af237 commit dc111d8
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 66 deletions.
32 changes: 14 additions & 18 deletions examples/other sources/tokenize_the_pile.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
import os.path

from datatrove.executor.base import PipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.io import LocalOutputDataFolder, S3InputDataFolder
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.io import LocalOutputDataFolder, S3InputDataFolder, S3OutputDataFolder
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.tokens.tokenizer import DocumentTokenizer


def format_adapter(d: dict, path: str, li: int):
return {
"content": d["text"],
"data_id": f"{os.path.splitext(os.path.basename(path))[0]}_{li}",
"metadata": d["meta"],
}


pipeline = [
JsonlReader(
S3InputDataFolder("s3://bigcode-experiments/the-pile-sharded/", stream=True),
gzip=False,
adapter=format_adapter,
content_key="text",
limit=100,
),
DocumentTokenizer(
LocalOutputDataFolder(path="/fsx/guilherme/the-pile/tokenized"), save_filename="the_pile_tokenized"
LocalOutputDataFolder(path="/home/gui/hf_dev/datatrove/examples_test/piletokenized-test/tokenized"),
save_filename="the_pile_tokenized",
),
]

executor: PipelineExecutor = SlurmPipelineExecutor(
executor: PipelineExecutor = LocalPipelineExecutor(
pipeline=pipeline,
tasks=20,
time="12:00:00",
partition="dev-cluster",
logging_dir="/fsx/guilherme/logs/tokenize_the_pile",
workers=20,
skip_completed=False,
logging_dir=S3OutputDataFolder(
"s3://extreme-scale-dp-temp/logs/tests/piletokenized/tokenized",
local_path="/home/gui/hf_dev/datatrove/examples_test/piletokenized-test/logs",
cleanup=False,
),
)
executor.run()
57 changes: 45 additions & 12 deletions src/datatrove/executor/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import dataclasses
import json
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Sequence
Expand All @@ -7,20 +9,26 @@

from datatrove.io import BaseOutputDataFolder, LocalOutputDataFolder
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.logging import add_task_logger, get_random_str, get_timestamp
from datatrove.utils.logging import add_task_logger, close_task_logger, get_random_str, get_timestamp
from datatrove.utils.stats import PipelineStats


class PipelineExecutor(ABC):
@abstractmethod
def __init__(self, pipeline: list[PipelineStep | Callable], logging_dir: BaseOutputDataFolder = None):
def __init__(
self,
pipeline: list[PipelineStep | Callable],
logging_dir: BaseOutputDataFolder = None,
skip_completed: bool = True,
):
self.pipeline: list[PipelineStep | Callable] = pipeline
self.logging_dir = (
logging_dir if logging_dir else LocalOutputDataFolder(f"logs/{get_timestamp()}_{get_random_str()}")
)
self.skip_completed = skip_completed

pipeline = "\n".join([pipe.__repr__() if callable(pipe) else "Sequence..." for pipe in self.pipeline])
print(f"--- 🛠️PIPELINE 🛠\n{pipeline}")
# pipeline = "\n".join([pipe.__repr__() if callable(pipe) else "Sequence..." for pipe in self.pipeline])
# print(f"--- 🛠️PIPELINE 🛠\n{pipeline}")

@abstractmethod
def run(self):
Expand All @@ -36,8 +44,8 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
logger.info(f"Skipping {rank=} as it has already been completed.")
return PipelineStats() # todo: fetch the original stats file (?)
add_task_logger(self.logging_dir, rank, local_rank)
# pipe data from one step to the next
try:
# pipe data from one step to the next
pipelined_data = None
for pipeline_step in self.pipeline:
if callable(pipeline_step):
Expand All @@ -48,19 +56,44 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
raise ValueError
if pipelined_data:
deque(pipelined_data, maxlen=0)

logger.info(f"Processing done for {rank=}")
# stats
stats = PipelineStats(
[pipeline_step.stats for pipeline_step in self.pipeline if isinstance(pipeline_step, PipelineStep)]
)
stats.save_to_disk(self.logging_dir.open(f"stats/{rank:05d}.json"))
# completed
self.mark_rank_as_completed(rank)
except Exception as e:
logger.exception(e)
raise e
logger.info(f"Processing done for {rank=}")
stats = PipelineStats(
[pipeline_step.stats for pipeline_step in self.pipeline if isinstance(pipeline_step, PipelineStep)]
)
stats.save_to_disk(self.logging_dir.open(f"stats/{rank:05d}.json"))
self.mark_rank_as_completed(rank)
finally:
close_task_logger(self.logging_dir, rank)
return stats

def is_rank_completed(self, rank: int):
return self.logging_dir.to_input_folder().file_exists(f"completions/{rank:05d}")
return self.skip_completed and self.logging_dir.to_input_folder().file_exists(f"completions/{rank:05d}")

def mark_rank_as_completed(self, rank: int):
self.logging_dir.open(f"completions/{rank:05d}").close()

def to_json(self, indent=4):
data = self.__dict__
data["pipeline"] = [{a: b for a, b in x.__dict__.items() if a != "stats"} for x in data["pipeline"]]
return json.dumps(data, indent=indent)

def save_executor_as_json(self, indent: int = 4):
with self.logging_dir.open("executor.json") as f:
json.dump(self, f, cls=ExecutorJSONEncoder, indent=indent)


class ExecutorJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if isinstance(o, PipelineExecutor):
return o.__dict__
if isinstance(o, PipelineStep):
return {a: b for a, b in o.__dict__.items() if a != "stats"}
return str(o)
29 changes: 18 additions & 11 deletions src/datatrove/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from datatrove.pipeline.base import PipelineStep


download_semaphore, upload_semaphore = None, None
download_semaphore, upload_semaphore, ranks_queue = None, None, None


def init_pool_processes(dl_sem, up_sem):
global download_semaphore, upload_semaphore
def init_pool_processes(dl_sem, up_sem, ranks_q):
global download_semaphore, upload_semaphore, ranks_queue
download_semaphore = dl_sem
upload_semaphore = up_sem
ranks_queue = ranks_q


class LocalPipelineExecutor(PipelineExecutor):
Expand All @@ -27,9 +28,9 @@ def __init__(
max_concurrent_uploads: int = 20,
max_concurrent_downloads: int = 50,
logging_dir: BaseOutputDataFolder = None,
skip_completed: bool = True,
):
super().__init__(pipeline, logging_dir)
self.local_ranks = None
super().__init__(pipeline, logging_dir, skip_completed)
self.tasks = tasks
self.workers = workers if workers != -1 else tasks
self.max_concurrent_uploads = max_concurrent_uploads
Expand All @@ -40,16 +41,17 @@ def _run_for_rank(self, rank: int, local_rank: int = -1):
for pipeline_step in self.pipeline:
if isinstance(pipeline_step, PipelineStep):
pipeline_step.set_up_dl_locks(download_semaphore, upload_semaphore)
local_rank = self.local_ranks.get()
local_rank = ranks_queue.get()
try:
return super()._run_for_rank(rank, local_rank)
finally:
self.local_ranks.put(local_rank) # free up used rank
ranks_queue.put(local_rank) # free up used rank

def run(self):
self.local_ranks = Queue()
self.save_executor_as_json()
ranks_q = Queue()
for i in range(self.workers):
self.local_ranks.put(i)
ranks_q.put(i)

if self.workers == 1:
pipeline = self.pipeline
Expand All @@ -61,9 +63,14 @@ def run(self):
dl_sem = Semaphore(self.max_concurrent_downloads)
up_sem = Semaphore(self.max_concurrent_uploads)

with multiprocess.Pool(self.workers, initializer=init_pool_processes, initargs=(dl_sem, up_sem)) as pool:
with multiprocess.Pool(
self.workers, initializer=init_pool_processes, initargs=(dl_sem, up_sem, ranks_q)
) as pool:
stats = list(pool.map(self._run_for_rank, range(self.tasks)))
return sum(stats)
stats = sum(stats)
stats.save_to_disk(self.logging_dir.open("stats.json"))
self.logging_dir.close()
return stats

@property
def world_size(self):
Expand Down
4 changes: 3 additions & 1 deletion src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
max_array_size: int = 1001,
depends: SlurmPipelineExecutor | None = None,
logging_dir: BaseOutputDataFolder = None,
skip_completed: bool = True,
):
"""
:param tasks: total number of tasks to run
Expand All @@ -51,7 +52,7 @@ def __init__(
"""
if isinstance(logging_dir, S3OutputDataFolder):
logging_dir.cleanup = False # if the files are removed from disk job launch will fail
super().__init__(pipeline, logging_dir)
super().__init__(pipeline, logging_dir, skip_completed)
self.tasks = tasks
self.workers = workers
self.partition = partition
Expand Down Expand Up @@ -125,6 +126,7 @@ def launch_job(self):
# pickle
with self.logging_dir.open("executor.pik", "wb") as executor_f:
dill.dump(self, executor_f, fmode=CONTENTS_FMODE)
self.save_executor_as_json()

with self.logging_dir.open("launch_script.slurm") as launchscript_f:
launchscript_f.write(
Expand Down
15 changes: 5 additions & 10 deletions src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import zstandard
from loguru import logger
from multiprocess.synchronize import SemLock

from datatrove.io.utils.fsspec import valid_fsspec_path

Expand All @@ -28,7 +29,7 @@ class BaseInputDataFile:

path: str
relative_path: str
folder: "BaseInputDataFolder" = None
_lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext)

@classmethod
def from_path(cls, path: str, **kwargs):
Expand Down Expand Up @@ -133,7 +134,7 @@ class BaseInputDataFolder(ABC):
extension: str | list[str] = None
recursive: bool = True
match_pattern: str = None
_lock = contextlib.nullcontext()
_lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext)

@classmethod
def from_path(cls, path, **kwargs):
Expand Down Expand Up @@ -284,9 +285,9 @@ class BaseOutputDataFile(ABC):

path: str
relative_path: str
folder: "BaseOutputDataFolder" = None
_file_handler = None
_mode: str = None
_lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext)

@classmethod
def from_path(cls, path: str, **kwargs):
Expand Down Expand Up @@ -371,8 +372,6 @@ def delete(self):
Does not call close() on the OutputDataFile directly to avoid uploading/saving it externally.
:return:
"""
if self.folder:
self.folder.pop_file(self.relative_path)
if self._file_handler:
self._file_handler.close()
self._file_handler = None
Expand All @@ -395,7 +394,7 @@ class BaseOutputDataFolder(ABC):

path: str
_output_files: dict[str, BaseOutputDataFile] = field(default_factory=dict)
_lock = contextlib.nullcontext()
_lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext)

def close(self):
"""
Expand Down Expand Up @@ -454,10 +453,6 @@ def open(self, relative_path: str, mode: str = "w", gzip: bool = False) -> BaseO
self._output_files[relative_path] = self._create_new_file(relative_path)
return self._output_files[relative_path].open(mode, gzip)

def pop_file(self, relative_path):
if relative_path in self._output_files:
self._output_files.pop(relative_path)

@abstractmethod
def to_input_folder(self) -> BaseInputDataFolder:
raise NotImplementedError
4 changes: 2 additions & 2 deletions src/datatrove/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __post_init__(self):

def _create_new_file(self, relative_path: str):
return FSSpecOutputDataFile(
path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs, folder=self
path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs, _lock=self._lock
)

def to_input_folder(self) -> BaseInputDataFolder:
Expand Down Expand Up @@ -88,7 +88,7 @@ def list_files(self, extension: str | list[str] = None, suffix: str = "") -> lis

def _unchecked_get_file(self, relative_path: str):
return FSSpecInputDataFile(
path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs, folder=self
path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs, _lock=self._lock
)

def file_exists(self, relative_path: str) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions src/datatrove/io/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LocalOutputDataFolder(BaseOutputDataFolder):

def _create_new_file(self, relative_path: str) -> LocalOutputDataFile:
return LocalOutputDataFile(
path=os.path.join(self.path, relative_path), relative_path=relative_path, folder=self
path=os.path.join(self.path, relative_path), relative_path=relative_path, _lock=self._lock
)

def to_input_folder(self) -> BaseInputDataFolder:
Expand Down Expand Up @@ -83,7 +83,7 @@ def _unchecked_get_file(self, relative_path: str) -> LocalInputDataFile:
:return: an input file
"""
return LocalInputDataFile(
path=os.path.join(self.path, relative_path), relative_path=relative_path, folder=self
path=os.path.join(self.path, relative_path), relative_path=relative_path, _lock=self._lock
)


Expand Down
Loading

0 comments on commit dc111d8

Please sign in to comment.