Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recursive was not taken into account in fsspec #38

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datatrove/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseInputDataFile, BaseInputDataFolder, BaseOutputDataFile, BaseOutputDataFolder
from .fsspec import FSSpecInputDataFile, FSSpecInputDataFolder, FSSpecOutputDataFile, FSSpecOutputDataFolder
from .local import LocalInputDataFile, LocalInputDataFolder, LocalOutputDataFile, LocalOutputDataFolder
from .s3 import S3InputDataFile, S3InputDataFolder, S3OutputDataFile, S3OutputDataFolder
from .s3 import S3InputDataFile, S3InputDataFolder, S3InputFileListFolder, S3OutputDataFile, S3OutputDataFolder
112 changes: 112 additions & 0 deletions src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,118 @@ def to_output_folder(self) -> "BaseOutputDataFolder":
raise NotImplementedError


@dataclass
class BaseInputFileListFolder(ABC):
"""Base input data folder class. Specific implementations should override its relevant methods.

Args:
file_list (list(str)): list of files, respecting the format of specific implementations
(i.e. s3 paths should start with s3://) – all files have to be in the same file system
"""

file_list: list[str]
_lock: SemLock | contextlib.nullcontext = field(default_factory=contextlib.nullcontext)

@classmethod
def from_path(cls, file_list, **kwargs):
"""InputDataFolder factory: get the correct instance from a path. Additional arguments will be passed to the
constructor of the matched implementation.

Args:
file_list: list of files
**kwargs: any additional argument passed to the matched
implementation

Returns:

"""
from datatrove.io import S3InputFileListFolder

if len(file_list) == 0:
raise ValueError("file_list must contain at least one file")
if file_list[0].startswith("s3://"):
return S3InputFileListFolder(file_list, **kwargs)
else:
raise NotImplementedError

@abstractmethod
def list_files(self) -> list[BaseInputDataFile]:
"""Retrieves a list of InputDataFile for all files in this folder matching both the dataclass properties and the
function parameters.

Returns:

"""
logger.error(
"Do not instantiate BaseInputFileListFolder directly, "
"use a S3InputFileListFolder or call"
"BaseInputFileListFolder.from_path(path)"
)
raise NotImplementedError

def set_lock(self, lock):
"""Pass a synchronization primitive to limit concurrent downloads

Args:
lock: synchronization primitive (Semaphore, Lock)

Returns:

"""
self._lock = lock

def get_files_shard(self, rank: int, world_size: int) -> list[BaseInputDataFile]:
"""Fetch a shard (set of files) for a given rank, assuming there are a total of `world_size` shards.
This should be deterministic to not have any overlap among different ranks.

Args:
rank: rank of the shard to fetch
world_size: total number of shards

Returns:
a list of input files
"""
return self.list_files()[rank::world_size]

def get_file(self, file_path: str) -> BaseInputDataFile | None:
"""Get a file directly by its name.

Args:
file_path: The file name/path.

Returns:
an input file if it exists or None
"""
if self.file_exists(file_path):
return self._unchecked_get_file(file_path)

@abstractmethod
def _unchecked_get_file(self, file_path: str) -> BaseInputDataFile:
"""Get a file directly by its name, without checking if it exists.
Subclasses should override this method (and not the one above).
Should not be called directly. Instead, call `get_file`

Args:
file_path: The file name/path.

Returns:
an input file
"""
raise NotImplementedError

@abstractmethod
def file_exists(self, file_path: str) -> bool:
"""Should be overriden by subclasses. Check if a given file exists.

Args:
file_path: The file name/path.

Returns:
if the file exists
"""
return True


def get_file_extension(filepath, depth=None) -> str:
"""Get full file extension (example: .jsonl.gz)
Optionally only get last `depth` extensions (depth=1: .gz)
Expand Down
18 changes: 13 additions & 5 deletions src/datatrove/io/fsspec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import os.path
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -73,11 +74,18 @@ def __post_init__(self):
self._fs = fsspec.filesystem(protocol, **(self.storage_options if self.storage_options else {}))

def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[BaseInputDataFile]:
return [
self._unchecked_get_file(os.path.relpath(path, self.path))
for path in self._fs.ls(os.path.join(self.path, suffix), detail=False)
if self._match_file(path, extension)
]
if self.recursive:
return [
self._unchecked_get_file(os.path.relpath(path, self.path))
for path in self._fs.glob(os.path.join(self.path, "**"), detail=False)
if self._match_file(os.path.relpath(path, self.path), extension)
]
else:
return [
self._unchecked_get_file(os.path.relpath(path, self.path))
for path in self._fs.ls(os.path.join(self.path, suffix), detail=False)
if self._match_file(path, extension)
]

def _unchecked_get_file(self, relative_path: str):
return FSSpecInputDataFile(
Expand Down
45 changes: 44 additions & 1 deletion src/datatrove/io/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger

from datatrove.io import BaseInputDataFile, BaseOutputDataFolder
from datatrove.io.base import BaseInputDataFolder, BaseOutputDataFile
from datatrove.io.base import BaseInputDataFolder, BaseInputFileListFolder, BaseOutputDataFile
from datatrove.io.utils.s3 import (
s3_download_file,
s3_file_exists,
Expand Down Expand Up @@ -177,3 +177,46 @@ def list_files(self, extension: str | list[str] = None, suffix: str = "") -> lis

def to_output_folder(self) -> BaseOutputDataFolder:
return S3OutputDataFolder(path=self.path, local_path=self.local_path, cleanup=self.cleanup)


@dataclass
class S3InputFileListFolder(BaseInputFileListFolder):
"""S3 input file list folder
Args:
local_path (str): where to download the files to (if `stream=False`)
stream (bool): stream the file directly from s3, without saving to disk
cleanup (str): remove downloaded files from disk after they are closed

"""

local_path: str = None
stream: bool = None
cleanup: bool = True

def __post_init__(self):
if not self.file_list[0].startswith("s3://"):
raise ValueError("S3InputFileListFolder pathes must start with s3://")
self._tmpdir = None
if self.stream is None:
self.stream = self.local_path is None
if not self.local_path:
self._tmpdir = tempfile.TemporaryDirectory()
self.local_path = self._tmpdir.name

def _unchecked_get_file(self, file_path: str):
return S3InputDataFile(
path=file_path,
local_path=os.path.join(self.local_path, file_path.replace("s3://", "")),
relative_path=file_path.replace("s3://", ""),
stream=self.stream,
cleanup=self.cleanup,
_lock=self._lock,
)

def file_exists(self, file_path: str) -> bool:
if self.local_path and os.path.exists(os.path.join(self.local_path, file_path.replace("s3://", ""))):
return True
return s3_file_exists(file_path)

def list_files(self) -> list[BaseInputDataFile]:
return [self._unchecked_get_file(file_path) for file_path in self.file_list]
14 changes: 6 additions & 8 deletions src/datatrove/io/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,12 @@ def s3_get_file_list(cloud_path, match_pattern=None, recursive=True):
prefix = prefixes.popleft()
for resp in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
if recursive:
prefixes.extend(
[
next_prefix["Prefix"]
for next_prefix in resp.get("CommonPrefixes", [])
if _match_prefix(main_prefix, next_prefix["Prefix"], match_pattern)
]
)
prefixes.extend([next_prefix["Prefix"] for next_prefix in resp.get("CommonPrefixes", [])])
objects.extend(
[os.path.relpath(x["Key"], main_prefix) for x in resp.get("Contents", []) if x["Key"] != prefix]
[
os.path.relpath(x["Key"], main_prefix)
for x in resp.get("Contents", [])
if x["Key"] != prefix and _match_prefix(main_prefix, x["Key"], match_pattern)
]
)
return sorted(objects)
7 changes: 4 additions & 3 deletions src/datatrove/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def stat_update(self, *labels, value: int = 1, unit: str = None):
self.stats[label].update(value, unit)

def update_doc_stats(self, document: Document):
self.stats["doc_len"] += len(document.content)
if token_count := document.metadata.get("token_count", None):
self.stats["doc_len_tokens"] += token_count
if document is not None:
self.stats["doc_len"] += len(document.content)
if token_count := document.metadata.get("token_count", None):
self.stats["doc_len_tokens"] += token_count

def track_time(self, unit: str = None):
if unit:
Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .c4_quality_filter import C4ParagraphFilter, C4QualityFilter
from .gopher_quality_filter import GopherQualityFilter
from .gopher_repetition_filter import GopherRepetitionFilter
from .lambda_filter import LambdaFilter
Expand Down
39 changes: 39 additions & 0 deletions src/datatrove/pipeline/filters/c4_quality_filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import heapq
import re

from nltk import load
Expand Down Expand Up @@ -61,3 +62,41 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]:
# sent_dedup's remove_dup_sentences may be adapted for this
lines = [line for line in lines if self.line_filter(line)]
doc.content = " ".join(lines)


class C4ParagraphFilter(BaseFilter):
"""Applies paragraph filtering from mC4

https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/text/c4_utils.py#L551
"""

name = "⛰ C4 Paragraph"

def __init__(self, exclusion_writer: DiskWriter = None):
super().__init__(exclusion_writer)

self.min_paragraphs = 3
self.min_paragraph_len = 200
self.line_delimiter = "\n"

def paragraph_filter(self, page):
"""Returns False iff a page has too few or too short paragraphs."""
lines = page.split(self.line_delimiter)
# Filter out docs that don't have at least three "paragraphs"
# (lines >= `min_paragraph_len` chars).
if (
len(lines) < self.min_paragraphs
or min(heapq.nlargest(3, [len(l) for l in lines])) < self.min_paragraph_len
):
return False
return True

def filter(self, doc: Document) -> bool | tuple[bool, str]:
"""Args:
doc

Returns:

"""
if not self.paragraph_filter(doc.content):
return False, f"< {self.min_paragraphs} paragraphs"
1 change: 1 addition & 0 deletions src/datatrove/pipeline/readers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import BaseReader
from .csv import CSVReader
from .jsonl import JsonlReader
from .parquet import ParquetReader
Expand Down
10 changes: 6 additions & 4 deletions src/datatrove/pipeline/readers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_document_from_dict(self, data: dict, source_file: BaseInputDataFile, id_
if not parsed_data.get("content", None):
if not self._empty_warning:
self._empty_warning = True
logger.warning("Found document without content, skipping.")
return None
logger.warning("Found document without content.")
# return None # This is not a good idea anymore, because it break the pipeline
document = Document(**parsed_data)
document.metadata.setdefault("file_path", source_file.path)
if self.default_metadata:
Expand All @@ -65,7 +65,8 @@ def read_files_shard(self, shard):
li = 0
with tqdm(total=self.limit if self.limit != -1 else None) if self.progress else nullcontext() as pbar:
for datafile in shard:
self.stat_update("input_files")
self.stats["input_files"] += 1
# self.stat_update("input_files")
logger.info(f"Reading input file {datafile.path}")
di = 0
for di, document in enumerate(self.read_file(datafile)):
Expand All @@ -75,7 +76,8 @@ def read_files_shard(self, shard):
if self.progress:
pbar.update()
li += 1
self.stat_update("documents", value=di, unit="input_file")
self.stats["documents"] += di
# self.stat_update("documents", value=di, unit="input_file")
if self.limit != -1 and li >= self.limit:
break

Expand Down
Loading