diff --git a/nemo_curator/_deduplicator.py b/nemo_curator/_deduplicator.py new file mode 100644 index 00000000..ad8adfb7 --- /dev/null +++ b/nemo_curator/_deduplicator.py @@ -0,0 +1,130 @@ +import warnings +from abc import ABC +from typing import Optional + +import dask.dataframe as dd + +from nemo_curator.datasets.doc_dataset import DocumentDataset + + +def _perform_removal( + left: dd.DataFrame, + duplicates: dd.DataFrame, + id_field: str, + group_field: str, +) -> dd.DataFrame: + # Create a new column name for temporary ID storage during merge + new_id_field = f"{id_field}_new" + + duplicates_to_remove = ( + duplicates + # Redistribute data across partitions so that all duplicates are in same partition + .shuffle(on=[group_field], ignore_index=True) + # For each partition, keep only the duplicated rows (excluding first occurrence) + .map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop( + columns=group_field + ) + # Rename the ID field to avoid conflicts in the upcoming merge + .rename(columns={id_field: new_id_field})[[new_id_field]] + ) + + merge = left.merge( + right=duplicates_to_remove, + how="left", + broadcast=True, # Broadcast smaller DataFrame to all partitions + left_on=id_field, + right_on=new_id_field, + ) + + # This effectively removes all rows that were not in duplicates_to_remove + removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field]) + return removed_result + + +class Deduplicator(ABC): + def __init__( + self, + id_field: str, + text_field: str, + grouped_field: str, + cache_dir: Optional[str] = None, + **kwargs, + ): + self.id_field = id_field + self.text_field = text_field + self.grouped_field = grouped_field + self.cache_dir = cache_dir + + def identify(self, *args, **kwargs): + """Abstract method to be implemented by concrete deduplicator classes. + Should implement the logic for identifying duplicates in the dataset.""" + raise NotImplementedError + + def remove( + self, dataset: DocumentDataset, duplicates: DocumentDataset + ) -> DocumentDataset: + """ + Remove duplicate documents from the dataset based on identified duplicate groups. + + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove duplicates from. + + duplicates: DocumentDataset + The dataset containing IDs of all documents and the corresponding duplicate group + they belong to. Documents in the same group are considered duplicates. + Only the first document in each group is retained. + + Returns + ------- + DocumentDataset of all documents with duplicates removed. + """ + if self.cache_dir is None: + msg = "Cache directory should be specified for improved performance for removal step." + warnings.warn(msg) + + left = dataset.df + right = duplicates.df + + print(f"{left.npartitions=}, {right.npartitions=}") + if left.npartitions < right.npartitions: + msg = ( + "The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. " + "This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." + ) + raise ValueError(msg) + + removed_result = _perform_removal( + left=left, + duplicates=right, + id_field=self.id_field, + group_field=self.grouped_field, + ) + return DocumentDataset(removed_result) + + def __call__( + self, dataset: DocumentDataset, perform_removal: bool = False + ) -> DocumentDataset: + """ + Main entry point for deduplication process. + + Parameters + ---------- + dataset: DocumentDataset + The input datset to remove duplicates from. + perform_removal: bool + If True, duplicates are removed from the dataset. + If False, only the duplicates are identified. + + Returns + ------- + DocumentDataset of all duplicates (id field, group field) if perform_removal is False. + DocumentDataset of all documents with duplicates removed if perform_removal is True. + """ + # First identify duplicates + duplicates = self.identify(dataset) + # Then optionally remove them + if perform_removal: + return self.remove(dataset, duplicates) + return duplicates diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index d274d0e3..9516d9ec 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -18,7 +18,6 @@ import time import warnings from contextlib import nullcontext -from datetime import datetime from hashlib import md5 from typing import Optional, Union @@ -27,13 +26,14 @@ from dask import dataframe as dd from nemo_curator._compat import DASK_P2P_ERROR +from nemo_curator._deduplicator import Deduplicator from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix from nemo_curator.utils.gpu_utils import is_cudf_type -class ExactDuplicates: +class ExactDuplicates(Deduplicator): """Find exact duplicates in a document corpus""" SUPPORTED_HASHES = {"md5"} @@ -64,6 +64,14 @@ def __init__( raise ValueError( f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}" ) + + super().__init__( + id_field=id_field, + text_field=text_field, + grouped_field="_hashes", + cache_dir=cache_dir, + ) + self.hash_method = hash_method self.id_field = id_field self.text_field = text_field @@ -135,7 +143,7 @@ def hash_documents( # TODO: Generalize ty using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) - def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: + def identify(self, dataset: DocumentDataset) -> DocumentDataset: """ Find document ID's for exact duplicates in a given DocumentDataset Parameters @@ -166,10 +174,11 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: self._logger.info( f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" ) - if is_cudf_type(result): - import dask_cudf - - result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False) - else: - result_dataset = dd.read_parquet(write_path) - return DocumentDataset(result_dataset) + backend = "cudf" if is_cudf_type(result) else "pandas" + return DocumentDataset.read_parquet( + write_path, + backend=backend, + blocksize="512MiB", + files_per_partition=None, + split_row_groups=False, + ) diff --git a/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py index 1394ae9a..a346a998 100644 --- a/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py +++ b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py @@ -125,8 +125,6 @@ def _run_connected_components( f"# rows in labels_df = {len(labels_df)}" ) assert num_nodes == len(labels_df) - # Ensure all docs in the same group are in the same partition - labels_df = labels_df.shuffle(on=["group"], ignore_index=True) labels_df.to_parquet(output_path, write_index=False, overwrite=True) Comms.destroy() self._logger.info( diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py index 41ef9c9f..1f126f95 100644 --- a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py +++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py @@ -19,8 +19,7 @@ import time from typing import Union -import dask_cudf - +from nemo_curator._deduplicator import Deduplicator from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.config import FuzzyDuplicatesConfig @@ -35,7 +34,7 @@ from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix -class FuzzyDuplicates: +class FuzzyDuplicates(Deduplicator): def __init__( self, config: FuzzyDuplicatesConfig, @@ -63,6 +62,13 @@ def __init__( self._logger = logger self.config = config + + super().__init__( + id_field=self.config.id_field, + text_field=self.config.text_field, + grouped_field="group", + ) + self.minhash = MinHash( seed=self.config.seed, num_hashes=self.config.num_hashes, @@ -129,7 +135,7 @@ def __init__( profile_dir=self.config.profile_dir, ) - def __call__(self, dataset: DocumentDataset): + def identify(self, dataset: DocumentDataset) -> DocumentDataset: """ Parameters ---------- @@ -243,4 +249,10 @@ def __call__(self, dataset: DocumentDataset): print(f"Stage {stage_num}: Connected Components across buckets complete!") stage_num += 1 - return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False)) + return DocumentDataset.read_parquet( + cc_path, + backend="cudf", + blocksize="512MiB", + files_per_partition=None, + split_row_groups=False, + ) diff --git a/tests/test_deduplicator.py b/tests/test_deduplicator.py new file mode 100644 index 00000000..3b681070 --- /dev/null +++ b/tests/test_deduplicator.py @@ -0,0 +1,144 @@ +import random + +import pandas as pd +import pytest +from dask import dataframe as dd + +from nemo_curator._deduplicator import Deduplicator, _perform_removal +from nemo_curator.datasets import DocumentDataset + + +@pytest.fixture() +def dummy_deduplicator(): + class TestDeduplicator(Deduplicator): + def __init__(self): + super().__init__( + id_field="id", + text_field="text", + grouped_field="group", + cache_dir=None, + ) + + def identify(self, ds: DocumentDataset): + """Dummy identify which marks all documents as duplicate""" + df = ds.df.drop(columns=[self.text_field]) + df[self.grouped_field] = 0 + return DocumentDataset(df[[self.id_field, self.grouped_field]]) + + return TestDeduplicator() + + +@pytest.fixture() +def ids(): + # Dataset has id a0...a9, b0...b9, c0...c9, d0...d9 + l = [f"{group}{i}" for group in ["a", "b", "c", "d"] for i in range(10)] + # We shuffle it to make sure all duplicates are not in the same partition + random.shuffle(l) + return l + + +@pytest.fixture +def sample_data(ids): + df = pd.DataFrame( + { + "id": ids, + "text": [f"text for {_id}" for _id in ids], + } + ) + return dd.from_pandas(df, npartitions=4) + + +@pytest.fixture +def duplicate_data(ids): + # In each group we want to keep only the first occurrence (e.g. a1, b1, c1, d1) + df = pd.DataFrame([{"id": _id, "group": _id[0]} for _id in ids]) + # Shuffle to make sure all duplicates are not in the same partition + return dd.from_pandas(df, npartitions=2) + + +def test_perform_removal_basic(sample_data: dd.DataFrame, duplicate_data: dd.DataFrame): + # Test basic duplicate removal functionality + result = _perform_removal( + left=sample_data, duplicates=duplicate_data, id_field="id", group_field="group" + ) + + result = result.compute() + + assert list(result.columns) == ["id", "text"] + assert len(result) == 4 + # It's not guaranteed that we'll have a0, b0, c0, d0 in the result + # So we should check the first character + assert set(result["id"].apply(lambda x: x[0]).tolist()) == set(["a", "b", "c", "d"]) + + +def test_perform_removal_all_duplicates(ids: list[str], sample_data: dd.DataFrame): + duplicates = dd.from_pandas( + pd.DataFrame({"id": ids, "group": [1] * len(ids)}), npartitions=2 + ) + + result = _perform_removal( + left=sample_data, duplicates=duplicates, id_field="id", group_field="group" + ) + + result = result.compute() + assert list(result.columns) == ["id", "text"] + # Should keep only one of the occurrences + assert len(result) == 1 + + +def test_not_remove_unique(ids: list[str], sample_data: dd.DataFrame): + # We create a dataset where first 30 ids are in one group + # Next 9 ids are in distinct groups + # And last id is not mentioned in duplicates + + duplicates = dd.from_pandas( + pd.DataFrame( + { + "id": ids[:30] + ids[30:39], + "group": ["group0"] * 30 + [f"group{i}" for i in range(1, 10)], + } + ), + npartitions=2, + ) + result = _perform_removal( + left=sample_data, duplicates=duplicates, id_field="id", group_field="group" + ) + + result = result.compute() + assert list(result.columns) == ["id", "text"] + # It has 1 row from the first group of 30 + # 9 rows from the 9 distinct groups + # And 1 row from the last group which is not included in set of duplicates + assert len(result) == 1 + 9 + 1 + # The last 10 ids should be in the result, there would be one more from the first 30 + assert set(ids[30:]).issubset(set(result["id"].tolist())) + + +def test_deduplicator_class(dummy_deduplicator: Deduplicator): + # Create sample dataframes with specific partition counts + df1 = dd.from_pandas( + pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}), + npartitions=2, + ) # dataset with 2 partitions + + dataset = DocumentDataset(df1) + duplicates = dummy_deduplicator.identify(dataset) + assert isinstance(duplicates, DocumentDataset) + + # We are able to perform deduplication successfully + result = dummy_deduplicator.remove(dataset, duplicates) + assert isinstance(result, DocumentDataset) + result = result.df.compute() + assert len(result) == 1 + assert list(result.columns) == ["id", "text"] + + # Test that it raises ValueError when right npartitions are greater than left npartitions + with pytest.raises(ValueError) as exc_info: + dummy_deduplicator.remove(dataset, duplicates.repartition(npartitions=3)) + + expected_msg = ( + "The number of partitions in the dataset to remove duplicates from is less than " + "the number of partitions in the duplicates dataset. This may lead to a shuffle " + "join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." + ) + assert str(exc_info.value) == expected_msg