From eca3573abebd8144ed7d6b49de593316eff8cd0f Mon Sep 17 00:00:00 2001 From: guipenedo Date: Fri, 2 Feb 2024 17:43:50 +0100 Subject: [PATCH] add writer adapter --- src/datatrove/pipeline/writers/disk_base.py | 27 ++++++++++++++++++--- src/datatrove/pipeline/writers/jsonl.py | 21 ++++++++-------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/datatrove/pipeline/writers/disk_base.py b/src/datatrove/pipeline/writers/disk_base.py index 60809d92..09f583d8 100644 --- a/src/datatrove/pipeline/writers/disk_base.py +++ b/src/datatrove/pipeline/writers/disk_base.py @@ -1,5 +1,7 @@ +import dataclasses from abc import ABC, abstractmethod from string import Template +from typing import Callable from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder @@ -7,11 +9,29 @@ from datatrove.utils.typeshelper import StatHints +def _default_adapter(document: Document) -> dict: + """ + You can create your own adapter that returns a dictionary in your preferred format + Args: + document: document to format + + Returns: a dictionary to write to disk + + """ + return {key: val for key, val in dataclasses.asdict(document).items() if val} + + class DiskWriter(PipelineStep, ABC): default_output_filename: str = None type = "💽 - WRITER" - def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "infer"): + def __init__( + self, + output_folder: DataFolderLike, + output_filename: str = None, + compression: str | None = "infer", + adapter: Callable = None, + ): super().__init__() self.compression = compression self.output_folder = get_datafolder(output_folder) @@ -20,6 +40,7 @@ def __init__(self, output_folder: DataFolderLike, output_filename: str = None, c output_filename += ".gz" self.output_filename = Template(output_filename) self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression) + self.adapter = adapter if adapter else _default_adapter def __enter__(self): return self @@ -36,12 +57,12 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs ) @abstractmethod - def _write(self, document: Document, file_handler): + def _write(self, document: dict, file_handler): raise NotImplementedError def write(self, document: Document, rank: int = 0, **kwargs): output_filename = self._get_output_filename(document, rank, **kwargs) - self._write(document, self.output_mg.get_file(output_filename)) + self._write(self.adapter(document), self.output_mg.get_file(output_filename)) self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs)) self.stat_update(StatHints.total) self.update_doc_stats(document) diff --git a/src/datatrove/pipeline/writers/jsonl.py b/src/datatrove/pipeline/writers/jsonl.py index bd5da835..10fff3bf 100644 --- a/src/datatrove/pipeline/writers/jsonl.py +++ b/src/datatrove/pipeline/writers/jsonl.py @@ -1,8 +1,6 @@ -import dataclasses import json -from typing import IO +from typing import IO, Callable -from datatrove.data import Document from datatrove.io import DataFolderLike from datatrove.pipeline.writers.disk_base import DiskWriter @@ -11,11 +9,14 @@ class JsonlWriter(DiskWriter): default_output_filename: str = "${rank}.jsonl" name = "🐿 Jsonl" - def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "gzip"): - super().__init__(output_folder, output_filename=output_filename, compression=compression) + def __init__( + self, + output_folder: DataFolderLike, + output_filename: str = None, + compression: str | None = "gzip", + adapter: Callable = None, + ): + super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter) - def _write(self, document: Document, file: IO): - file.write( - json.dumps({key: val for key, val in dataclasses.asdict(document).items() if val}, ensure_ascii=False) - + "\n" - ) + def _write(self, document: dict, file: IO): + file.write(json.dumps(document, ensure_ascii=False) + "\n")