diff --git a/nemo_curator/_deduplicator.py b/nemo_curator/_deduplicator.py new file mode 100644 index 00000000..5be65e8d --- /dev/null +++ b/nemo_curator/_deduplicator.py @@ -0,0 +1,111 @@ +import warnings +from abc import ABC +from typing import Optional + +import dask.dataframe as dd + +from nemo_curator.datasets.doc_dataset import DocumentDataset + + +def _perform_removal( + left: dd.DataFrame, + duplicates: dd.DataFrame, + id_field: str, + group_field: str, +) -> dd.DataFrame: + new_id_field = f"{id_field}_new" + + duplicates_to_remove = ( + duplicates.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]) + .drop(columns=group_field) + .rename(columns={id_field: new_id_field})[[new_id_field]] + ) + merge = left.merge( + right=duplicates_to_remove, + how="left", + broadcast=True, + left_on=id_field, + right_on=new_id_field, + ) + removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field]) + return removed_result + + +class Deduplicator(ABC): + def __init__( + self, + id_field: str, + text_field: str, + grouped_field: str, + cache_dir: Optional[str] = None, + **kwargs, + ): + self.id_field = id_field + self.text_field = text_field + self.grouped_field = grouped_field + self.cache_dir = cache_dir + + def identify(self, *args, **kwargs): + raise NotImplementedError + + def remove( + self, dataset: DocumentDataset, duplicates: DocumentDataset + ) -> DocumentDataset: + """ + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove duplicates from. + + duplicates: DocumentDataset + The dataset containing IDs of all documents and the corresponding duplicate group + they belong to. Documents in the same group are considered duplicates. + Only the first document in each group is retained. + + Returns + ------- + DocumentDataset of all documents with duplicates removed. + """ + if self.cache_dir is None: + msg = "Cache directory should be specified for improved performance for removal step." + warnings.warn(msg) + + left = dataset.df + right = duplicates.df + + if left.npartitions < right.npartitions: + msg = ( + "The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. " + "This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." + ) + raise ValueError(msg) + + removed_result = _perform_removal( + left=left, + duplicates=right, + id_field=self.id_field, + group_field=self.grouped_field, + ) + return DocumentDataset(removed_result) + + def __call__( + self, dataset: DocumentDataset, perform_removal: bool = False + ) -> DocumentDataset: + """ + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove duplicates from. + perform_removal: bool + If True, duplicates are removed from the dataset. If False, only the duplicates are identified. + + Returns + ------- + DocumentDataset of all duplicates (id field, group field) if if perform_removal is False. + DocumentDataset of all documents with duplicates removed if perform_removal is True. + + """ + duplicates = self.identify(dataset) + if perform_removal: + return self.remove(dataset, duplicates) + return duplicates diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index d274d0e3..9516d9ec 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -18,7 +18,6 @@ import time import warnings from contextlib import nullcontext -from datetime import datetime from hashlib import md5 from typing import Optional, Union @@ -27,13 +26,14 @@ from dask import dataframe as dd from nemo_curator._compat import DASK_P2P_ERROR +from nemo_curator._deduplicator import Deduplicator from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix from nemo_curator.utils.gpu_utils import is_cudf_type -class ExactDuplicates: +class ExactDuplicates(Deduplicator): """Find exact duplicates in a document corpus""" SUPPORTED_HASHES = {"md5"} @@ -64,6 +64,14 @@ def __init__( raise ValueError( f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}" ) + + super().__init__( + id_field=id_field, + text_field=text_field, + grouped_field="_hashes", + cache_dir=cache_dir, + ) + self.hash_method = hash_method self.id_field = id_field self.text_field = text_field @@ -135,7 +143,7 @@ def hash_documents( # TODO: Generalize ty using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) - def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: + def identify(self, dataset: DocumentDataset) -> DocumentDataset: """ Find document ID's for exact duplicates in a given DocumentDataset Parameters @@ -166,10 +174,11 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: self._logger.info( f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" ) - if is_cudf_type(result): - import dask_cudf - - result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False) - else: - result_dataset = dd.read_parquet(write_path) - return DocumentDataset(result_dataset) + backend = "cudf" if is_cudf_type(result) else "pandas" + return DocumentDataset.read_parquet( + write_path, + backend=backend, + blocksize="512MiB", + files_per_partition=None, + split_row_groups=False, + ) diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py index 41ef9c9f..1f126f95 100644 --- a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py +++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py @@ -19,8 +19,7 @@ import time from typing import Union -import dask_cudf - +from nemo_curator._deduplicator import Deduplicator from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.config import FuzzyDuplicatesConfig @@ -35,7 +34,7 @@ from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix -class FuzzyDuplicates: +class FuzzyDuplicates(Deduplicator): def __init__( self, config: FuzzyDuplicatesConfig, @@ -63,6 +62,13 @@ def __init__( self._logger = logger self.config = config + + super().__init__( + id_field=self.config.id_field, + text_field=self.config.text_field, + grouped_field="group", + ) + self.minhash = MinHash( seed=self.config.seed, num_hashes=self.config.num_hashes, @@ -129,7 +135,7 @@ def __init__( profile_dir=self.config.profile_dir, ) - def __call__(self, dataset: DocumentDataset): + def identify(self, dataset: DocumentDataset) -> DocumentDataset: """ Parameters ---------- @@ -243,4 +249,10 @@ def __call__(self, dataset: DocumentDataset): print(f"Stage {stage_num}: Connected Components across buckets complete!") stage_num += 1 - return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False)) + return DocumentDataset.read_parquet( + cc_path, + backend="cudf", + blocksize="512MiB", + files_per_partition=None, + split_row_groups=False, + )