From baced94f539d829a3afe6e635e5f4aa41789f474 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Wed, 6 Dec 2023 23:44:47 +0000 Subject: [PATCH 1/3] recursive was not taken into account in fsspec --- src/datatrove/io/fsspec.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/datatrove/io/fsspec.py b/src/datatrove/io/fsspec.py index 16d971a8..a95c1db1 100644 --- a/src/datatrove/io/fsspec.py +++ b/src/datatrove/io/fsspec.py @@ -73,11 +73,18 @@ def __post_init__(self): self._fs = fsspec.filesystem(protocol, **(self.storage_options if self.storage_options else {})) def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[BaseInputDataFile]: - return [ - self._unchecked_get_file(os.path.relpath(path, self.path)) - for path in self._fs.ls(os.path.join(self.path, suffix), detail=False) - if self._match_file(path, extension) - ] + if self.recursive: + return [ + self._unchecked_get_file(os.path.relpath(path, self.path)) + for path in self._fs.glob(os.path.join(self.path, "**"), detail=False) + if self._match_file(os.path.relpath(path, self.path), extension) + ] + else: + return [ + self._unchecked_get_file(os.path.relpath(path, self.path)) + for path in self._fs.ls(os.path.join(self.path, suffix), detail=False) + if self._match_file(path, extension) + ] def _unchecked_get_file(self, relative_path: str): return FSSpecInputDataFile( From 126a0b9dafbfe65d7ccdeb2f7b81e186c87d86a4 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Sun, 10 Dec 2023 01:44:44 +0000 Subject: [PATCH 2/3] updates --- src/datatrove/io/__init__.py | 2 +- src/datatrove/io/base.py | 112 +++++++++++++++++++++ src/datatrove/io/fsspec.py | 1 + src/datatrove/io/s3.py | 45 ++++++++- src/datatrove/io/utils/s3.py | 14 ++- src/datatrove/pipeline/base.py | 7 +- src/datatrove/pipeline/readers/__init__.py | 1 + src/datatrove/pipeline/readers/base.py | 4 +- src/datatrove/utils/stats.py | 1 + 9 files changed, 172 insertions(+), 15 deletions(-) diff --git a/src/datatrove/io/__init__.py b/src/datatrove/io/__init__.py index 0ba94486..bd679213 100644 --- a/src/datatrove/io/__init__.py +++ b/src/datatrove/io/__init__.py @@ -1,4 +1,4 @@ from .base import BaseInputDataFile, BaseInputDataFolder, BaseOutputDataFile, BaseOutputDataFolder from .fsspec import FSSpecInputDataFile, FSSpecInputDataFolder, FSSpecOutputDataFile, FSSpecOutputDataFolder from .local import LocalInputDataFile, LocalInputDataFolder, LocalOutputDataFile, LocalOutputDataFolder -from .s3 import S3InputDataFile, S3InputDataFolder, S3OutputDataFile, S3OutputDataFolder +from .s3 import S3InputDataFile, S3InputDataFolder, S3InputFileListFolder, S3OutputDataFile, S3OutputDataFolder diff --git a/src/datatrove/io/base.py b/src/datatrove/io/base.py index 2155fa89..94689256 100644 --- a/src/datatrove/io/base.py +++ b/src/datatrove/io/base.py @@ -280,6 +280,118 @@ def to_output_folder(self) -> "BaseOutputDataFolder": raise NotImplementedError +@dataclass +class BaseInputFileListFolder(ABC): + """Base input data folder class. Specific implementations should override its relevant methods. + + Args: + file_list (list(str)): list of files, respecting the format of specific implementations + (i.e. s3 paths should start with s3://) – all files have to be in the same file system + """ + + file_list: list[str] + _lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext) + + @classmethod + def from_path(cls, file_list, **kwargs): + """InputDataFolder factory: get the correct instance from a path. Additional arguments will be passed to the + constructor of the matched implementation. + + Args: + file_list: list of files + **kwargs: any additional argument passed to the matched + implementation + + Returns: + + """ + from datatrove.io import S3InputFileListFolder + + if len(file_list) == 0: + raise ValueError("file_list must contain at least one file") + if file_list[0].startswith("s3://"): + return S3InputFileListFolder(file_list, **kwargs) + else: + raise NotImplementedError + + @abstractmethod + def list_files(self) -> list[BaseInputDataFile]: + """Retrieves a list of InputDataFile for all files in this folder matching both the dataclass properties and the + function parameters. + + Returns: + + """ + logger.error( + "Do not instantiate BaseInputFileListFolder directly, " + "use a S3InputFileListFolder or call" + "BaseInputFileListFolder.from_path(path)" + ) + raise NotImplementedError + + def set_lock(self, lock): + """Pass a synchronization primitive to limit concurrent downloads + + Args: + lock: synchronization primitive (Semaphore, Lock) + + Returns: + + """ + self._lock = lock + + def get_files_shard(self, rank: int, world_size: int) -> list[BaseInputDataFile]: + """Fetch a shard (set of files) for a given rank, assuming there are a total of `world_size` shards. + This should be deterministic to not have any overlap among different ranks. + + Args: + rank: rank of the shard to fetch + world_size: total number of shards + + Returns: + a list of input files + """ + return self.list_files()[rank::world_size] + + def get_file(self, file_path: str) -> BaseInputDataFile | None: + """Get a file directly by its name. + + Args: + file_path: The file name/path. + + Returns: + an input file if it exists or None + """ + if self.file_exists(file_path): + return self._unchecked_get_file(file_path) + + @abstractmethod + def _unchecked_get_file(self, file_path: str) -> BaseInputDataFile: + """Get a file directly by its name, without checking if it exists. + Subclasses should override this method (and not the one above). + Should not be called directly. Instead, call `get_file` + + Args: + file_path: The file name/path. + + Returns: + an input file + """ + raise NotImplementedError + + @abstractmethod + def file_exists(self, file_path: str) -> bool: + """Should be overriden by subclasses. Check if a given file exists. + + Args: + file_path: The file name/path. + + Returns: + if the file exists + """ + return True + + def get_file_extension(filepath, depth=None) -> str: """Get full file extension (example: .jsonl.gz) Optionally only get last `depth` extensions (depth=1: .gz) diff --git a/src/datatrove/io/fsspec.py b/src/datatrove/io/fsspec.py index a95c1db1..55871377 100644 --- a/src/datatrove/io/fsspec.py +++ b/src/datatrove/io/fsspec.py @@ -1,3 +1,4 @@ +import os import os.path from contextlib import contextmanager from dataclasses import dataclass diff --git a/src/datatrove/io/s3.py b/src/datatrove/io/s3.py index 6bd9703f..39c5926d 100644 --- a/src/datatrove/io/s3.py +++ b/src/datatrove/io/s3.py @@ -7,7 +7,7 @@ from loguru import logger from datatrove.io import BaseInputDataFile, BaseOutputDataFolder -from datatrove.io.base import BaseInputDataFolder, BaseOutputDataFile +from datatrove.io.base import BaseInputDataFolder, BaseInputFileListFolder, BaseOutputDataFile from datatrove.io.utils.s3 import ( s3_download_file, s3_file_exists, @@ -177,3 +177,46 @@ def list_files(self, extension: str | list[str] = None, suffix: str = "") -> lis def to_output_folder(self) -> BaseOutputDataFolder: return S3OutputDataFolder(path=self.path, local_path=self.local_path, cleanup=self.cleanup) + + +@dataclass +class S3InputFileListFolder(BaseInputFileListFolder): + """S3 input file list folder + Args: + local_path (str): where to download the files to (if `stream=False`) + stream (bool): stream the file directly from s3, without saving to disk + cleanup (str): remove downloaded files from disk after they are closed + + """ + + local_path: str = None + stream: bool = None + cleanup: bool = True + + def __post_init__(self): + if not self.file_list[0].startswith("s3://"): + raise ValueError("S3InputFileListFolder pathes must start with s3://") + self._tmpdir = None + if self.stream is None: + self.stream = self.local_path is None + if not self.local_path: + self._tmpdir = tempfile.TemporaryDirectory() + self.local_path = self._tmpdir.name + + def _unchecked_get_file(self, file_path: str): + return S3InputDataFile( + path=file_path, + local_path=os.path.join(self.local_path, file_path.replace("s3://", "")), + relative_path=file_path.replace("s3://", ""), + stream=self.stream, + cleanup=self.cleanup, + _lock=self._lock, + ) + + def file_exists(self, file_path: str) -> bool: + if self.local_path and os.path.exists(os.path.join(self.local_path, file_path.replace("s3://", ""))): + return True + return s3_file_exists(file_path) + + def list_files(self) -> list[BaseInputDataFile]: + return [self._unchecked_get_file(file_path) for file_path in self.file_list] diff --git a/src/datatrove/io/utils/s3.py b/src/datatrove/io/utils/s3.py index a6ed0348..8ec333ed 100644 --- a/src/datatrove/io/utils/s3.py +++ b/src/datatrove/io/utils/s3.py @@ -106,14 +106,12 @@ def s3_get_file_list(cloud_path, match_pattern=None, recursive=True): prefix = prefixes.popleft() for resp in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): if recursive: - prefixes.extend( - [ - next_prefix["Prefix"] - for next_prefix in resp.get("CommonPrefixes", []) - if _match_prefix(main_prefix, next_prefix["Prefix"], match_pattern) - ] - ) + prefixes.extend([next_prefix["Prefix"] for next_prefix in resp.get("CommonPrefixes", [])]) objects.extend( - [os.path.relpath(x["Key"], main_prefix) for x in resp.get("Contents", []) if x["Key"] != prefix] + [ + os.path.relpath(x["Key"], main_prefix) + for x in resp.get("Contents", []) + if x["Key"] != prefix and _match_prefix(main_prefix, x["Key"], match_pattern) + ] ) return sorted(objects) diff --git a/src/datatrove/pipeline/base.py b/src/datatrove/pipeline/base.py index df8cd250..e72d4dc2 100644 --- a/src/datatrove/pipeline/base.py +++ b/src/datatrove/pipeline/base.py @@ -16,9 +16,10 @@ def stat_update(self, *labels, value: int = 1, unit: str = None): self.stats[label].update(value, unit) def update_doc_stats(self, document: Document): - self.stats["doc_len"] += len(document.content) - if token_count := document.metadata.get("token_count", None): - self.stats["doc_len_tokens"] += token_count + if document is not None: + self.stats["doc_len"] += len(document.content) + if token_count := document.metadata.get("token_count", None): + self.stats["doc_len_tokens"] += token_count def track_time(self, unit: str = None): if unit: diff --git a/src/datatrove/pipeline/readers/__init__.py b/src/datatrove/pipeline/readers/__init__.py index 190baabc..75761f0b 100644 --- a/src/datatrove/pipeline/readers/__init__.py +++ b/src/datatrove/pipeline/readers/__init__.py @@ -1,3 +1,4 @@ +from .base import BaseReader from .csv import CSVReader from .jsonl import JsonlReader from .parquet import ParquetReader diff --git a/src/datatrove/pipeline/readers/base.py b/src/datatrove/pipeline/readers/base.py index 71bc63f0..11ddb7a9 100644 --- a/src/datatrove/pipeline/readers/base.py +++ b/src/datatrove/pipeline/readers/base.py @@ -46,8 +46,8 @@ def get_document_from_dict(self, data: dict, source_file: BaseInputDataFile, id_ if not parsed_data.get("content", None): if not self._empty_warning: self._empty_warning = True - logger.warning("Found document without content, skipping.") - return None + logger.warning("Found document without content.") + # return None # This is not a good idea anymore, because it break the pipeline document = Document(**parsed_data) document.metadata.setdefault("file_path", source_file.path) if self.default_metadata: diff --git a/src/datatrove/utils/stats.py b/src/datatrove/utils/stats.py index 32551dfb..f245088c 100644 --- a/src/datatrove/utils/stats.py +++ b/src/datatrove/utils/stats.py @@ -39,6 +39,7 @@ def __repr__(self): return ", ".join(f"{key}: {stats}" for key, stats in self.items()) def to_dict(self): + print(f"to_dict {self} in {list(self.items())}") return {a: b.to_dict() for a, b in self.items()} From 4f5cd009f79352c63c99b912a2e131caf90550a0 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Sun, 17 Dec 2023 09:16:47 +0000 Subject: [PATCH 3/3] batched tokenization --- src/datatrove/pipeline/filters/__init__.py | 1 + .../pipeline/filters/c4_quality_filter.py | 39 +++++++++++++++ src/datatrove/pipeline/readers/base.py | 6 ++- src/datatrove/pipeline/tokens/tokenizer.py | 49 +++++++++++++------ 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/datatrove/pipeline/filters/__init__.py b/src/datatrove/pipeline/filters/__init__.py index bbc724e7..664f730e 100644 --- a/src/datatrove/pipeline/filters/__init__.py +++ b/src/datatrove/pipeline/filters/__init__.py @@ -1,3 +1,4 @@ +from .c4_quality_filter import C4ParagraphFilter, C4QualityFilter from .gopher_quality_filter import GopherQualityFilter from .gopher_repetition_filter import GopherRepetitionFilter from .lambda_filter import LambdaFilter diff --git a/src/datatrove/pipeline/filters/c4_quality_filter.py b/src/datatrove/pipeline/filters/c4_quality_filter.py index 9feabc8b..3a6407cd 100644 --- a/src/datatrove/pipeline/filters/c4_quality_filter.py +++ b/src/datatrove/pipeline/filters/c4_quality_filter.py @@ -1,3 +1,4 @@ +import heapq import re from nltk import load @@ -61,3 +62,41 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]: # sent_dedup's remove_dup_sentences may be adapted for this lines = [line for line in lines if self.line_filter(line)] doc.content = " ".join(lines) + + +class C4ParagraphFilter(BaseFilter): + """Applies paragraph filtering from mC4 + + https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/text/c4_utils.py#L551 + """ + + name = "⛰ C4 Paragraph" + + def __init__(self, exclusion_writer: DiskWriter = None): + super().__init__(exclusion_writer) + + self.min_paragraphs = 3 + self.min_paragraph_len = 200 + self.line_delimiter = "\n" + + def paragraph_filter(self, page): + """Returns False iff a page has too few or too short paragraphs.""" + lines = page.split(self.line_delimiter) + # Filter out docs that don't have at least three "paragraphs" + # (lines >= `min_paragraph_len` chars). + if ( + len(lines) < self.min_paragraphs + or min(heapq.nlargest(3, [len(l) for l in lines])) < self.min_paragraph_len + ): + return False + return True + + def filter(self, doc: Document) -> bool | tuple[bool, str]: + """Args: + doc + + Returns: + + """ + if not self.paragraph_filter(doc.content): + return False, f"< {self.min_paragraphs} paragraphs" diff --git a/src/datatrove/pipeline/readers/base.py b/src/datatrove/pipeline/readers/base.py index 11ddb7a9..80da1fb7 100644 --- a/src/datatrove/pipeline/readers/base.py +++ b/src/datatrove/pipeline/readers/base.py @@ -65,7 +65,8 @@ def read_files_shard(self, shard): li = 0 with tqdm(total=self.limit if self.limit != -1 else None) if self.progress else nullcontext() as pbar: for datafile in shard: - self.stat_update("input_files") + self.stats["input_files"] += 1 + # self.stat_update("input_files") logger.info(f"Reading input file {datafile.path}") di = 0 for di, document in enumerate(self.read_file(datafile)): @@ -75,7 +76,8 @@ def read_files_shard(self, shard): if self.progress: pbar.update() li += 1 - self.stat_update("documents", value=di, unit="input_file") + self.stats["documents"] += di + # self.stat_update("documents", value=di, unit="input_file") if self.limit != -1 and li >= self.limit: break diff --git a/src/datatrove/pipeline/tokens/tokenizer.py b/src/datatrove/pipeline/tokens/tokenizer.py index 401ccdd7..a9c11492 100644 --- a/src/datatrove/pipeline/tokens/tokenizer.py +++ b/src/datatrove/pipeline/tokens/tokenizer.py @@ -1,3 +1,4 @@ +import itertools import mmap import struct @@ -19,6 +20,19 @@ TOKENIZERS_INSTALLED = False +def batched(iterable, n): + """In python 3.12+ we could use itertools.batched instead + + One difference with itertools.batched: we return a list instead of a tuple + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := list(itertools.islice(it, n)): + yield batch + + class TokenizedFile: def __init__( self, @@ -125,6 +139,7 @@ def __init__( eos_token: str = "<|endoftext|>", # whether to add the EOS token after each document save_loss_metadata: bool = False, # save the loss information shuffle: bool = True, # whether to shuffle documents in the dataset, + batch_size: int = 1000, # batch size for tokenization seed: int = None, save_final_metadata: bool = True, ): @@ -135,6 +150,7 @@ def __init__( self.eos_token = eos_token self.save_loss_metadata = save_loss_metadata self.shuffle = shuffle + self.batch_size = batch_size self._tokenizer = None self.rand = default_rng(seed) self.save_final_metadata = save_final_metadata @@ -142,7 +158,7 @@ def __init__( def set_up_dl_locks(self, dl_lock, up_lock): self.output_folder.set_lock(up_lock) - def get_loss_values(self, document: Document, encoded): + def get_loss_values(self, document: Document, encoded: Encoding): loss_values = None if self.save_loss_metadata: loss_values = np.ones((len(encoded.ids))) @@ -156,23 +172,26 @@ def get_loss_values(self, document: Document, encoded): loss_values = loss_values[:t_start] return loss_values - def write_unshuffled(self, data, filename): + def write_unshuffled(self, data: DocumentsPipeline, filename: str): unshuff = TokenizedFile( self.output_folder, filename, save_index=not self.shuffle, save_loss_metadata=self.save_loss_metadata ) - # tokenize each document's text and write its tokens sequentially to the output .ds - for document in data: + # tokenize document's text in batches to go faster – we compute loss values independently if needed + for batch in batched(data, self.batch_size): with self.track_time(): - encoded: Encoding = self.tokenizer.encode(document.content) - tokens = encoded.ids - # loss values - loss_values = self.get_loss_values(document, encoded) - if loss_values is not None and len(loss_values) < len(tokens): - tokens = tokens[: len(loss_values)] - # write bytes to disk - unshuff.write(tokens, loss_values) - # save stats - self.stat_update("tokens", value=len(tokens)) + encoded_batch: Encoding = self.tokenizer.encode_batch([document.content for document in batch]) + loss_values_batch = [ + self.get_loss_values(document, encoded) for document, encoded in zip(batch, encoded_batch) + ] + for encoded, loss_values in zip(encoded_batch, loss_values_batch): + with self.track_time(): + tokens = encoded.ids + if loss_values is not None and len(loss_values) < len(tokens): + tokens = tokens[: len(loss_values)] + # write bytes to disk + unshuff.write(tokens, loss_values) + # save stats + self.stat_update("tokens", value=len(tokens)) return unshuff def get_output_filename(self, rank, name): @@ -199,7 +218,7 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do self.output_folder.close() @property - def tokenizer(self): + def tokenizer(self) -> Tokenizer: if not self._tokenizer: if not TOKENIZERS_INSTALLED: logger.error("`tokenizers` is required to run DocumentTokenizer")