Skip to content

Commit

Permalink
add writer adapter (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo authored Feb 3, 2024
1 parent c4291ae commit 76a243d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
27 changes: 24 additions & 3 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
import dataclasses
from abc import ABC, abstractmethod
from string import Template
from typing import Callable

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, get_datafolder
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.typeshelper import StatHints


def _default_adapter(document: Document) -> dict:
"""
You can create your own adapter that returns a dictionary in your preferred format
Args:
document: document to format
Returns: a dictionary to write to disk
"""
return {key: val for key, val in dataclasses.asdict(document).items() if val}


class DiskWriter(PipelineStep, ABC):
default_output_filename: str = None
type = "💽 - WRITER"

def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "infer"):
def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = "infer",
adapter: Callable = None,
):
super().__init__()
self.compression = compression
self.output_folder = get_datafolder(output_folder)
Expand All @@ -20,6 +40,7 @@ def __init__(self, output_folder: DataFolderLike, output_filename: str = None, c
output_filename += ".gz"
self.output_filename = Template(output_filename)
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
self.adapter = adapter if adapter else _default_adapter

def __enter__(self):
return self
Expand All @@ -36,12 +57,12 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
)

@abstractmethod
def _write(self, document: Document, file_handler):
def _write(self, document: dict, file_handler):
raise NotImplementedError

def write(self, document: Document, rank: int = 0, **kwargs):
output_filename = self._get_output_filename(document, rank, **kwargs)
self._write(document, self.output_mg.get_file(output_filename))
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
self.stat_update(StatHints.total)
self.update_doc_stats(document)
Expand Down
21 changes: 11 additions & 10 deletions src/datatrove/pipeline/writers/jsonl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dataclasses
import json
from typing import IO
from typing import IO, Callable

from datatrove.data import Document
from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter

Expand All @@ -11,11 +9,14 @@ class JsonlWriter(DiskWriter):
default_output_filename: str = "${rank}.jsonl"
name = "🐿 Jsonl"

def __init__(self, output_folder: DataFolderLike, output_filename: str = None, compression: str | None = "gzip"):
super().__init__(output_folder, output_filename=output_filename, compression=compression)
def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = "gzip",
adapter: Callable = None,
):
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)

def _write(self, document: Document, file: IO):
file.write(
json.dumps({key: val for key, val in dataclasses.asdict(document).items() if val}, ensure_ascii=False)
+ "\n"
)
def _write(self, document: dict, file: IO):
file.write(json.dumps(document, ensure_ascii=False) + "\n")

0 comments on commit 76a243d

Please sign in to comment.