diff --git a/docs/user-guide/api/dask.rst b/docs/user-guide/api/dask.rst index 8a549eba..b09d3715 100644 --- a/docs/user-guide/api/dask.rst +++ b/docs/user-guide/api/dask.rst @@ -4,4 +4,7 @@ Dask Cluster Functions .. autofunction:: nemo_curator.get_client -.. autofunction:: nemo_curator.get_network_interfaces \ No newline at end of file +.. autofunction:: nemo_curator.get_network_interfaces + +.. autoclass:: nemo_curator.ToBackend + :members: \ No newline at end of file diff --git a/docs/user-guide/cpuvsgpu.rst b/docs/user-guide/cpuvsgpu.rst index 0ee26baf..337c6e63 100644 --- a/docs/user-guide/cpuvsgpu.rst +++ b/docs/user-guide/cpuvsgpu.rst @@ -85,6 +85,46 @@ To read a dataset into GPU memory, one could use the following function call. Even if you start a GPU dask cluster, you can't operate on datasets that use a ``pandas`` backend. The ``DocuemntDataset`` must either have been originally read in with a ``cudf`` backend, or it must be transferred during the script. +----------------------------------------- +Moving data between CPU and GPU +----------------------------------------- + +The ``ToBackend`` module provides a way to move data between CPU memory and GPU memory by swapping between pandas and cuDF backends for your dataset. +To see how it works, take a look at this example. + +.. code-block:: python + + from nemo_curator import Sequential, ToBackend, ScoreFilter, get_client + from nemo_curator.datasets import DocumentDataset + from nemo_curator.classifiers import DomainClassifier + from nemo_curator.filters import RepeatingTopNGramsFilter, NonAlphaNumericFilter + + def main(): + client = get_client(cluster_type="gpu") + + dataset = DocumentDataset.read_json("books.jsonl") + curation_pipeline = Sequential([ + ScoreFilter(RepeatingTopNGramsFilter(n=5)), + ToBackend("cudf"), + DomainClassifier(), + ToBackend("pandas"), + ScoreFilter(NonAlphaNumericFilter()), + ]) + + curated_dataset = curation_pipeline(dataset) + + curated_dataset.to_json("curated_books.jsonl") + + if __name__ == "__main__": + main() + +Let's highlight some of the important parts of this example. + +* ``client = get_client(cluster_type="gpu")``: Creates a local Dask cluster with access to the GPUs. In order to use/swap to a cuDF dataframe backend, you need to make sure you are running on a GPU Dask cluster. +* ``dataset = DocumentDataset.read_json("books.jsonl")``: Reads in the dataset to a pandas (CPU) backend by default. +* ``curation_pipeline = ...``: Defines a curation pipeline consisting of a CPU filtering step, a GPU classifier step, and another CPU filtering step. The ``ToBackend("cudf")`` moves the dataset from CPU to GPU for the classifier, and the ``ToBackend("pandas")`` moves the dataset back to the CPU from the GPU for the last filter. +* ``curated_dataset.to_json("curated_books.jsonl")``: Writes the dataset directly to disk from the GPU. There is no need to transfer back to the CPU before writing to disk. + ----------------------------------------- Dask with Slurm ----------------------------------------- diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py index 4f8cdc25..585bdbc5 100644 --- a/nemo_curator/classifiers/base.py +++ b/nemo_curator/classifiers/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import List, Optional, Union import torch @@ -25,10 +25,11 @@ from transformers import AutoModel from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule from nemo_curator.utils.distributed_utils import get_gpu_memory_info -class DistributedDataClassifier(ABC): +class DistributedDataClassifier(BaseModule): """Abstract class for running multi-node multi-GPU data classification""" def __init__( @@ -43,6 +44,7 @@ def __init__( device_type: str, autocast: bool, ): + super().__init__(input_backend="cudf") self.model = model self.labels = labels self.filter_by = filter_by @@ -53,7 +55,7 @@ def __init__( self.device_type = device_type self.autocast = autocast - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: result_doc_dataset = self._run_classifier(dataset) if self.filter_by is not None: return self._filter_documents(result_doc_dataset) diff --git a/nemo_curator/filters/doc_filter.py b/nemo_curator/filters/doc_filter.py index e84765e4..e3f3a7de 100644 --- a/nemo_curator/filters/doc_filter.py +++ b/nemo_curator/filters/doc_filter.py @@ -14,7 +14,7 @@ import importlib from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any, Literal, Union from nemo_curator.filters.bitext_filter import BitextFilter @@ -81,6 +81,16 @@ def keep_document(self, scores: Any) -> bool: "keep_document method must be implemented by subclasses" ) + @property + def backend(self) -> Literal["pandas", "cudf", "any"]: + """ + The dataframe backend the filter operates on. + Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'. + Returns: + str: A string representing the dataframe backend the filter needs as input + """ + return "pandas" + @property def name(self): return self._name diff --git a/nemo_curator/modifiers/doc_modifier.py b/nemo_curator/modifiers/doc_modifier.py index 1bcde8f5..b9b900b5 100644 --- a/nemo_curator/modifiers/doc_modifier.py +++ b/nemo_curator/modifiers/doc_modifier.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Literal class DocumentModifier(ABC): @@ -26,3 +27,13 @@ def __init__(self): @abstractmethod def modify_document(self, text): pass + + @property + def backend(self) -> Literal["pandas", "cudf", "any"]: + """ + The dataframe backend the modifier operates on. + Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'. + Returns: + str: A string representing the dataframe backend the modifier needs as input + """ + return "pandas" diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 897e5402..54e0377b 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -22,12 +22,14 @@ from nemo_curator.utils.import_utils import gpu_only_import_from from .add_id import AddId +from .base import BaseModule from .config import FuzzyDuplicatesConfig, SemDedupConfig from .dataset_ops import blend_datasets, Shuffle from .exact_dedup import ExactDuplicates from .meta import Sequential from .modify import Modify from .task import TaskDecontamination +from .to_backend import ToBackend # GPU packages MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.minhash", "MinHash") @@ -88,4 +90,6 @@ "ClusteringModel", "SemanticClusterLevelDedup", "SemDedup", + "BaseModule", + "ToBackend", ] diff --git a/nemo_curator/modules/add_id.py b/nemo_curator/modules/add_id.py index 24416391..e7e733fc 100644 --- a/nemo_curator/modules/add_id.py +++ b/nemo_curator/modules/add_id.py @@ -19,18 +19,20 @@ from dask import delayed from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule from nemo_curator.utils.module_utils import count_digits -class AddId: +class AddId(BaseModule): def __init__( self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None ) -> None: + super().__init__(input_backend="pandas") self.id_field = id_field self.id_prefix = id_prefix self.start_index = start_index - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.start_index is None: return self._add_id_fast(dataset) else: diff --git a/nemo_curator/modules/base.py b/nemo_curator/modules/base.py new file mode 100644 index 00000000..ba2f3dbc --- /dev/null +++ b/nemo_curator/modules/base.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Literal, Optional + +import dask.dataframe as dd + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.gpu_utils import is_cudf_type + + +class BaseModule(ABC): + """ + Base class for all NeMo Curator modules. + + Handles validating that data lives on the correct device for each module + """ + + SUPPORTED_BACKENDS = ["pandas", "cudf", "any"] + + def __init__( + self, + input_backend: Literal["pandas", "cudf", "any"], + name: Optional[str] = None, + ) -> None: + """ + Constructs a Module + + Args: + input_backend (Literal["pandas", "cudf", "any"]): The backend the input dataframe must be on for the module to work + name (str, Optional): The name of the module. If None, defaults to self.__class__.__name__ + """ + super().__init__() + self.name = name or self.__class__.__name__ + + if input_backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"{input_backend} not one of the supported backends {self.SUPPORTED_BACKENDS}" + ) + self.input_backend = input_backend + + @abstractmethod + def call(self, dataset: DocumentDataset): + """ + Performs an arbitrary operation on a dataset + + Args: + dataset (DocumentDataset): The dataset to operate on + """ + raise NotImplementedError("call method must be implemented by subclasses") + + def _validate_correct_backend(self, ddf: dd.DataFrame): + if self.input_backend == "any": + return + + backend = "cudf" if is_cudf_type(ddf) else "pandas" + if backend != self.input_backend: + raise ValueError( + f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {backend}." + "Try using nemo_curator.ToBackend to swap dataframe backends before running this module." + ) + + def __call__(self, dataset: DocumentDataset): + """ + Validates the dataset is on the right backend, and performs an arbitrary operation on it + + Args: + dataset (DocumentDataset): The dataset to operate on + """ + self._validate_correct_backend(dataset.df) + + return self.call(dataset) diff --git a/nemo_curator/modules/dataset_ops.py b/nemo_curator/modules/dataset_ops.py index f11184a5..8240e0cf 100644 --- a/nemo_curator/modules/dataset_ops.py +++ b/nemo_curator/modules/dataset_ops.py @@ -5,13 +5,14 @@ import numpy as np from nemo_curator.datasets.doc_dataset import DocumentDataset +from nemo_curator.modules.base import BaseModule def default_filename(partition_num: int) -> str: return f"file_{partition_num:010d}.jsonl" -class Shuffle: +class Shuffle(BaseModule): def __init__( self, seed: Optional[int] = None, @@ -32,13 +33,14 @@ def __init__( will look like given the partition number. The default method names the partition f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format. """ + super().__init__(input_backend="pandas") self.seed = seed self.npartitions = npartitions self.partition_to_filename = partition_to_filename self.rand_col = "_shuffle_rand" self.filename_col = filename_col - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: if self.seed is None: return self.shuffle_nondeterministic(dataset) else: diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index d274d0e3..5d65ff1b 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -29,11 +29,12 @@ from nemo_curator._compat import DASK_P2P_ERROR from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger +from nemo_curator.modules.base import BaseModule from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix from nemo_curator.utils.gpu_utils import is_cudf_type -class ExactDuplicates: +class ExactDuplicates(BaseModule): """Find exact duplicates in a document corpus""" SUPPORTED_HASHES = {"md5"} @@ -59,6 +60,7 @@ def __init__( cache_dir: str, Default None If specified, will compute & write duplicate id's to cache directory. """ + super().__init__(input_backend="any") if hash_method not in self.SUPPORTED_HASHES: raise ValueError( @@ -135,7 +137,7 @@ def hash_documents( # TODO: Generalize ty using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) - def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: + def call(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: """ Find document ID's for exact duplicates in a given DocumentDataset Parameters diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py index 1006ee25..88825219 100644 --- a/nemo_curator/modules/filter.py +++ b/nemo_curator/modules/filter.py @@ -22,6 +22,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.datasets.parallel_dataset import ParallelDataset from nemo_curator.filters import DocumentFilter +from nemo_curator.modules.base import BaseModule from nemo_curator.utils.module_utils import is_batched # Override so that pd.NA is not passed during the metadata inference @@ -31,7 +32,7 @@ ) -class Score: +class Score(BaseModule): """ The module responsible for adding metadata to records based on statistics about the text. It accepts an arbitrary scoring function that accepts a text field and returns a score. @@ -56,12 +57,13 @@ def __init__( text_field (str): The field the documents will be read from. score_type (Union[type, str]): The datatype of the score that will be made for each document. """ + super().__init__(input_backend="pandas") self.score_fn = score_fn self.score_field = score_field self.text_field = text_field self.score_type = score_type - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Applies the scoring to a dataset @@ -89,7 +91,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: return dataset -class Filter: +class Filter(BaseModule): """ The module responsible for filtering records based on a metadata field. It accepts an arbitrary filter function that accepts a metadata field and returns True if the field should be kept. @@ -107,6 +109,7 @@ def __init__(self, filter_fn: Callable, filter_field: str, invert: bool = False) filter_field (str): The field(s) to be passed into the filter function. invert (bool): Whether to invert the filter condition. """ + super().__init__(input_backend="pandas") self.filter_fn = filter_fn self.filter_field = filter_field self.invert = invert @@ -134,7 +137,7 @@ def compute_filter_mask(self, dataset: DocumentDataset): return bool_mask - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Applies the filtering to a dataset @@ -148,7 +151,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(dataset.df[bool_mask]) -class ScoreFilter: +class ScoreFilter(BaseModule): """ The module responsible for applying a filter to all documents in a DocumentDataset. It accepts an arbitrary DocumentFilter and first computes the score for a document. @@ -176,6 +179,7 @@ def __init__( score_type (Union[type, str]): The datatype of the score that will be made for each document. invert (bool): If True, will keep all documents that are normally discarded. """ + super().__init__(input_backend=filter_obj.backend) self.filter_obj = filter_obj self.text_field = text_field self.score_field = score_field @@ -219,7 +223,7 @@ def compute_filter_mask(self, dataset: DocumentDataset): return bool_mask - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Scores and filters all records in the dataset @@ -233,7 +237,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(dataset.df[bool_mask]) -class ParallelScoreFilter: +class ParallelScoreFilter(BaseModule): def __init__( self, src_filter_obj, @@ -263,7 +267,7 @@ def __init__( score_type (Optional[str]): The datatype of the score that will be made for each document. Defaults to None. invert (bool, optional): If True, will keep all documents that are normally discarded. Defaults to False. """ - + super().__init__(input_backend=src_filter_obj.backend) self.source_score_filter = ScoreFilter( src_filter_obj, src_field, src_score, score_type, invert ) @@ -271,7 +275,7 @@ def __init__( tgt_filter_obj, tgt_field, tgt_score, score_type, invert ) - def __call__(self, dataset: ParallelDataset): + def call(self, dataset: ParallelDataset): src_bool_mask = self.source_score_filter.compute_filter_mask(dataset) tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset) diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py index 41ef9c9f..516fe9c4 100644 --- a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py +++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py @@ -23,6 +23,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger +from nemo_curator.modules.base import BaseModule from nemo_curator.modules.config import FuzzyDuplicatesConfig from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle @@ -35,7 +36,7 @@ from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix -class FuzzyDuplicates: +class FuzzyDuplicates(BaseModule): def __init__( self, config: FuzzyDuplicatesConfig, @@ -53,6 +54,7 @@ def __init__( DocumentDataset containing IDs of all documents and the corresponding duplicate group they belong to. Documents in the same group are near duplicates. """ + super().__init__(input_backend="cudf") if isinstance(logger, str): self._logger = create_logger( rank=0, @@ -129,7 +131,7 @@ def __init__( profile_dir=self.config.profile_dir, ) - def __call__(self, dataset: DocumentDataset): + def call(self, dataset: DocumentDataset): """ Parameters ---------- diff --git a/nemo_curator/modules/modify.py b/nemo_curator/modules/modify.py index 1307ab17..93f31cc9 100644 --- a/nemo_curator/modules/modify.py +++ b/nemo_curator/modules/modify.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers import DocumentModifier +from nemo_curator.modules.base import BaseModule from nemo_curator.utils.module_utils import is_batched -class Modify: +class Modify(BaseModule): def __init__(self, modifier: DocumentModifier, text_field="text"): + super().__init__(input_backend=modifier.backend) self.modifier = modifier self.text_field = text_field - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: if is_batched(self.modifier.modify_document): dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions( self.modifier.modify_document, meta=(None, str) diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index eff5e2ec..b86a468b 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -18,6 +18,7 @@ from typing import Union from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule from nemo_curator.modules.config import SemDedupConfig from nemo_curator.modules.semantic_dedup.clusteringmodel import ClusteringModel from nemo_curator.modules.semantic_dedup.embeddings import EmbeddingCreator @@ -26,7 +27,7 @@ ) -class SemDedup: +class SemDedup(BaseModule): def __init__( self, config: SemDedupConfig, @@ -42,6 +43,7 @@ def __init__( config (SemDedupConfig): Configuration for SemDedup. logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. """ + super().__init__(input_backend="cudf") self.config = config self.logger = logger cache_dir = config.cache_dir @@ -80,7 +82,7 @@ def __init__( self.eps_thresholds = config.eps_thresholds self.eps_to_extract = config.eps_to_extract - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Execute the SemDedup process. diff --git a/nemo_curator/modules/task.py b/nemo_curator/modules/task.py index 2571b6a8..06b12af1 100644 --- a/nemo_curator/modules/task.py +++ b/nemo_curator/modules/task.py @@ -20,12 +20,13 @@ from dask import delayed from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule from nemo_curator.tasks.downstream_task import DownstreamTask from nemo_curator.utils.distributed_utils import single_partition_write_with_filename from nemo_curator.utils.text_utils import get_words -class TaskDecontamination: +class TaskDecontamination(BaseModule): def __init__( self, tasks: Union[DownstreamTask, Iterable[DownstreamTask]], @@ -47,6 +48,7 @@ def __init__( max_splits: The maximum number of times a document may be split before being entirely discarded. removed_dir: If not None, the documents split too many times will be written to this directory using the filename in the dataset. """ + super().__init__(input_backend="pandas") if isinstance(tasks, DownstreamTask): tasks = [tasks] self.tasks = tasks @@ -58,7 +60,7 @@ def __init__( self.max_splits = max_splits self.removed_dir = removed_dir - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: # Convert the dataframe to delayed objects for complex operations original_meta = dataset.df.dtypes.to_dict() diff --git a/nemo_curator/modules/to_backend.py b/nemo_curator/modules/to_backend.py new file mode 100644 index 00000000..0bfcafc6 --- /dev/null +++ b/nemo_curator/modules/to_backend.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal + +from nemo_curator.datasets.doc_dataset import DocumentDataset +from nemo_curator.modules.base import BaseModule + + +class ToBackend(BaseModule): + """ + A module for moving dataframes between backends. + """ + + def __init__(self, backend: Literal["pandas", "cudf"]) -> None: + """ + Constructs a ToBackend module + + Args: + backend (str): The backend to transfer the dataset to. Can be "pandas" or "cudf" + """ + super().__init__(input_backend="any") + if backend not in BaseModule.SUPPORTED_BACKENDS: + raise ValueError( + f"Backend must be either 'pandas' or 'cudf', but got {backend}" + ) + self.backend = backend + + def call(self, dataset: DocumentDataset) -> DocumentDataset: + return DocumentDataset(dataset.df.to_backend(self.backend)) diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 00000000..6acd87b5 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,428 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +import pytest +from dask.dataframe.utils import assert_eq +from distributed import Client + +from nemo_curator import ( + BaseModule, + FuzzyDuplicates, + FuzzyDuplicatesConfig, + ScoreFilter, + Sequential, + ToBackend, +) +from nemo_curator.datasets import DocumentDataset +from nemo_curator.filters import MeanWordLengthFilter +from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from + +cudf = gpu_only_import("cudf") +dask_cudf = gpu_only_import("dask_cudf") +LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") + + +class CPUModule(BaseModule): + def __init__(self): + super().__init__(input_backend="pandas") + + def call(self, dataset: DocumentDataset): + dataset.df["cpu_lengths"] = dataset.df["text"].str.len() + return dataset + + +class GPUModule(BaseModule): + def __init__(self): + super().__init__(input_backend="cudf") + + def call(self, dataset: DocumentDataset): + dataset.df["gpu_lengths"] = dataset.df["text"].str.len() + return dataset + + +class AnyModule(BaseModule): + def __init__(self): + super().__init__(input_backend="any") + + def call(self, dataset: DocumentDataset): + dataset.df["any_lengths"] = dataset.df["text"].str.len() + return dataset + + +@pytest.fixture +def raw_data(): + base_data = { + "id": [1, 2, 3, 4, 100, 200, 300], + "text": [ + "The quick brown fox jumps over the lazy dog", + "The quick brown foxes jumps over the lazy dog", + "The quick brown wolf jumps over the lazy dog", + "The quick black cat jumps over the lazy dog", + "A test string", + "Another test string", + "A different object", + ], + } + gt_results = [43, 45, 44, 43, 13, 19, 18] + + return base_data, gt_results + + +@pytest.fixture +def cpu_data(raw_data): + base_data, gt_results = raw_data + df = pd.DataFrame(base_data) + gt_lengths = pd.Series(gt_results, name="cpu_lengths") + return DocumentDataset.from_pandas(df), gt_lengths + + +@pytest.fixture +def gpu_data(raw_data): + base_data, gt_results = raw_data + df = cudf.DataFrame(base_data) + df = dask_cudf.from_cudf(df, 2) + gt_lengths = cudf.Series(gt_results, name="gpu_lengths", dtype="int32") + return DocumentDataset(df), gt_lengths + + +@pytest.mark.gpu +class TestBackendSupport: + @pytest.fixture(autouse=True, scope="class") + def gpu_client(self, request): + with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client: + request.cls.client = client + request.cls.cluster = cluster + yield + + def test_pandas_backend( + self, + cpu_data, + ): + print("client", self.client) + dataset, gt_lengths = cpu_data + pipeline = CPUModule() + result = pipeline(dataset) + result_df = result.df.compute() + assert_eq(result_df["cpu_lengths"], gt_lengths) + + def test_cudf_backend( + self, + gpu_data, + ): + print("client", self.client) + dataset, gt_lengths = gpu_data + pipeline = GPUModule() + result = pipeline(dataset) + result_df = result.df.compute() + assert_eq(result_df["gpu_lengths"], gt_lengths) + + def test_any_backend( + self, + cpu_data, + gpu_data, + ): + print("client", self.client) + cpu_dataset, gt_cpu_lengths = cpu_data + gt_cpu_lengths = gt_cpu_lengths.rename("any_lengths") + gpu_dataset, gt_gpu_lengths = gpu_data + gt_gpu_lengths = gt_gpu_lengths.rename("any_lengths") + pipeline = AnyModule() + + cpu_result = pipeline(cpu_dataset) + cpu_result_df = cpu_result.df.compute() + assert_eq(cpu_result_df["any_lengths"], gt_cpu_lengths) + gpu_result = pipeline(gpu_dataset) + gpu_result_df = gpu_result.df.compute() + assert_eq(gpu_result_df["any_lengths"], gt_gpu_lengths) + + def test_pandas_to_cudf( + self, + cpu_data, + gpu_data, + ): + print("client", self.client) + dataset, gt_cpu_lengths = cpu_data + _, gt_gpu_lengths = gpu_data + pipeline = Sequential( + [ + CPUModule(), + ToBackend("cudf"), + GPUModule(), + ] + ) + result = pipeline(dataset) + result_df = result.df.compute() + assert_eq(result_df["cpu_lengths"], gt_cpu_lengths) + assert_eq(result_df["gpu_lengths"], gt_gpu_lengths) + + def test_cudf_to_pandas( + self, + cpu_data, + gpu_data, + ): + print("client", self.client) + _, gt_cpu_lengths = cpu_data + dataset, gt_gpu_lengths = gpu_data + pipeline = Sequential( + [ + GPUModule(), + ToBackend("pandas"), + CPUModule(), + ] + ) + result = pipeline(dataset) + result_df = result.df.compute() + assert_eq(result_df["cpu_lengths"], gt_cpu_lengths) + assert_eq(result_df["gpu_lengths"], gt_gpu_lengths) + + def test_5x_switch( + self, + cpu_data, + gpu_data, + ): + print("client", self.client) + dataset, gt_cpu_lengths = cpu_data + _, gt_gpu_lengths = gpu_data + pipeline = Sequential( + [ + CPUModule(), + ToBackend("cudf"), + GPUModule(), + ToBackend("pandas"), + CPUModule(), + ToBackend("cudf"), + GPUModule(), + ToBackend("pandas"), + CPUModule(), + ToBackend("cudf"), + GPUModule(), + ] + ) + result = pipeline(dataset) + result_df = result.df.compute() + assert sorted(list(result_df.columns)) == [ + "cpu_lengths", + "gpu_lengths", + "id", + "text", + ] + assert_eq(result_df["cpu_lengths"], gt_cpu_lengths) + assert_eq(result_df["gpu_lengths"], gt_gpu_lengths) + + def test_wrong_backend_cpu_data(self, cpu_data): + with pytest.raises(ValueError): + print("client", self.client) + dataset, _ = cpu_data + pipeline = GPUModule() + result = pipeline(dataset) + _ = result.df.compute() + + def test_wrong_backend_gpu_data(self, gpu_data): + with pytest.raises(ValueError): + print("client", self.client) + dataset, _ = gpu_data + pipeline = CPUModule() + result = pipeline(dataset) + _ = result.df.compute() + + def test_unsupported_to_backend(self, cpu_data): + with pytest.raises(ValueError): + print("client", self.client) + dataset, _ = cpu_data + pipeline = ToBackend("fake_backend") + result = pipeline(dataset) + _ = result.df.compute() + + +@pytest.fixture +def real_module_raw_data(): + base_data = { + "id": [1, 2, 3, 4, 100, 200, 300], + "text": [ + "The quick brown fox jumps over the lazy dog", + "The quick brown foxes jumps over the lazy dog", + "The quick brown wolf jumps over the lazy dog", + "The quick black cat jumps over the lazy dog", + "A test string", + "Another test string", + "A different object", + ], + } + return base_data + + +@pytest.fixture +def real_module_cpu_data(real_module_raw_data): + df = pd.DataFrame(real_module_raw_data) + gt_results = pd.Series( + [35 / 9, 37 / 9, 4.0, 35 / 9, 33 / 9, 51 / 9, 48 / 9], name="mean_lengths" + ) + return DocumentDataset.from_pandas(df), gt_results + + +@pytest.fixture +def real_module_gpu_data(real_module_raw_data): + df = cudf.DataFrame(real_module_raw_data) + df = dask_cudf.from_cudf(df, 2) + gt_results = cudf.Series([[1, 2, 3, 4], [100, 200]], name="id") + return DocumentDataset(df), gt_results + + +@pytest.mark.gpu +class TestRealModules: + @pytest.fixture(autouse=True, scope="class") + def gpu_client(self, request): + with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client: + request.cls.client = client + request.cls.cluster = cluster + yield + + def test_score_filter( + self, + real_module_cpu_data, + ): + print("client", self.client) + dataset, gt_results = real_module_cpu_data + pipeline = ScoreFilter( + MeanWordLengthFilter(), score_field="mean_lengths", score_type=float + ) + result = pipeline(dataset) + result_df = result.df.compute() + assert_eq(result_df["mean_lengths"], gt_results) + + def test_score_filter_wrong_backend( + self, + real_module_gpu_data, + ): + with pytest.raises(ValueError): + print("client", self.client) + dataset, _ = real_module_gpu_data + pipeline = ScoreFilter( + MeanWordLengthFilter(), score_field="mean_lengths", score_type=float + ) + result = pipeline(dataset) + _ = result.df.compute() + + def test_fuzzy_dedup( + self, + real_module_gpu_data, + tmpdir, + ): + print(self.client) + dataset, gt_results = real_module_gpu_data + # Dedup might fail when indices per partition do not start from 0 + dataset.df = dataset.df.reset_index(drop=True) + config = FuzzyDuplicatesConfig( + cache_dir=tmpdir, + id_field="id", + text_field="text", + seed=42, + char_ngrams=5, + num_buckets=15, + hashes_per_bucket=1, + use_64_bit_hash=False, + buckets_per_shuffle=3, + false_positive_check=True, + num_anchors=2, + jaccard_threshold=0.3, + ) + fuzzy_duplicates = FuzzyDuplicates(config=config) + result = fuzzy_duplicates(dataset) + result_df = result.df.compute() + # Drop non duplicated docs + result_df = result_df[result_df.group.duplicated(keep=False)] + result_df = result_df.groupby("group").id.agg(list) + # Sort to maintain uniform ordering + + result_df = result_df.list.sort_values() + result_df = result_df.sort_values() + gt_results = gt_results.list.sort_values() + gt_results = gt_results.sort_values() + assert_eq(gt_results, result_df, check_index=False) + + def test_fuzzy_dedup_wrong_backend( + self, + real_module_cpu_data, + tmpdir, + ): + with pytest.raises(ValueError): + print(self.client) + dataset, _ = real_module_cpu_data + # Dedup might fail when indices per partition do not start from 0 + dataset.df = dataset.df.reset_index(drop=True) + config = FuzzyDuplicatesConfig( + cache_dir=tmpdir, + id_field="id", + text_field="text", + seed=42, + char_ngrams=5, + num_buckets=15, + hashes_per_bucket=1, + use_64_bit_hash=False, + buckets_per_shuffle=3, + false_positive_check=True, + num_anchors=2, + jaccard_threshold=0.3, + ) + fuzzy_duplicates = FuzzyDuplicates(config=config) + result = fuzzy_duplicates(dataset) + _ = result.df.compute() + + def test_score_filter_and_fuzzy( + self, + real_module_cpu_data, + real_module_gpu_data, + tmpdir, + ): + print("client", self.client) + dataset, _ = real_module_cpu_data + _, gt_results = real_module_gpu_data + dataset.df = dataset.df.reset_index(drop=True) + config = FuzzyDuplicatesConfig( + cache_dir=tmpdir, + id_field="id", + text_field="text", + seed=42, + char_ngrams=5, + num_buckets=15, + hashes_per_bucket=1, + use_64_bit_hash=False, + buckets_per_shuffle=3, + false_positive_check=True, + num_anchors=2, + jaccard_threshold=0.3, + ) + pipeline = Sequential( + [ + ScoreFilter( + MeanWordLengthFilter(), score_field="mean_lengths", score_type=float + ), + ToBackend("cudf"), + FuzzyDuplicates(config=config), + ] + ) + + result = pipeline(dataset) + result_df = result.df.compute() + # Right now the output of FuzzyDuplicates does not retain the original metadata + # so we simply check the output of fuzzy dedupe to ensure accuracy + # Drop non duplicated docs + result_df = result_df[result_df.group.duplicated(keep=False)] + result_df = result_df.groupby("group").id.agg(list) + # Sort to maintain uniform ordering + result_df = result_df.list.sort_values() + result_df = result_df.sort_values() + gt_results = gt_results.list.sort_values() + gt_results = gt_results.sort_values() + assert_eq(gt_results, result_df, check_index=False)