Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds writer adapter #83

Merged
merged 1 commit into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
import dataclasses
from abc import ABC, abstractmethod
from string import Template
from typing import Callable

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, get_datafolder
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.typeshelper import StatHints


def _default_adapter(document: Document) -> dict:
"""
You can create your own adapter that returns a dictionary in your preferred format
Args:
document: document to format

Returns: a dictionary to write to disk

"""
return {key: val for key, val in dataclasses.asdict(document).items() if val}


class DiskWriter(PipelineStep, ABC):
default_output_filename: str = None
type = "💽 - WRITER"

def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "infer"):
def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = "infer",
adapter: Callable = None,
):
super().__init__()
self.compression = compression
self.output_folder = get_datafolder(output_folder)
Expand All @@ -20,6 +40,7 @@ def __init__(self, output_folder: DataFolderLike, output_filename: str = None, c
output_filename += ".gz"
self.output_filename = Template(output_filename)
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
self.adapter = adapter if adapter else _default_adapter

def __enter__(self):
return self
Expand All @@ -36,12 +57,12 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
)

@abstractmethod
def _write(self, document: Document, file_handler):
def _write(self, document: dict, file_handler):
raise NotImplementedError

def write(self, document: Document, rank: int = 0, **kwargs):
output_filename = self._get_output_filename(document, rank, **kwargs)
self._write(document, self.output_mg.get_file(output_filename))
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
self.stat_update(StatHints.total)
self.update_doc_stats(document)
Expand Down
21 changes: 11 additions & 10 deletions src/datatrove/pipeline/writers/jsonl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dataclasses
import json
from typing import IO
from typing import IO, Callable

from datatrove.data import Document
from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter

Expand All @@ -11,11 +9,14 @@ class JsonlWriter(DiskWriter):
default_output_filename: str = "${rank}.jsonl"
name = "🐿 Jsonl"

def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "gzip"):
super().__init__(output_folder, output_filename=output_filename, compression=compression)
def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = "gzip",
adapter: Callable = None,
):
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)

def _write(self, document: Document, file: IO):
file.write(
json.dumps({key: val for key, val in dataclasses.asdict(document).items() if val}, ensure_ascii=False)
+ "\n"
)
def _write(self, document: dict, file: IO):
file.write(json.dumps(document, ensure_ascii=False) + "\n")
Loading