From 3d14b0d2b501fff9087138a0d0c7f7e12f4daaa7 Mon Sep 17 00:00:00 2001 From: shuoyangd Date: Wed, 27 Nov 2024 17:46:29 -0500 Subject: [PATCH] Add support for parallel data curation (#193) * add data interface to read simple bitext Signed-off-by: Shuoyang Ding * adding ParallelScoreFilter Signed-off-by: Shuoyang Ding * add test for ParallelScoreFilter, small style change for ParallelDataset test, fix a few data and import bugs Signed-off-by: Shuoyang Ding * allow ParallelScoreFilter to take different filters for source and target Signed-off-by: Shuoyang Ding * add JointScoreFilter and LengthRatioFilter Signed-off-by: Shuoyang Ding * [WIP] add heuristic filter w/o test Signed-off-by: Shuoyang Ding * merge with main Signed-off-by: Shuoyang Ding * add test for histogram filter, fix a few bugs Signed-off-by: Shuoyang Ding * length ratio, joint score filter testing Signed-off-by: Shuoyang Ding * fix typing in joint test Signed-off-by: Shuoyang Ding * add a fake comet qe filter as an initial step Signed-off-by: Shuoyang Ding * [WIP] adding bitext cleaning tutorial Signed-off-by: Shuoyang Ding * [WIP] fixing example Signed-off-by: Shuoyang Ding * fix slow histogram filter, fix faulty bitext loading Signed-off-by: Shuoyang Ding * tutorial running Signed-off-by: Shuoyang Ding * [WIP] documentation of bitext tutorial Signed-off-by: Shuoyang Ding * add tested version of comet-qe filter Signed-off-by: Shuoyang Ding * fix ParallelDataset bug where single file name is not accepted, and dataset is sometimes turned into its parent class by mistake, add write to simple bitext functionality, update bitext tutorial Signed-off-by: Shuoyang Ding * add docstring to explain simple bitext format, fix a bug where file extensions are removed twice before writing Signed-off-by: Shuoyang Ding * remove print line for debug Signed-off-by: Shuoyang Ding * add comet filter to tutorial Signed-off-by: Shuoyang Ding * refactor COMET QE filter to decouple model from filter, make sure JointScoreFilter can take more than one fields for source and target Signed-off-by: Shuoyang Ding * use refactored qe filter Signed-off-by: Shuoyang Ding * wrap_qe_input should be a static method Signed-off-by: Shuoyang Ding * use conditional import for comet, formatting changes Signed-off-by: Shuoyang Ding * [WIP] add cometoid Signed-off-by: Shuoyang Ding * [WIP] attempt to resolve device conflict but is failing Signed-off-by: Shuoyang Ding * [WIP] playing with cometoid arguments Signed-off-by: Shuoyang Ding * [WIP] -d 0 doesn't look necessary Signed-off-by: Shuoyang Ding * tested arguments for Cometoid Signed-off-by: Shuoyang Ding * use proper safe import, make sure test doesn't crash sans comet/pymarian Signed-off-by: Shuoyang Ding * falling back to comet for tutorial since that's easier to set up, uppdate README Signed-off-by: Shuoyang Ding * give credit to original fairseq implementation of histogram filtering, run black formatter Signed-off-by: Shuoyang Ding * fix pre-commit complaint Signed-off-by: Shuoyang Ding * fix small bug Signed-off-by: Shuoyang Ding * fix another occurrence of the same bug Signed-off-by: Shuoyang Ding * introduce shard limit to a single PyMarian API call to avoid memory leakage Signed-off-by: Shuoyang Ding * repartition after reading simple bitext data Signed-off-by: Shuoyang Ding * -d 0 is actually needed for pymarian Signed-off-by: Shuoyang Ding * remove duplicate LengthRatioFilter definition Signed-off-by: Shuoyang Ding * refactor repeated code segment in file writing, change classifier to accomodate custom field names, pause doc repartition since it causes problems Signed-off-by: Shuoyang Ding * [WIP] addressed comments in #193 apart from resolving .iloc pattern, test currently failing Signed-off-by: Shuoyang Ding * refactor to resolve .loc pattern, test passing Signed-off-by: Shuoyang Ding * add missing file Signed-off-by: Shuoyang Ding * revert changes in setup.py Signed-off-by: Shuoyang Ding * fix a small bug in parallel dataset, explain why repartition is disabled, fix tutorial Signed-off-by: Shuoyang Ding * add api guide, small change on bitext/parallel score filter docstring Signed-off-by: Shuoyang Ding * fix read_simple_bitext test issues Signed-off-by: Shuoyang Ding * reinstate dependencies lost during merging Signed-off-by: Shuoyang Ding * re-enable multiple partitions for simple bitext, add parallel write Signed-off-by: Shuoyang Ding * take care of the case where filename is not supplied in dataframe, make logic clearer Signed-off-by: Shuoyang Ding * address other minor comments in the PR, fix segment order scrambling Signed-off-by: Shuoyang Ding * fix test errors, add bitext dependencies Signed-off-by: Shuoyang Ding * add back more missing imports Signed-off-by: Shuoyang Ding * add bitext to [all] in .toml, add platformdirs as dependency Signed-off-by: Shuoyang Ding * merge upstream, remove old bitext requirement list Signed-off-by: Shuoyang Ding * delete requirement file again Signed-off-by: Shuoyang Ding --------- Signed-off-by: Shuoyang Ding Co-authored-by: nverma1 --- docs/user-guide/api/datasets.rst | 4 +- docs/user-guide/api/filters.rst | 20 +++ nemo_curator/datasets/__init__.py | 3 +- nemo_curator/datasets/parallel_dataset.py | 167 ++++++++++++++++++ nemo_curator/filters/__init__.py | 13 +- nemo_curator/filters/bitext_filter.py | 145 ++++++++++++++++ nemo_curator/filters/classifier_filter.py | 153 ++++++++++++++++- nemo_curator/filters/doc_filter.py | 13 +- nemo_curator/filters/heuristic_filter.py | 144 +++++++++++++++- nemo_curator/filters/models/__init__.py | 0 nemo_curator/filters/models/qe_models.py | 167 ++++++++++++++++++ nemo_curator/modules/__init__.py | 3 +- nemo_curator/modules/filter.py | 97 +++++++++-- nemo_curator/utils/distributed_utils.py | 113 ++++++++++++- nemo_curator/utils/file_utils.py | 5 + pyproject.toml | 9 + tests/bitext_data/toy.de | 4 + tests/bitext_data/toy.en | 4 + tests/test_filters.py | 195 +++++++++++++++++++++- tests/test_read_simple_bitext.py | 68 ++++++++ tutorials/bitext_cleaning/README.md | 21 +++ tutorials/bitext_cleaning/docbuilder.py | 39 +++++ tutorials/bitext_cleaning/main.py | 133 +++++++++++++++ 23 files changed, 1490 insertions(+), 30 deletions(-) create mode 100644 nemo_curator/datasets/parallel_dataset.py create mode 100644 nemo_curator/filters/bitext_filter.py create mode 100644 nemo_curator/filters/models/__init__.py create mode 100644 nemo_curator/filters/models/qe_models.py create mode 100644 tests/bitext_data/toy.de create mode 100644 tests/bitext_data/toy.en create mode 100644 tests/test_read_simple_bitext.py create mode 100644 tutorials/bitext_cleaning/README.md create mode 100644 tutorials/bitext_cleaning/docbuilder.py create mode 100644 tutorials/bitext_cleaning/main.py diff --git a/docs/user-guide/api/datasets.rst b/docs/user-guide/api/datasets.rst index c8dba791..53b64b65 100644 --- a/docs/user-guide/api/datasets.rst +++ b/docs/user-guide/api/datasets.rst @@ -9,10 +9,12 @@ DocumentDataset .. autoclass:: nemo_curator.datasets.DocumentDataset :members: +.. autoclass:: nemo_curator.datasets.ParallelDataset + :members: ------------------------------- ImageTextPairDataset ------------------------------- .. autoclass:: nemo_curator.datasets.ImageTextPairDataset - :members: \ No newline at end of file + :members: diff --git a/docs/user-guide/api/filters.rst b/docs/user-guide/api/filters.rst index b705a184..55b78ed7 100644 --- a/docs/user-guide/api/filters.rst +++ b/docs/user-guide/api/filters.rst @@ -10,6 +10,10 @@ Base Class :members: :member-order: bysource +.. autoclass:: nemo_curator.filters.BitextFilter + :members: + :member-order: bysource + .. autofunction:: nemo_curator.filters.import_filter ------------------------------ @@ -40,6 +44,14 @@ FastText Filters :members: :member-order: bysource +------------------------------ +Quality Estimation Filters +------------------------------ + +.. autoclass:: nemo_curator.filters.QualityEstimationFilter + :members: + :member-order: bysource + ------------------------------ Heuristic Filters ------------------------------ @@ -132,6 +144,14 @@ Heuristic Filters :members: :member-order: bysource +.. autoclass:: nemo_curator.filters.HistogramFilter + :members: + :member-order: bysource + +.. autoclass:: nemo_curator.filters.LengthRatioFilter + :members: + :member-order: bysource + ------------------------------ Code Filters ------------------------------ diff --git a/nemo_curator/datasets/__init__.py b/nemo_curator/datasets/__init__.py index 16f4343a..38954502 100644 --- a/nemo_curator/datasets/__init__.py +++ b/nemo_curator/datasets/__init__.py @@ -15,9 +15,10 @@ from nemo_curator.utils.import_utils import image_only_import_from from .doc_dataset import DocumentDataset +from .parallel_dataset import ParallelDataset ImageTextPairDataset = image_only_import_from( "nemo_curator.datasets.image_text_pair_dataset", "ImageTextPairDataset" ) -__all__ = ["DocumentDataset", "ImageTextPairDataset"] +__all__ = ["DocumentDataset", "ImageTextPairDataset", "ParallelDataset"] diff --git a/nemo_curator/datasets/parallel_dataset.py b/nemo_curator/datasets/parallel_dataset.py new file mode 100644 index 00000000..280a6b0c --- /dev/null +++ b/nemo_curator/datasets/parallel_dataset.py @@ -0,0 +1,167 @@ +import csv +from typing import List, Optional, Tuple, Union + +import dask.dataframe as dd +import pandas as pd + +from nemo_curator.datasets.doc_dataset import DocumentDataset +from nemo_curator.utils.distributed_utils import write_to_disk +from nemo_curator.utils.file_utils import remove_path_extension +from nemo_curator.utils.import_utils import gpu_only_import + +cudf = gpu_only_import("cudf") + + +class ParallelDataset(DocumentDataset): + """ + An extension of the standard `DocumentDataset` with a special method that loads simple bitext. + + For data with more complicated metadata, please convert your data into jsonl/parquet/pickle format + and use interfaces defined in `DocumentDataset`. + """ + + def persist(self): + return ParallelDataset(self.df.persist()) + + @classmethod + def read_simple_bitext( + cls, + src_input_files: Union[str, List[str]], + tgt_input_files: Union[str, List[str]], + src_lang: str, + tgt_lang: str, + backend: str = "pandas", + add_filename: bool = False, + npartitions: int = 16, + ): + """See `read_single_simple_bitext_file_pair` docstring for what "simple_bitext" means and usage of other parameters. + + Args: + src_input_files (Union[str, List[str]]): one or several input files, in source language + tgt_input_files (Union[str, List[str]]): one or several input files, in target language + + Raises: + TypeError: If types of `src_input_files` and `tgt_input_files` doesn't agree. + + Returns: + ParallelDataset: A `ParallelDataset` object with `self.df` holding the ingested simple bitext. + """ + + if isinstance(src_input_files, str) and isinstance(tgt_input_files, str): + src_input_files = [src_input_files] + tgt_input_files = [tgt_input_files] + elif not isinstance(src_input_files, list) or not isinstance( + tgt_input_files, list + ): + raise TypeError("Both file inputs must be strings or lists.") + + # use default doc id for now + # but in the future it might be useful to allow customizing doc id by passing a prefix + df_files = [] + # We do not use `dd.from_map` because an individual file could be pretty large, + # hence, it's not appropriate to partition based on individual files. + # What we do is that we concatenate all the individual files and perform repartition. + for src_input_file, tgt_input_file in zip(src_input_files, tgt_input_files): + df_file = ParallelDataset.read_single_simple_bitext_file_pair( + (src_input_file, tgt_input_file), + src_lang=src_lang, + tgt_lang=tgt_lang, + backend=backend, + add_filename=add_filename, + ) + df_files.append(df_file) + + if backend == "cudf": + df = cudf + else: + df = pd + + data = dd.from_pandas(df.concat(df_files), npartitions=npartitions) + return cls(data) + + def to_bitext( + self, + output_file_dir, + write_to_filename=False, + ): + """See `nemo_curator.utils.distributed_utils.write_to_disk` docstring for parameter usage.""" + write_to_disk( + df=self.df, + output_file_dir=output_file_dir, + write_to_filename=write_to_filename, + output_type="bitext", + ) + + @staticmethod + def read_single_simple_bitext_file_pair( + input_file_pair: Tuple[str], + src_lang: str, + tgt_lang: str, + doc_id: str = None, + backend: str = "cudf", + add_filename: bool = False, + ) -> Union[dd.DataFrame, "dask_cudf.DataFrame"]: + """This function reads a pair of "simple bitext" files into a pandas DataFrame. + A simple bitext is a commonly data format in machine translation. + It consists of two plain text files with the same number of lines, each line pair being translations of each other. For example: + + data.de: + + ``` + Wir besitzen keine Reisetaschen aus Leder. + Die Firma produziert Computer für den deutschen Markt. + ... + ``` + + data.en: + + ``` + We don't own duffel bags made of leather. + The company produces computers for the German market. + ... + ``` + + For simplicity, we also assume that the names of the two text files have the same prefix, except for different language code at the end as file extensions. + + Args: + input_file_pair (Tuple[str]): A pair of file paths pointing to the input files + src_lang (str): Source language, in ISO-639-1 (two character) format (e.g. 'en') + tgt_lang (str): Target language, in ISO-639-1 (two character) format (e.g. 'en') + doc_id (str, optional): A string document id to assign to every segment in the file. Defaults to None. + backend (str, optional): Backend of the data frame. Defaults to "cudf". + add_filename (bool, optional): Add filename as an extra field to every segment in the file. Defaults to False. + + Returns: + Union[dd.DataFrame, dask_cudf.DataFrame] + """ + src_input_file, tgt_input_file = input_file_pair + assert remove_path_extension(src_input_file) == remove_path_extension( + tgt_input_file + ), f"Assuming source and target filenames would have common prefix before language code, but got {src_input_file} and {tgt_input_file}." + + if not doc_id: + doc_id = "▁".join([src_input_file, tgt_input_file]) + + if backend == "cudf": + df = cudf + else: + df = pd + + df_src = df.read_csv( + src_input_file, names=["src"], sep="\t", quoting=csv.QUOTE_NONE + ) + df_tgt = df.read_csv( + tgt_input_file, names=["tgt"], sep="\t", quoting=csv.QUOTE_NONE + ) + assert len(df_src) == len( + df_tgt + ), f"We assume the source and target file would have the same number of lines, but got {len(df_src)} and {len(df_tgt)}." + df_combined = df.concat([df_src, df_tgt], axis=1) + df_combined["doc_id"] = doc_id + df_combined["src_lang"] = src_lang + df_combined["tgt_lang"] = tgt_lang + + if add_filename: + df_combined["filename"] = remove_path_extension(src_input_file) + + return df_combined diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 5ca7a2a2..9905c837 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .classifier_filter import FastTextLangId, FastTextQualityFilter +from .bitext_filter import BitextFilter +from .classifier_filter import ( + FastTextLangId, + FastTextQualityFilter, + QualityEstimationFilter, +) from .code import ( AlphaFilter, GeneralCommentToCodeFilter, @@ -29,6 +34,8 @@ BulletsFilter, CommonEnglishWordsFilter, EllipsisFilter, + HistogramFilter, + LengthRatioFilter, LongWordFilter, MeanWordLengthFilter, NonAlphaNumericFilter, @@ -51,6 +58,7 @@ from .synthetic import AnswerabilityFilter, EasinessFilter __all__ = [ + "BitextFilter", "DocumentFilter", "import_filter", "FastTextLangId", @@ -85,6 +93,9 @@ "AlphaFilter", "HTMLBoilerplateFilter", "PerExtensionFilter", + "LengthRatioFilter", + "HistogramFilter", + "QualityEstimationFilter", "AnswerabilityFilter", "EasinessFilter", ] diff --git a/nemo_curator/filters/bitext_filter.py b/nemo_curator/filters/bitext_filter.py new file mode 100644 index 00000000..f85e6840 --- /dev/null +++ b/nemo_curator/filters/bitext_filter.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union + +from dask.typing import no_default + +from nemo_curator.datasets.parallel_dataset import ParallelDataset +from nemo_curator.utils.module_utils import is_batched + + +class BitextFilter(ABC): + """A base class for bitext filter objects (such as length ratio, QE filter) on bitext. + Different from `ParallelScoreFilter`, these filters require looking at both source AND target side of the bitext to compute a score. + + This is roughly equivalent to a `ScoreFilter` wrapping over a `DocumentFilter` object. + But aside from operating on `ParallelDataset` instead of `DocumentDataset`, it comes with some other differences: + + - It discarded the ScoreFilter/DocumentFilter hierarchy. So filter classes can directly be used instead of being wrapped by ScoreFilter. + - Unlike an DocumentFilter object, it allows passing extra metadata information into the scoring function. + """ + + def __init__( + self, + src_field: str = "src", + tgt_field: str = "tgt", + metadata_fields: Union[List[str], str] = [], + metadata_field_name_mapping: Dict[str, str] = {}, + score_field: Optional[str] = None, + score_type: Union[type, str] = None, + invert=False, + ): + """Args: + src_field (str, optional): The field the source documents will be read from. Defaults to "src". + tgt_field (str, optional): The field the target documents will be read from. Defaults to "tgt". + metadata_fields (Union[List[str], str], optional): Name of the metadata fields in case fields other than source and target documents need to be accessed. Defaults to []. + metadata_field_name_mapping (Dict[str, str], optional): Mapping of field names in the data to argument names in `_score_bitext` function, in case they are different. + For example, if a field is called "src" in the data but should be passed to an argument called "source" in `_score_bitext` function, + you should add an entry `{"src": "source"}`. Identity map is assumed if a mapping is not specified for a field name. Default to {}. + score_field (Optional[str], optional): The field to which the scores will be written. If None, scores will be immediately discarded after use. Defaults to None. + score_type (Union[type, str], optional): The datatype of the score that will be made for each document. Defaults to None. + invert (bool, optional): If True, will keep all documents that are normally discarded. Defaults to False. + + Raises: + ValueError: If length of source and target fields are different. + """ + + self.src_field = src_field + self.tgt_field = tgt_field + self.metadata_fields = metadata_fields + self.metadata_field_name_mapping = metadata_field_name_mapping + self.score_field = score_field + self.score_type = score_type + self.invert = invert + + def __call__( + self, + dataset: ParallelDataset, + ) -> ParallelDataset: + """Scores and filters all records in the dataset + + Args: + dataset (ParallelDataset): The dataset to apply the module to. + + Returns: + ParallelDataset: A dataset with the score and filter applied + """ + # Set the metadata for the function calls if provided + if self.score_type: + meta = (None, self.score_type) + else: + meta = no_default + + # support multiple fields if supplied + fields = [] + fields.append(self.src_field) + fields.append(self.tgt_field) + fields.extend(self.metadata_fields) + + if is_batched(self.score_bitext): + scores = dataset.df[fields].map_partitions( + self._score_bitext_wrapper, + metadata_field_name_mapping=self.metadata_field_name_mapping, + meta=meta, + ) + else: + scores = dataset.df[fields].apply( + self._score_bitext_wrapper, + metadata_field_name_mapping=self.metadata_field_name_mapping, + axis=1, + meta=meta, + ) + + if self.score_field is not None: + dataset.df[self.score_field] = scores + + if is_batched(self.keep_bitext): + bool_mask = scores.map_partitions(self.keep_bitext, meta=(None, bool)) + else: + bool_mask = scores.apply(self.keep_bitext, meta=(None, bool)) + if self.invert: + bool_mask = ~bool_mask + + return ParallelDataset(dataset.df[bool_mask]) + + def _score_bitext_wrapper( + self, + df, + metadata_field_name_mapping: Dict[str, str] = {}, + ): + """In the batch mode, pass fields in a data frame to arguments of a function, according to a name mapping. + + Args: + df (DataFrame): data frame to perform the mapping on. + metadata_field_name_mapping (Dict[str, str]): see `__call__` function for details. + """ + kwargs = {} + kwargs["src"] = df[self.src_field] + kwargs["tgt"] = df[self.tgt_field] + for field_name in self.metadata_fields: + arg_name = metadata_field_name_mapping.get(field_name, field_name) + kwargs[arg_name] = df[field_name] + + return self.score_bitext(**kwargs) + + @abstractmethod + def score_bitext(self, src, tgt, **kwargs): + """Scoring function for the bitext.""" + pass + + @abstractmethod + def keep_bitext(self, **kwargs): + pass diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py index 741df964..d3414077 100644 --- a/nemo_curator/filters/classifier_filter.py +++ b/nemo_curator/filters/classifier_filter.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dask +from typing import List + import fasttext import numpy as np import pandas as pd +from nemo_curator.filters.bitext_filter import BitextFilter from nemo_curator.filters.doc_filter import DocumentFilter +from nemo_curator.filters.models.qe_models import COMETQEModel, PyMarianQEModel from nemo_curator.utils.decorators import batched from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker @@ -99,3 +102,151 @@ def keep_document(self, score): def _load_model(self): return fasttext.load_model(self._model_path) + + +class QualityEstimationFilter(BitextFilter): + """(Bitext filter) Use a Quality Estimation (QE) model to score individual segments and filter based on estimated quality score. + (reference: https://arxiv.org/pdf/2311.05350) + """ + + # a mapping from supported model names to their corresponding model class + SUPPORTED_MODELS = { + "comet-qe": COMETQEModel, + "cometoid-wmt23": PyMarianQEModel, + "cometoid-wmt23-mqm": PyMarianQEModel, + } + + def __init__(self, model_name, cutoff, mode="always_en_x", gpu=False, **kwargs): + """Args: + model_name (_type_): Name of the model, as listed in the `SUPPORTED_MODELS` variable. + cutoff (_type_): A cut-off threshold for filtering. All segments with scores lower than this threshold will be tossed away. + mode (str, optional): See `_score_document_with_qe` for definition. Defaults to "always_en_x". + gpu (bool, optional): Whether to use GPU. Defaults to False. + + Raises: + NotImplementedError: If a model name outside the supported model list is passed. + """ + super().__init__(**kwargs) + + if model_name in self.SUPPORTED_MODELS: + self._name = model_name + else: + raise NotImplementedError( + f"Only the following models are currently supported: {str(self.SUPPORTED_MODELS.keys())}" + ) + + self._model_path = None + self._mode = mode + self._cutoff = cutoff + self._gpu = gpu + + def _score_bitext_with_qe( + self, + model, + src: pd.Series, + tgt: pd.Series, + src_lang: pd.Series, + tgt_lang: pd.Series, + mode: str = "always_en_x", + ) -> List[float]: + """Arrange the documents according to the inference mode, call the model to estimate translation quality. + + Args: + model (_type_): QE model object to be called. + src (pd.Series): data frame holding the source document. + tgt (pd.Series): data frame holding the target document. + src_lang (pd.Series): data frame holding the list of source languages. + tgt_lang (pd.Series): data frame holding the list of target languages. + mode (str, optional): Currently three inference modes are supported: + + - `simple`: Maintain the translation direction as specified in the data and + simply pass the corresponding fields to the quality estimation model. + - `always_en_x`: Always pass the English side as the source and non-English side as the target. + This is the strategy used by the referenced paper: https://arxiv.org/pdf/2311.05350. + - `bidi`: Estimate quality on both directions, then average the score. Potentially more accurate + when original translation direction is uncertain (note that "original" translation direction + might have been flipped while building the data), but also twice as expensive computationally. + + Defaults to "always_en_x". + + Returns: + List[float]: A list of float scores corresponding to the individual score of each documents. + """ + + def _is_en_x(src_lang: str, tgt_lang: str): + return src_lang == "en" and tgt_lang != "en" + + def _has_en(src_lang: str, tgt_lang: str): + return src_lang == "en" and tgt_lang == "en" + + model_class = self.SUPPORTED_MODELS[self._name] + + if mode == "simple": + input = [model_class.wrap_qe_input(s, t) for s, t in zip(src, tgt)] + return model.predict(input) + elif mode == "always_en_x": + # if English is included but it's on the target side, flip to make sure we are scoring with en-x + # this strategy was proposed in: https://aclanthology.org/2023.wmt-1.50.pdf + input = [ + model_class.wrap_qe_input( + s, t, reverse=(_has_en(sl, tl) and not _is_en_x(sl, tl)) + ) + for s, t, sl, tl in zip(src, tgt, src_lang, tgt_lang) + ] + return model.predict(input) + elif mode == "bidi": + # score twice -- once forward and once backward + fwd_input = [model_class.wrap_qe_input(s, t) for s, t in zip(src, tgt)] + rev_input = [ + model_class.wrap_qe_input(s, t, reverse=True) for s, t in zip(src, tgt) + ] + scores = model.predict( + fwd_input + rev_input + ) # making one call to take advantage of batching + # first half is forward score, second half is reverse score -- now we unpack and average + fwd_scores = scores[: len(src)] + rev_scores = scores[len(src) :] + return [(fs + rs) / 2 for fs, rs in zip(fwd_scores, rev_scores)] + else: + raise NotImplementedError + + @batched + def score_bitext( + self, src: pd.Series, tgt: pd.Series, src_lang: pd.Series, tgt_lang: pd.Series + ) -> pd.Series: + """Wrapper function that scores documents in a data frame. Most work is done in `_score_document_with_qe`. + + Args: + Takes two metadata fields: `src_lang` and `tgt_lang`. Refer to `_score_bitext_with_qe` function for details. + + Raises: + RuntimeError: If input data frame arguments doesn't have the same length. + + Returns: + pd.Series: A list of float scores corresponding to the individual score of each documents. + """ + + if not len(src) == len(tgt) == len(src_lang) == len(tgt_lang): + raise RuntimeError( + "Different fields of the data frame should have the same length" + ) + + model_attr = f"{self._name}_{self._model_path}" + try: + model = load_object_on_worker( + model_attr, + self.SUPPORTED_MODELS[self._name].load_model, + {"model_name": self._name, "gpu": self._gpu}, + ) + except NoWorkerError: + return pd.Series([-1.0 for _ in range(len(src))]) + + scores = self._score_bitext_with_qe( + model, src, tgt, src_lang, tgt_lang, self._mode + ) + + return pd.Series(scores, index=src.index) + + def keep_bitext(self, score): + """Decides whether a single document should be retained according to a threshold of estimated quality score.""" + return score >= self._cutoff diff --git a/nemo_curator/filters/doc_filter.py b/nemo_curator/filters/doc_filter.py index 742ecf7b..e84765e4 100644 --- a/nemo_curator/filters/doc_filter.py +++ b/nemo_curator/filters/doc_filter.py @@ -14,7 +14,9 @@ import importlib from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Union + +from nemo_curator.filters.bitext_filter import BitextFilter class DocumentFilter(ABC): @@ -108,7 +110,7 @@ def ngrams(self, ngrams): self._ngrams = ngrams -def import_filter(filter_path: str) -> DocumentFilter: +def import_filter(filter_path: str) -> Union[DocumentFilter, BitextFilter]: """ Imports a filter under nemo_curator.filters given the module path @@ -124,9 +126,12 @@ def import_filter(filter_path: str) -> DocumentFilter: module_path, filter_name = filter_path.rsplit(".", 1) filter_module = importlib.import_module(module_path) filter_class = getattr(filter_module, filter_name) - if not issubclass(filter_class, DocumentFilter): + if not issubclass(filter_class, DocumentFilter) and not issubclass( + filter_class, BitextFilter + ): raise ValueError( f"Input filter {filter_class.__name__} must be derived " - "from DocumentFilter defined in nemo_curator.filters.doc_filter" + "from DocumentFilter defined in nemo_curator.filters.doc_filter or" + "from BitextFilter defined in nemo_curator.filters.bitext_filter" ) return filter_class diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index ebd6f8d3..c17e4e9a 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_curator.filters.doc_filter import DocumentFilter +import os.path +import tarfile + +import requests +from platformdirs import user_cache_dir + +from nemo_curator.filters.bitext_filter import BitextFilter +from nemo_curator.filters.doc_filter import DocumentFilter, import_filter from nemo_curator.utils.constants import ( bullet_list, common_english_words, @@ -662,3 +669,138 @@ def score_document(self, text): def keep_document(self, score): return score != 1 + + +class HistogramFilter(DocumentFilter): + """Histogram filter used by the NLLB paper (https://arxiv.org/pdf/2207.04672). See p30 for details. + + The high-level idea of histogram filter can be described as a cheap version of language ID. + Basically, it checks what ratio of characters in the data instance are included in the character historgrams collected from trusted data in the corresponding language. + If the ratio is too low, then there is a good chance that there is a language ID mismatch and the data instance should be discarded. + + Written with reference to the original fairseq implementation at: + https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py. + """ + + def __init__(self, lang="en", threshold=0.8, cache_dir="", threshold_char="]"): + """Args: + lang (str, optional): Expected language of the segment. This will decide which histogram will be loaded. Defaults to "en". + threshold (float, optional): Threshold for ratio of characters in the histogram. Defaults to 0.8. + cache_dir (str, optional): Cache dir download histogram files. Defaults to "". + threshold_char (str, optional): Formatter character of the histogram files. You should not change this unless you rebuilt your own histogram. Defaults to "]". + """ + super().__init__() + self._lang = lang + self._threshold = threshold + self._cache_dir = cache_dir if cache_dir else user_cache_dir() + self._threshold_char = threshold_char + self._name = "histogram" + + if not os.path.isdir(os.path.join(self._cache_dir, "histograms")): + self._download_histograms() + + self._read_hist() + + def _download_histograms(self): + """Download and process histograms from default repo. + + Raises: + requests.exceptions.RequestException: If download fails. + """ + + # Send a GET request to the URL + response = requests.get( + "https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz" + ) + + # Check if the request was successful + if response.status_code != 200: + raise requests.exceptions.RequestException( + f"Failed to download histogram file. Status code: {response.status_code}" + ) + + # Open a file to write the content + os.makedirs(self._cache_dir, exist_ok=True) + download_dest_path = os.path.join(self._cache_dir, "histograms.tar.gz") + with open(download_dest_path, "wb") as file: + file.write(response.content) + + extract_path = os.path.join(self._cache_dir, "histograms") + with tarfile.open(download_dest_path, "r:gz") as tar: + # Extract all the contents into the specified directory + tar.extractall(path=extract_path) + + def _read_hist(self): + """Load histogram files.""" + + self._histogram = [] + with open( + os.path.join( + self._cache_dir, + "histograms", + "checkpoint", + "edunov", + "cc60_multilingual", + "clean_hists", + self._lang, + ) + ) as f: + for line in f: + c = line[0] + if c == self._threshold_char: + break + self._histogram.append(c) + self._histogram = set(self._histogram) + + def score_document(self, text: str) -> float: + """Compute histogram token ratio of a text data instance according to the loaded histogram. + + Args: + text (str): Text data instance. + + Returns: + float: Ratio of tokens included in the histogram. + """ + cnt = len([c for c in text.strip() if c in self._histogram]) + return 1 if cnt / len(text) > self._threshold else 0 + + def keep_document(self, score): + return score == 1 + + +class LengthRatioFilter(BitextFilter): + """(Bitext filter) Length ratio filter for bitext, similar to the one implemented in Moses toolkit (`https://github.com/moses-smt/mosesdecoder/blob/master/scripts/training/clean-corpus-n.perl`). + + If the ratio between source and target tokens is not within a specified range then discard. Either direction (src/tgt, tgt/src) is considered. + """ + + def __init__(self, max_ratio=3.0, src_lang="en", tgt_lang="en", **kwargs): + """Args: + max_ratio (float, optional): Maximum allowed length ratio between either direction of the bitext. Defaults to 3.0. + src_lang (str, optional): Language of the source data (needed for tokenization). Defaults to "en". + tgt_lang (str, optional): Language of the target data (needed for tokenization). Defaults to "en". + """ + + super().__init__(**kwargs) + self._max_ratio = float(max_ratio) + self._src_word_splitter = get_word_splitter(src_lang) + self._tgt_word_splitter = get_word_splitter(tgt_lang) + self._name = "length_ratio" + + def score_bitext(self, src: str, tgt: str) -> float: + """Tokenize the source and target sentences and compute length ratio. + + Args: + src (str): Source document string. + tgt (str): Target document string. + + Returns: + float: The maximum ratio among the two translation directions of the bitext. + """ + src_len = len(self._src_word_splitter(src.strip())) + tgt_len = len(self._tgt_word_splitter(tgt.strip())) + return max(src_len / tgt_len, tgt_len / src_len) + + def keep_bitext(self, score): + """Decides whether a single document should be retained according to the computed length ratio.""" + return score < self._max_ratio diff --git a/nemo_curator/filters/models/__init__.py b/nemo_curator/filters/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nemo_curator/filters/models/qe_models.py b/nemo_curator/filters/models/qe_models.py new file mode 100644 index 00000000..6a06cb8f --- /dev/null +++ b/nemo_curator/filters/models/qe_models.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from huggingface_hub import hf_hub_download + +from nemo_curator.utils.import_utils import safe_import + +COMET_IMPORT_MSG = ( + "To run QE filtering with COMET, you need to install from PyPI with: `pip install unbabel-comet`. " + + "More information at https://github.com/Unbabel/COMET." +) +PYMARIAN_IMPORT_MSG = ( + "To run QE filtering with Cometoid/PyMarian, you need to install PyMarian. " + + "More information at https://github.com/marian-nmt/wmt23-metrics?tab=readme-ov-file#setup." +) +comet = safe_import("comet", msg=COMET_IMPORT_MSG) +pymarian = safe_import("pymarian", msg=PYMARIAN_IMPORT_MSG) + + +class QEModel(ABC): + """Abstract model for all quality estimation models for bitext.""" + + def __init__(self, name: str, model, gpu=False): + """Args: + name (str): A string named of the model. Not directly tied to `MODEL_NAME_TO_HF_PATH` as defined in some subclasses but it is suggested. + model: A loaded model object. The type of the object depends on the loaded model type. + gpu (bool, optional): Whether inference is on GPU. Defaults to False. + """ + self._name = name + self._model = model + self._gpu = gpu + + @classmethod + @abstractmethod + def load_model(cls, model_name: str): + """An abstract method that loads the model according to a model name. + + Args: + model_name (str): The name of the model to be loaded. + Could be a huggingface model name, a path, or something else, depending on the implementation. + """ + pass + + @staticmethod + @abstractmethod + def wrap_qe_input(src: str, tgt: str, reverse=False, **kwargs): + """An abstract method that implements the following: given the individual source and target string of the bitext, + wrap them into proper format that can be accepted by the underlying model. + + Args: + src (str): Source side string of the bitext. + tgt (str): Target side string of the bitext. + reverse (bool, optional): Whether to reverse the source and target side of the bitext. Defaults to False. + """ + pass + + @abstractmethod + def predict(self, **kwargs) -> List[float]: + """An abstract method that calls the underlying model to produce estimated quality scores. + + Returns: + List[float]: List of quality scores. + """ + pass + + +class COMETQEModel(QEModel): + """Wrapper class for any COMET quality estimation models (https://github.com/Unbabel/COMET).""" + + MODEL_NAME_TO_HF_PATH = { + "comet-qe": "Unbabel/wmt20-comet-qe-da", + } + + @classmethod + def load_model(cls, model_name: str, gpu: bool = False): + """See parent class docstring for details on functionality and arguments.""" + path = comet.download_model(cls.MODEL_NAME_TO_HF_PATH[model_name]) + return cls(model_name, comet.load_from_checkpoint(path), gpu) + + @staticmethod + def wrap_qe_input(src: str, tgt: str, reverse=False, **kwargs): + """See parent class docstring for details on functionality and arguments.""" + return {"src": src, "mt": tgt} if not reverse else {"src": tgt, "mt": src} + + def predict(self, input: List, **kwargs) -> List[float]: + """Implements quality estimation score prediction for COMET model. + + Args: + input (List): A list of bitext pairs wrapped as dictionaries. + + Returns: + List[float]: List of quality scores. + """ + return self._model.predict( + input, gpus=int(self._gpu), num_workers=0 + ).scores # it's critical to set num_workers=0 to avoid spawning new processes within a dask worker + + +class PyMarianQEModel(QEModel): + + MODEL_NAME_TO_HF_PATH = { + "cometoid-wmt23": "marian-nmt/cometoid22-wmt23", + "cometoid-wmt23-mqm": "marian-nmt/cometoid22-wmt23", + } + # Because PyMarian depends on its own deep learning library rather than PyTorch/Huggingface + # there is unfortunately no model configuration interface that can automatically adapt to + # individual systems (like hf `AutoConfig`). + # Those should work on most systems, but if not please adjust as needed. + MARIAN_GPU_ARGS = " -w 8000 --mini-batch 32 -d 0" + MARIAN_CPU_ARGS = " --cpu-threads 1 -w 2000" + # PyMarian has memory leakage when a very large input is passed. + # Hence we limit the size of input passed into PyMarian within one API call. + SHARD_SIZE = 5000 + + @classmethod + def load_model(cls, model_name: str, gpu: bool = False): + """See parent class docstring for details on functionality and arguments.""" + repo_id = cls.MODEL_NAME_TO_HF_PATH[model_name] + model_path = hf_hub_download(repo_id, filename="checkpoints/marian.model.bin") + vocab_path = hf_hub_download(repo_id, filename="vocab.spm") + marian_args = f"-m {model_path} -v {vocab_path} {vocab_path} --like comet-qe" + if gpu: + marian_args += cls.MARIAN_GPU_ARGS + else: + marian_args += cls.MARIAN_CPU_ARGS + return cls(model_name, pymarian.Evaluator(marian_args), gpu) + + @staticmethod + def wrap_qe_input(src: str, tgt: str, reverse=False, **kwargs): + """See parent class docstring for details on functionality and arguments.""" + return [src, tgt] if not reverse else [tgt, src] + + def predict(self, input: List, **kwargs) -> List[float]: + """Implements quality estimation score prediction for Cometoid/PyMarian model. + + Args: + input (List): A list of bitext pairs wrapped as dictionaries. + + Returns: + List[float]: List of quality scores. + """ + scores = [] + for start_idx in range(0, len(input), self.SHARD_SIZE): + scores.extend( + self._model.evaluate(input[start_idx : start_idx + self.SHARD_SIZE]) + ) + + if not self._name.endswith("mqm"): + # using DA+SQM score by default + # choice made based on paper: https://aclanthology.org/2023.wmt-1.62.pdf + return [score[1] for score in scores] + else: + return [score[0] for score in scores] diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 87d82402..bc565931 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -25,7 +25,7 @@ from .config import FuzzyDuplicatesConfig, SemDedupConfig from .dataset_ops import blend_datasets, Shuffle from .exact_dedup import ExactDuplicates -from .filter import Filter, Score, ScoreFilter +from .filter import Filter, Score, ScoreFilter, ParallelScoreFilter from .meta import Sequential from .modify import Modify from .task import TaskDecontamination @@ -64,6 +64,7 @@ "Modify", "Score", "ScoreFilter", + "ParallelScoreFilter", "Sequential", "TaskDecontamination", "AddId", diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py index b09a7c5c..1006ee25 100644 --- a/nemo_curator/modules/filter.py +++ b/nemo_curator/modules/filter.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union + +from typing import Callable, List, Optional, Union import pandas as pd +from dask.array import logical_and from dask.dataframe.extensions import make_array_nonempty from dask.typing import no_default from nemo_curator.datasets import DocumentDataset +from nemo_curator.datasets.parallel_dataset import ParallelDataset from nemo_curator.filters import DocumentFilter from nemo_curator.utils.module_utils import is_batched @@ -108,15 +111,14 @@ def __init__(self, filter_fn: Callable, filter_field: str, invert: bool = False) self.filter_field = filter_field self.invert = invert - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: - """ - Applies the filtering to a dataset + def compute_filter_mask(self, dataset: DocumentDataset): + """Compute the bool mask to filter the dataset. Args: - dataset (DocumentDataset): The dataset to apply the module to + dataset (DocumentDataset): The dataset to compute filter mask on. Returns: - DocumentDataset: A dataset with entries removed according to the filter + Series or DataFrame: A mask corresponding to each data instance indicating whether it will be retained. """ if is_batched(self.filter_fn): bool_mask = dataset.df[self.filter_field].map_partitions( @@ -130,6 +132,19 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: if self.invert: bool_mask = ~bool_mask + return bool_mask + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Applies the filtering to a dataset + + Args: + dataset (DocumentDataset): The dataset to apply the module to + + Returns: + DocumentDataset: A dataset with entries removed according to the filter + """ + bool_mask = self.compute_filter_mask(dataset) return DocumentDataset(dataset.df[bool_mask]) @@ -167,15 +182,14 @@ def __init__( self.score_type = score_type self.invert = invert - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: - """ - Scores and filters all records in the dataset + def compute_filter_mask(self, dataset: DocumentDataset): + """Compute the bool mask to filter the dataset. Args: - dataset (DocumentDataset): The dataset to apply the module to + dataset (DocumentDataset): The dataset to compute filter mask on. Returns: - DocumentDataset: A dataset with the score and filter applied + Series or DataFrame: A mask corresponding to each data instance indicating whether it will be retained. """ if self.score_type: meta = (None, self.score_type) @@ -203,4 +217,65 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: if self.invert: bool_mask = ~bool_mask + return bool_mask + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Scores and filters all records in the dataset + + Args: + dataset (DocumentDataset): The dataset to apply the module to + + Returns: + DocumentDataset: A dataset with the score and filter applied + """ + bool_mask = self.compute_filter_mask(dataset) return DocumentDataset(dataset.df[bool_mask]) + + +class ParallelScoreFilter: + def __init__( + self, + src_filter_obj, + tgt_filter_obj, + src_field="src", + tgt_field="tgt", + src_score=None, + tgt_score=None, + score_type=None, + invert=False, + ): + """A filter object wrapper class for applying *monolingual* filter objects on bitext. + If either side of the bitext is discarded, the whole bitext pair is discarded. + If you want to apply a *bitext* filter that takes both the source and target as input, checkout `BitextFilter` class. + + Note that the goal of this wrapper class is to group the same/similar filters on bitext thus making the logic clearer, + which is why we force the `score_type` and `invert` to be the same among source/target filters. + If you need the extra flexibility, you should fall back to applying two filters one after the other. + + Args: + src_filter_obj (_type_): The score function that takes in a document string and outputs a score for the source document. + tgt_filter_obj (_type_): The score function that takes in a document string and outputs a score for the target document. + src_field (str, optional): The field the source documents will be read from. Defaults to "src". + tgt_field (str, optional): The field the target documents will be read from. Defaults to "tgt". + src_score (str, optional): The field to which the source scores will be written. If None, scores will be immediately discarded after use. Defaults to None. + tgt_score (str, optional): The field to which the target scores will be written. If None, scores will be immediately discarded after use. Defaults to None. + score_type (Optional[str]): The datatype of the score that will be made for each document. Defaults to None. + invert (bool, optional): If True, will keep all documents that are normally discarded. Defaults to False. + """ + + self.source_score_filter = ScoreFilter( + src_filter_obj, src_field, src_score, score_type, invert + ) + self.target_score_filter = ScoreFilter( + tgt_filter_obj, tgt_field, tgt_score, score_type, invert + ) + + def __call__(self, dataset: ParallelDataset): + src_bool_mask = self.source_score_filter.compute_filter_mask(dataset) + tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset) + + # remove lines together if one of them is filtered + bool_mask = logical_and(src_bool_mask, tgt_bool_mask) + + return ParallelDataset(dataset.df[bool_mask]) diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 4e695baf..27d41679 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -15,12 +15,14 @@ import ast import os +import shutil os.environ["RAPIDS_NO_INITIALIZE"] = "1" import random import warnings from contextlib import nullcontext from datetime import datetime +from itertools import zip_longest from pathlib import Path from typing import Dict, List, Literal, Optional, Union @@ -568,7 +570,9 @@ def single_partition_write_with_filename( if not keep_filename_column: out_df = out_df.drop("filename", axis=1) - filename = Path(filename).stem + filename = ( + Path(filename).stem if output_type != "bitext" else Path(filename).name + ) output_file_path = os.path.join(output_file_dir, filename) if output_type == "jsonl": @@ -596,13 +600,79 @@ def single_partition_write_with_filename( elif output_type == "parquet": output_file_path = output_file_path + ".parquet" out_df.to_parquet(output_file_path) - + elif output_type == "bitext": + raise RuntimeError( + "You shouldn't call this function to write to simple bitext." + ) else: raise ValueError(f"Unknown output type: {output_type}") return success_ser +def _single_partition_write_to_simple_bitext( + out_df, output_file_path, partition_info=None +): + if len(out_df) > 0: + empty_partition = False + else: + warnings.warn(f"Empty partition found") + empty_partition = True + + if is_cudf_type(out_df): + import cudf + + success_ser = cudf.Series([empty_partition]) + else: + success_ser = pd.Series([empty_partition]) + + src_output_file_path = output_file_path + f".{out_df['src_lang'].iloc[0]}" + tgt_output_file_path = output_file_path + f".{out_df['tgt_lang'].iloc[0]}" + partition_id = partition_info["number"] if partition_info else 0 + with ( + open(f"{src_output_file_path}.{partition_id}", "w") as src_out, + open(f"{tgt_output_file_path}.{partition_id}", "w") as tgt_out, + ): + for src, tgt in zip(out_df["src"], out_df["tgt"]): + src_out.write(src + os.linesep) + tgt_out.write(tgt + os.linesep) + + return success_ser + + +def _merge_tmp_simple_bitext_partitions(tmp_output_dir: str, output_dir: str): + """Merge partitions of simple bitext files in `tmp_output_dir` into files at `output_file_dir`. + + Args: + tmp_output_dir (str): temporary directory that has all the simple bitext output partitions, + with suffixes that looks like "file.1", "file.2" that shows the merging order + output_file_path (str): dir to write output files + """ + + sorted_tmp_files = sorted( + os.listdir(tmp_output_dir), key=lambda x: int(x.split(".")[-1]) + ) + unique_file_handles = {} + # Loop through the sorted files and concatenate their contents + for f in sorted_tmp_files: + input_file_path = os.path.join(tmp_output_dir, f) + output_file_name = ".".join(f.split(".")[:-1]) + + # this is where current file will be concatenated into + output_file_path = os.path.join(output_dir, output_file_name) + + # create the output file if we haven't yet + if output_file_path not in unique_file_handles: + unique_file_handles[output_file_path] = open(output_file_path, "w") + + with open(input_file_path, "r") as infile: + unique_file_handles[output_file_path].write(infile.read()) + + # close all dangling file handles + for handle in unique_file_handles.values(): + handle.close() + + def write_to_disk( df, output_file_dir: str, @@ -628,14 +698,14 @@ def write_to_disk( "write_using_filename is True but no filename column found in df" ) - if write_to_filename: - if is_cudf_type(df): - import cudf + if is_cudf_type(df): + import cudf - output_meta = cudf.Series([True]) - else: - output_meta = pd.Series([True], dtype="bool") + output_meta = cudf.Series([True]) + else: + output_meta = pd.Series([True], dtype="bool") + if write_to_filename and output_type != "bitext": os.makedirs(output_file_dir, exist_ok=True) output = df.map_partitions( single_partition_write_with_filename, @@ -661,6 +731,33 @@ def write_to_disk( ) elif output_type == "parquet": df.to_parquet(output_file_dir, write_index=False) + elif output_type == "bitext": + if write_to_filename: + os.makedirs(output_file_dir, exist_ok=True) + tmp_output_file_dir = os.path.join(output_file_dir, ".tmp") + os.makedirs(tmp_output_file_dir, exist_ok=True) + file_name = os.path.basename(list(df.filename.unique())[0]) + else: + tmp_output_file_dir = os.path.join(output_file_dir, ".tmp") + os.makedirs(tmp_output_file_dir, exist_ok=True) + file_name = os.path.basename(output_file_dir) + + output = df.map_partitions( + _single_partition_write_to_simple_bitext, + os.path.join(tmp_output_file_dir, file_name), + meta=output_meta, + enforce_metadata=False, + ) + output = output.compute() + _merge_tmp_simple_bitext_partitions( + tmp_output_file_dir, + ( + output_file_dir + if write_to_filename + else os.path.dirname(output_file_dir) + ), + ) + shutil.rmtree(tmp_output_file_dir) else: raise ValueError(f"Unknown output type: {output_type}") diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index e919c88b..f1c957b4 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -456,3 +456,8 @@ def reshard_jsonl( # Save to balanced files _save_jsonl(b, output_dir, start_index=start_index, prefix=file_prefix) + + +def remove_path_extension(path: str): + p = pathlib.Path(path) + return os.path.join(p.parent, p.stem) diff --git a/pyproject.toml b/pyproject.toml index a4cb3d95..75f5b6b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "numpy<2", "openai", "peft", + "platformdirs", "presidio-analyzer==2.2.351", "presidio-anonymizer==2.2.351", "pycld2", @@ -104,9 +105,17 @@ image_nightly = [ "timm>=1.0.8", "nemo_curator[cuda12x_nightly]", ] +# Installs bitext curation modules +bitext = [ + "huggingface-hub", + "tqdm", + "transformers", + "nemo_curator[cuda12x]", +] # Installs all of the above with Stable RAPIDS all = [ "nemo_curator[image]", + "nemo_curator[bitext]", ] # Installs all of the above with RAPIDS Nightlies all_nightly = [ diff --git a/tests/bitext_data/toy.de b/tests/bitext_data/toy.de new file mode 100644 index 00000000..3c5cd2d2 --- /dev/null +++ b/tests/bitext_data/toy.de @@ -0,0 +1,4 @@ +Wir besitzen keine Reisetaschen aus Leder. + Wir besitzen keine Reisetaschen aus Leder. +Die Firma produziert Computer für den deutschen Markt. + Die Firma produziert Computer für den deutschen Markt. diff --git a/tests/bitext_data/toy.en b/tests/bitext_data/toy.en new file mode 100644 index 00000000..8617394b --- /dev/null +++ b/tests/bitext_data/toy.en @@ -0,0 +1,4 @@ +We don't own duffel bags made of leather. + We don't own duffel bags made of leather. +The company produces computers for the German market. + The company produces computers for the German market. diff --git a/tests/test_filters.py b/tests/test_filters.py index 66bc7dc0..2f8bd00d 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -21,6 +21,7 @@ from dask import dataframe as dd from nemo_curator.datasets import DocumentDataset +from nemo_curator.datasets.parallel_dataset import ParallelDataset from nemo_curator.filters import ( AlphaFilter, BoilerPlateStringFilter, @@ -29,7 +30,9 @@ DocumentFilter, EllipsisFilter, GeneralCommentToCodeFilter, + HistogramFilter, HTMLBoilerplateFilter, + LengthRatioFilter, LongWordFilter, MeanWordLengthFilter, NonAlphaNumericFilter, @@ -54,8 +57,19 @@ WordsWithoutAlphabetsFilter, XMLHeaderFilter, ) -from nemo_curator.modules import Filter, Score, ScoreFilter, Sequential +from nemo_curator.filters.models.qe_models import COMET_IMPORT_MSG, PYMARIAN_IMPORT_MSG +from nemo_curator.modules import ( + Filter, + ParallelScoreFilter, + Score, + ScoreFilter, + Sequential, +) from nemo_curator.utils.decorators import batched +from nemo_curator.utils.import_utils import is_unavailable, safe_import + +comet = safe_import("comet", msg=COMET_IMPORT_MSG) +pymarian = safe_import("pymarian", msg=PYMARIAN_IMPORT_MSG) class LetterCountFilter(DocumentFilter): @@ -107,6 +121,28 @@ def list_to_dataset(documents, col_name="text", npartitions=2): return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) +def two_lists_to_parallel_dataset( + src_documents, + tgt_documents, + src_lang, + tgt_lang, + src_col_name="src", + tgt_col_name="tgt", + npartitions=2, +): + src_langs = [src_lang] * len(src_documents) + tgt_langs = [tgt_lang] * len(src_documents) + data = { + src_col_name: src_documents, + "src_lang": src_langs, + tgt_col_name: tgt_documents, + "tgt_lang": tgt_langs, + } + pdf = pd.DataFrame(data) + + return ParallelDataset(dd.from_pandas(pdf, npartitions=npartitions)) + + @pytest.fixture def letter_count_data(): return list_to_dataset( @@ -114,6 +150,28 @@ def letter_count_data(): ) +@pytest.fixture +def parallel_letter_count_data(): + return two_lists_to_parallel_dataset( + ["Einsa", "Zwei aaa", "a Drei a", "Fünf aaa a", "aaaSieben aaaa"], + ["aOne", "Two aa", "a a Three a", "Five aaa aa", "aaaSeven aaaa"], + src_lang="de", + tgt_lang="en", + src_col_name="src", + tgt_col_name="tgt", + ) + + +@pytest.fixture +def length_ratio_data(): + return two_lists_to_parallel_dataset( + ["Test", "test", "Test Test ", "Test Test"], + ["Prueba", "prueba prueba prueba", "Prueba Prueba", "Prueba Prueba Prueba "], + src_lang="en", + tgt_lang="es", + ) + + class TestFilterModule: def test_score_filter(self, letter_count_data): letter_filter = LetterCountFilter() @@ -299,6 +357,38 @@ def test_chain_filter(self, letter_count_data): expected_data, filtered_data ), f"Expected {expected_data} but got {filtered_data}" + def test_parallel_score_filter(self, parallel_letter_count_data): + src_letter_count_filter = LetterCountFilter(min_count=2) + tgt_letter_count_filter = LetterCountFilter(min_count=3) + filter_step = ParallelScoreFilter( + src_letter_count_filter, tgt_letter_count_filter + ) + filtered_data = filter_step(parallel_letter_count_data) + + expected_indices = [2, 3, 4] + expected_data = ParallelDataset( + parallel_letter_count_data.df.loc[expected_indices] + ) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + + def test_joint_score_filter(self, length_ratio_data): + filter_ = LengthRatioFilter( + max_ratio=1.5, + src_lang="en", + tgt_lang="de", + score_field="ratio", + score_type=float, + ) + filtered_data = filter_(length_ratio_data) + + expected_indices = [0, 2] + expected_data = ParallelDataset(length_ratio_data.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + class TestHeuristicFilters: def test_nonalpha(self): @@ -648,6 +738,34 @@ def test_pornographicurls(self): expected_data, filtered_data ), f"Expected {expected_data} but got {filtered_data}" + def test_histogram(self): + dataset = list_to_dataset( + [ + "This is a perfectly fine English document.", + "But if you insist that this is written in Chinese,", + "it's likely that something is fishy.", + "另一方面,这是一个好的中文文档,", + "但你一定要说这是英文文档,", + "那很可能有些地方出了差错。", + ] + ) + filter1 = ScoreFilter(HistogramFilter(lang="en")) + filter2 = ScoreFilter(HistogramFilter(lang="zh")) + + expected_indices1 = [0, 1, 2] + expected_indices2 = [3, 4, 5] + expected_data1 = DocumentDataset(dataset.df.loc[expected_indices1]) + expected_data2 = DocumentDataset(dataset.df.loc[expected_indices2]) + + filtered_data1 = filter1(dataset) + filtered_data2 = filter2(dataset) + assert all_equal( + expected_data1, filtered_data1 + ), f"Expected {expected_data1} but got {filtered_data1}" + assert all_equal( + expected_data2, filtered_data2 + ), f"Expected {expected_data2} but got {filtered_data2}" + class TestCodeFilters: def test_python_comment_to_code(self): @@ -883,3 +1001,78 @@ def test_fake_langid_filter(self): assert all_equal( expected_data, filtered_data ), f"Expected {expected_data} but got {filtered_data}" + + @pytest.mark.skipif( + is_unavailable(comet), reason="Test depends on COMET but it's not installed." + ) + def test_comet_qe_filter(self): + dataset = two_lists_to_parallel_dataset( + [ + "This sentence will be translated on the Chinese side.", + "This sentence will have something irrelevant on the Chinese side.", + ], + [ + "这句话在中文一侧会被翻译。", + "至尊戒,驭众戒;至尊戒,寻众戒;魔戒至尊引众戒,禁锢众戒黑暗中。", + ], + "en", + "zh", + ) + + from nemo_curator.filters import QualityEstimationFilter + from nemo_curator.utils.distributed_utils import get_client + + client = get_client(n_workers=1) + filter_ = QualityEstimationFilter( + "comet-qe", + cutoff=-0.25, + mode="bidi", + score_type=float, + metadata_fields=["src_lang", "tgt_lang"], + ) + filtered_data = filter_(dataset) + + expected_indices = [0] + expected_data = ParallelDataset(dataset.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + client.close() + + @pytest.mark.skipif( + is_unavailable(pymarian), + reason="Test depends on PyMarian but it's not installed.", + ) + def test_cometoid_qe_filter(self): + dataset = two_lists_to_parallel_dataset( + [ + "This sentence will be translated on the Chinese side.", + "This sentence will have something irrelevant on the Chinese side.", + ], + [ + "这句话在中文一侧会被翻译。", + "至尊戒,驭众戒;至尊戒,寻众戒;魔戒至尊引众戒,禁锢众戒黑暗中。", + ], + "en", + "zh", + ) + + from nemo_curator.filters import QualityEstimationFilter + from nemo_curator.utils.distributed_utils import get_client + + client = get_client(n_workers=1) + filter_ = QualityEstimationFilter( + "cometoid-wmt23", + cutoff=0.75, + mode="bidi", + score_type=float, + metadata_fields=["src_lang", "tgt_lang"], + ) # enable GPU by gpu=True + filtered_data = filter_(dataset) + + expected_indices = [0] + expected_data = ParallelDataset(dataset.df.loc[expected_indices]) + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + client.close() diff --git a/tests/test_read_simple_bitext.py b/tests/test_read_simple_bitext.py new file mode 100644 index 00000000..f9581a36 --- /dev/null +++ b/tests/test_read_simple_bitext.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from nemo_curator.datasets.parallel_dataset import ParallelDataset + + +# The source/target file paths will be concatenated as document ID in `ParallelDataset`, so we can't directly pass a `Path` object. +@pytest.fixture +def src_file(): + return Path("tests/bitext_data/toy.de").absolute().as_posix() + + +@pytest.fixture +def tgt_file(): + return Path("tests/bitext_data/toy.en").absolute().as_posix() + + +class TestReadSimpleBitext: + + def test_pandas_read_simple_bitext(self, src_file, tgt_file): + ds = ParallelDataset.read_simple_bitext( + src_input_files=[src_file], + tgt_input_files=[tgt_file], + src_lang="de", + tgt_lang="en", + backend="pandas", + ) + + for idx, (src_line, tgt_line) in enumerate( + zip(open(src_file, encoding="utf-8"), open(tgt_file, encoding="utf-8")) + ): + assert ds.df["src"].compute()[idx] == src_line.rstrip("\n") + assert ds.df["tgt"].compute()[idx] == tgt_line.rstrip("\n") + assert ds.df["src_lang"].compute()[idx] == "de" + assert ds.df["tgt_lang"].compute()[idx] == "en" + + @pytest.mark.gpu + def test_cudf_read_simple_bitext(self, src_file, tgt_file): + ds = ParallelDataset.read_simple_bitext( + src_input_files=[src_file], + tgt_input_files=[tgt_file], + src_lang="de", + tgt_lang="en", + backend="cudf", + ) + + for idx, (src_line, tgt_line) in enumerate( + zip(open(src_file, encoding="utf-8"), open(tgt_file, encoding="utf-8")) + ): + assert ds.df["src"].compute()[idx] == src_line.rstrip("\n") + assert ds.df["tgt"].compute()[idx] == tgt_line.rstrip("\n") + assert ds.df["src_lang"].compute()[idx] == "de" + assert ds.df["tgt_lang"].compute()[idx] == "en" diff --git a/tutorials/bitext_cleaning/README.md b/tutorials/bitext_cleaning/README.md new file mode 100644 index 00000000..63859570 --- /dev/null +++ b/tutorials/bitext_cleaning/README.md @@ -0,0 +1,21 @@ +# Bitext Cleaning + +This tutorial demonstrates and highlights the bitext-specific functionalities within NeMo Curator's API to load and filter the English-German split of the [Multi-Target TED Talks](https://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/) dataset. These files are based on [TED Talks](https://www.ted.com/). + +The dataset contains a small training set of ~150k sentences, and ~2k dev and test sets. This tutorial only downloads and filters the training set. + +## Walkthrough + +This tutorial highlights several bitext-specific functionalities within NeMo Curator's API, including: +1. The ParallelDataset Class +2. Length Ratio Filtering via the LengthRatioFilter and the JointScoreFilter classes +3. Histogram-based Language ID Filtering via the HistogramFilter and the ParallelScoreFilter class. +4. (Optional and only on GPU) Model-based filtering using quality estimation model (`Unbabel/wmt20-comet-qe-da`) and the JointScoreFilter class. + +## Usage + +After installing the NeMo Curator package, you can simply run the following command: +``` +python tutorials/bitext_cleaning/main.py +``` +This will download the English-German training data, and run both length ratio filtering and histogram-based language ID filtering. The filtered data will be saved in the `data/` directory here. diff --git a/tutorials/bitext_cleaning/docbuilder.py b/tutorials/bitext_cleaning/docbuilder.py new file mode 100644 index 00000000..0c7ab693 --- /dev/null +++ b/tutorials/bitext_cleaning/docbuilder.py @@ -0,0 +1,39 @@ +import os + +import requests + +from nemo_curator.download.doc_builder import DocumentDownloader + + +class TedTalksDownloader(DocumentDownloader): + def __init__(self, download_dir: str): + super().__init__() + + if not os.path.isdir(download_dir): + os.makedirs(download_dir) + + self._download_dir = download_dir + print("Download directory: ", self._download_dir) + + def download(self, url_src: str, url_tgt: str, force=False) -> str: + src_filename = os.path.basename(url_src) + src_out = os.path.join(self._download_dir, src_filename) + tgt_filename = os.path.basename(url_tgt) + tgt_out = os.path.join(self._download_dir, tgt_filename) + + sides = ["src", "tgt"] + urls = [url_src, url_tgt] + output_files = {"src": src_out, "tgt": tgt_out} + for side, url, out_file_key in zip(sides, urls, output_files): + if os.path.exists(output_files[out_file_key]) and force is False: + print( + f"File '{output_files[out_file_key]}' already exists, skipping download." + ) + else: + print(f"Downloading TED Talks dataset from '{url}'...") + response = requests.get(url) + + with open(output_files[out_file_key], "wb") as file: + file.write(response.content) + + return output_files diff --git a/tutorials/bitext_cleaning/main.py b/tutorials/bitext_cleaning/main.py new file mode 100644 index 00000000..58e52ed5 --- /dev/null +++ b/tutorials/bitext_cleaning/main.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import shutil +from functools import partial +from typing import Any + +from docbuilder import TedTalksDownloader + +from nemo_curator import ParallelScoreFilter, Sequential +from nemo_curator.datasets.parallel_dataset import ParallelDataset +from nemo_curator.filters import ( + HistogramFilter, + LengthRatioFilter, + QualityEstimationFilter, +) +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import ArgumentHelper + +TED_DE_URL = "https://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/t/en-de/raw/ted_train_en-de.raw.de" +TED_EN_URL = "https://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/t/en-de/raw/ted_train_en-de.raw.en" +SRC_LANG = "en" +TGT_LANG = "de" + +SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) +DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data") + + +def download_files() -> str: + downloader = TedTalksDownloader(DATA_DIR) + + # ted files is a dict of filenames, 'src', 'tgt' keys and their corresponding file paths as values + ted_files = downloader.download(TED_EN_URL, TED_DE_URL, force=False) + return ted_files["src"], ted_files["tgt"] + + +def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDataset: + filters = Sequential( + [ + LengthRatioFilter( + max_ratio=2, + src_lang=SRC_LANG, + tgt_lang=TGT_LANG, + score_field="length_ratio", + score_type=float, + ), + ParallelScoreFilter( + HistogramFilter(lang=SRC_LANG), + HistogramFilter(lang=TGT_LANG), + src_score="src_hist", + tgt_score="tgt_hist", + score_type=int, + ), + ] + ) + + if gpu: + filters.modules.append( + QualityEstimationFilter( + "comet-qe", + cutoff=-0.25, + gpu=gpu, + metadata_fields=["src_lang", "tgt_lang"], + score_type=float, + ) + ) + else: + print("Running on CPU, so skipping QE filtering to save time") + + filtered_dataset = filters(dataset) + return filtered_dataset + + +def run_curation_pipeline(args: Any, src_file: str, tgt_file: str) -> None: + # Initialize the Dask cluster. + client = get_client(**ArgumentHelper.parse_client_args(args)) + print(f"Running curation pipeline on '{src_file} and {tgt_file}'...") + + print("Reading the data...") + + bitext_dataset = ParallelDataset.read_simple_bitext( + src_file, tgt_file, SRC_LANG, TGT_LANG, add_filename=True, npartitions=16 + ) + curation_steps = Sequential( + [ + partial(filter_dataset, gpu=(args.device == "gpu")), + ] + ) + + dataset = curation_steps(bitext_dataset) + print("Executing the pipeline...") + dataset = dataset.persist() + + print(f"Original dataset length: {len(bitext_dataset.df)}") + print(f"After dataprep: {len(dataset.df)}") + print("Writing the results to disk...") + + # Overwrite existing files in the curated directory. + out_path = os.path.join(DATA_DIR, "curated") + + if os.path.isdir(out_path): + shutil.rmtree(out_path) + + os.makedirs(out_path) + dataset.to_bitext(out_path, write_to_filename=True) + client.close() + + +def main(): + parser = argparse.ArgumentParser() + args = ArgumentHelper(parser).add_distributed_args().parse_args() + # Limit the total number of workers to ensure we don't run out of memory. + args.n_workers = min(args.n_workers, 8) + + src_file, tgt_file = download_files() + run_curation_pipeline(args, src_file, tgt_file) + + +if __name__ == "__main__": + main()