Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch all IO to fsspec #49

Merged
merged 34 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5b1d453
batched tokenization
thomwolf Dec 17, 2023
e381c34
revert stats change
guipenedo Dec 18, 2023
4dc7af6
rewrite to only use fsspec
guipenedo Jan 10, 2024
3576bcd
fixed tests
guipenedo Jan 11, 2024
0661ab3
fixed minhash s3 support
guipenedo Jan 11, 2024
e3ea72d
bugfix tests
guipenedo Jan 11, 2024
d6e463d
removed mmaps from tokenizer - now streaming with seek only
guipenedo Jan 12, 2024
5c281e6
try to optimize caching on the streaming a bit
guipenedo Jan 12, 2024
5d517fb
nit
guipenedo Jan 12, 2024
9a612f2
small bugfix
guipenedo Jan 12, 2024
0ea10ad
some qol improvements
guipenedo Jan 12, 2024
244dd06
fixing bugs in slurm with new io
guipenedo Jan 12, 2024
5cf9de5
bugfix fetching list of completed tasks
guipenedo Jan 12, 2024
0cb380e
more executor bugfixes
guipenedo Jan 12, 2024
93bb0e8
change absolute paths of local files
guipenedo Jan 12, 2024
9d59e8b
fix creating log folder
guipenedo Jan 12, 2024
bc4119b
fix creating stats dir
guipenedo Jan 12, 2024
5e8f6b9
more executor io bugfixes
guipenedo Jan 12, 2024
230fc55
bugfix merge_stats
guipenedo Jan 12, 2024
4c57ed8
ensure list_files determinism
guipenedo Jan 16, 2024
e42f152
Update src/datatrove/datafolder.py
guipenedo Jan 16, 2024
ab75d88
address review comments
guipenedo Jan 16, 2024
d1fe562
local_working_dir to improve tokenization shuffling performance
guipenedo Jan 16, 2024
11a2d9a
add (str path, fully initialized fs object) option to `get_datafolder`
guipenedo Jan 17, 2024
11804dc
implement auto_mkdir
guipenedo Jan 17, 2024
5d62aed
changed resolve_paths and added a first io test
guipenedo Jan 17, 2024
e07c67e
added missing test dependency
guipenedo Jan 17, 2024
ca79c08
auto_mkdirs changes and test
guipenedo Jan 17, 2024
de3508a
added some more tests. changed local executor to use forkserver as ot…
guipenedo Jan 18, 2024
5adc77b
don't pin dependencies
guipenedo Jan 18, 2024
65f2041
Merge branch 'main' into io-fsspec
guipenedo Jan 18, 2024
8a4a2d9
some more dependency changes
guipenedo Jan 18, 2024
d9a16ff
add bogus auth credentials
guipenedo Jan 18, 2024
e6c2369
explicitly remove directories from list_files
guipenedo Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions src/datatrove/datafolder.py
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import IO, TypeAlias

from fsspec import AbstractFileSystem
from fsspec import open as fsspec_open
from fsspec.core import url_to_fs
from fsspec.implementations.dirfs import DirFileSystem
from fsspec.implementations.local import LocalFileSystem


class OutputFileManager:
def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"):
self.fs = fs
self.mode = mode
self.compression = compression
self._output_files = {}

def get_file(self, filename):
if filename not in self._output_files:
self._output_files[filename] = self.fs.open(filename, mode=self.mode, compression=self.compression)
return self._output_files[filename]

def write(self, filename, data):
self.get_file(filename).write(data)

def __enter__(self):
return self

def close(self):
for file in self._output_files.values():
file.close()
self._output_files.clear()

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


class DataFolder(DirFileSystem):
def __init__(
self,
path: str,
fs: AbstractFileSystem | None = None,
recursive: bool = True,
glob: str | None = None,
**storage_options,
):
"""
`recursive` and `glob` will only affect the listing operation
Args:
path:
fs:
recursive:
glob:
**storage_options:
"""
super().__init__(path=path, fs=fs if fs else url_to_fs(path, **storage_options)[0])
self.recursive = recursive
self.pattern = glob
if not self.isdir("/"):
self.mkdirs("/", exist_ok=True)
guipenedo marked this conversation as resolved.
Show resolved Hide resolved

def list_files(self, suffix: str = "", extension: str | list[str] = None) -> list[str]:
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
if extension and isinstance(extension, str):
extension = [extension]
return [
f
for f in (
self.find(suffix, maxdepth=0 if not self.recursive else None)
if not self.pattern
else self.glob(self.fs.sep.join([self.pattern, suffix]), maxdepth=0 if not self.recursive else None)
)
if not extension or any(f.endswith(ext) for ext in extension)
]

def get_shard(self, rank: int, world_size: int, **kwargs) -> list[str]:
"""Fetch a shard (set of files) for a given rank, assuming there are a total of `world_size` shards.
This should be deterministic to not have any overlap among different ranks.

Args:
rank: rank of the shard to fetch
world_size: total number of shards
other parameters will be passed to list_files

Returns:
a list of file paths
"""
return self.list_files(**kwargs)[rank::world_size]

def unstrip_protocol(self, name: str) -> str:
if isinstance(self.fs, LocalFileSystem):
return name
return super().unstrip_protocol(name)

def to_absolute_paths(self, paths) -> list[str] | str:
if isinstance(paths, str):
return self.unstrip_protocol(self._join(paths))
return list(map(self.unstrip_protocol, self._join(paths)))
guipenedo marked this conversation as resolved.
Show resolved Hide resolved

def get_outputfile_manager(self, **kwargs) -> OutputFileManager:
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
return OutputFileManager(self, **kwargs)

def bulk_open_files(self, paths, mode="rb", **kwargs):
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
return [self.open(path, mode=mode, **kwargs) for path in paths]


def get_datafolder(data: DataFolder | str) -> DataFolder | None:
if data is None:
return None
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(data, str):
return DataFolder(data)
if isinstance(data, DataFolder):
return data
if isinstance(data, tuple) and list(map(type, data)) == [str, dict]:
return DataFolder(data[0], **data[1])
raise ValueError("You must pass a DataFolder or a str path")


def get_file(file: IO | str, mode="rt", **kwargs):
if isinstance(file, str):
return fsspec_open(file, mode, **kwargs)
return file


ParsableDataFolder: TypeAlias = str | tuple[str, dict] | DataFolder
guipenedo marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 12 additions & 16 deletions src/datatrove/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from loguru import logger

from datatrove.io import BaseOutputDataFolder, LocalOutputDataFolder
from datatrove.datafolder import ParsableDataFolder, get_datafolder
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.logging import add_task_logger, close_task_logger, get_random_str, get_timestamp, log_pipeline
from datatrove.utils.stats import PipelineStats
Expand All @@ -18,7 +18,7 @@ class PipelineExecutor(ABC):
def __init__(
self,
pipeline: list[PipelineStep | Callable],
logging_dir: BaseOutputDataFolder = None,
logging_dir: ParsableDataFolder = None,
skip_completed: bool = True,
):
"""
Expand All @@ -31,11 +31,7 @@ def __init__(
previous runs. default: True
"""
self.pipeline: list[PipelineStep | Callable] = pipeline
if isinstance(logging_dir, str):
logging_dir = BaseOutputDataFolder.from_path(logging_dir)
self.logging_dir = (
logging_dir if logging_dir else LocalOutputDataFolder(f"logs/{get_timestamp()}_{get_random_str()}")
)
self.logging_dir = get_datafolder(logging_dir if logging_dir else f"logs/{get_timestamp()}_{get_random_str()}")
self.skip_completed = skip_completed

@abstractmethod
Expand All @@ -51,7 +47,7 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
if self.is_rank_completed(rank):
logger.info(f"Skipping {rank=} as it has already been completed.")
return PipelineStats()
add_task_logger(self.logging_dir, rank, local_rank)
logfile = add_task_logger(self.logging_dir, rank, local_rank)
log_pipeline(self.pipeline)
try:
# pipe data from one step to the next
Expand All @@ -70,27 +66,27 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:

# stats
stats = PipelineStats(self.pipeline)
stats.save_to_disk(self.logging_dir.open(f"stats/{rank:05d}.json"))
self.logging_dir.makedirs("stats", exist_ok=True)
with self.logging_dir.open(f"stats/{rank:05d}.json", "w") as f:
stats.save_to_disk(f)
logger.info(stats.get_repr(f"Task {rank}"))
# completed
self.mark_rank_as_completed(rank)
except Exception as e:
logger.exception(e)
raise e
finally:
close_task_logger(self.logging_dir, rank)
close_task_logger(logfile)
return stats

def is_rank_completed(self, rank: int):
return self.skip_completed and self.logging_dir.to_input_folder().file_exists(f"completions/{rank:05d}")
return self.skip_completed and self.logging_dir.isfile(f"completions/{rank:05d}")

def mark_rank_as_completed(self, rank: int):
self.logging_dir.open(f"completions/{rank:05d}").close()
self.logging_dir.touch(f"completions/{rank:05d}")

def get_incomplete_ranks(self):
completed = {
file.relative_path for file in self.logging_dir.to_input_folder().list_files(suffix="completions")
}
completed = set(self.logging_dir.list_files("completions"))
return list(
filter(
lambda rank: not self.skip_completed or f"completions/{rank:05d}" not in completed,
Expand All @@ -104,7 +100,7 @@ def to_json(self, indent=4):
return json.dumps(data, indent=indent)

def save_executor_as_json(self, indent: int = 4):
with self.logging_dir.open("executor.json") as f:
with self.logging_dir.open("executor.json", "w") as f:
json.dump(self, f, cls=ExecutorJSONEncoder, indent=indent)


Expand Down
38 changes: 10 additions & 28 deletions src/datatrove/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
from multiprocessing import Value
from typing import Callable

import multiprocess.pool
from loguru import logger
from multiprocess import Queue, Semaphore
from multiprocess import Pool, Queue

from datatrove.datafolder import ParsableDataFolder
from datatrove.executor.base import PipelineExecutor
from datatrove.io import BaseOutputDataFolder
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.stats import PipelineStats


# multiprocessing vars
download_semaphore, upload_semaphore, ranks_queue, completed = None, None, None, None
ranks_queue, completed = None, None


def init_pool_processes(dl_sem, up_sem, ranks_q, completed_counter):
global download_semaphore, upload_semaphore, ranks_queue, completed
download_semaphore = dl_sem
upload_semaphore = up_sem
def init_pool_processes(ranks_q, completed_counter):
global ranks_queue, completed
ranks_queue = ranks_q
completed = completed_counter

Expand All @@ -30,9 +27,7 @@ def __init__(
pipeline: list[PipelineStep | Callable],
tasks: int = 1,
workers: int = -1,
max_concurrent_uploads: int = 20,
max_concurrent_downloads: int = 50,
logging_dir: BaseOutputDataFolder | str = None,
logging_dir: ParsableDataFolder = None,
skip_completed: bool = True,
):
"""Execute a pipeline locally
Expand All @@ -44,10 +39,6 @@ def __init__(
tasks: total number of tasks to run the pipeline on
workers: how many tasks to run simultaneously. -1 for no
limit
max_concurrent_uploads: limit the number of files that may
be uploaded simultaneously to avoid rate limits
max_concurrent_downloads: limit the number of files that may
be downloaded simultaneously to avoid rate limits
logging_dir: where to save logs, stats, etc. Should be an
OutputDataFolder or a str. If str, BaseOutputDataFolder.from_path(value) will be used to convert it
skip_completed: whether to skip tasks that were completed in
Expand All @@ -56,14 +47,8 @@ def __init__(
super().__init__(pipeline, logging_dir, skip_completed)
self.tasks = tasks
self.workers = workers if workers != -1 else tasks
self.max_concurrent_uploads = max_concurrent_uploads
self.max_concurrent_downloads = max_concurrent_downloads

def _run_for_rank(self, rank: int, local_rank: int = -1):
if self.workers > 1:
for pipeline_step in self.pipeline:
if isinstance(pipeline_step, PipelineStep):
pipeline_step.set_up_dl_locks(download_semaphore, upload_semaphore)
local_rank = ranks_queue.get()
try:
return super()._run_for_rank(rank, local_rank)
Expand Down Expand Up @@ -97,19 +82,16 @@ def run(self):
self.pipeline = deepcopy(pipeline)
stats.append(self._run_for_rank(rank))
else:
dl_sem = Semaphore(self.max_concurrent_downloads)
up_sem = Semaphore(self.max_concurrent_uploads)
completed_counter = Value("i", skipped)

with multiprocess.Pool(
self.workers, initializer=init_pool_processes, initargs=(dl_sem, up_sem, ranks_q, completed_counter)
) as pool:
with Pool(self.workers, initializer=init_pool_processes, initargs=(ranks_q, completed_counter)) as pool:
stats = list(pool.imap_unordered(self._run_for_rank, ranks_to_run))
# merged stats
stats = sum(stats, start=PipelineStats())
stats.save_to_disk(self.logging_dir.open("stats.json"))
self.logging_dir.makedirs("stats", exist_ok=True)
with self.logging_dir.open("stats.json", "wt") as statsfile:
stats.save_to_disk(statsfile)
logger.success(stats.get_repr(f"All {self.tasks} tasks"))
self.logging_dir.close()
return stats

@property
Expand Down
Loading