Skip to content

Commit

Permalink
fc
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek <[email protected]>
  • Loading branch information
praateekmahajan committed Jan 28, 2025
1 parent fe41ac1 commit 89b9005
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 15 deletions.
111 changes: 111 additions & 0 deletions nemo_curator/_deduplicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import warnings
from abc import ABC
from typing import Optional

import dask.dataframe as dd

from nemo_curator.datasets.doc_dataset import DocumentDataset


def _perform_removal(
left: dd.DataFrame,
duplicates: dd.DataFrame,
id_field: str,
group_field: str,
) -> dd.DataFrame:
new_id_field = f"{id_field}_new"

duplicates_to_remove = (
duplicates.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")])
.drop(columns=group_field)
.rename(columns={id_field: new_id_field})[[new_id_field]]
)
merge = left.merge(
right=duplicates_to_remove,
how="left",
broadcast=True,
left_on=id_field,
right_on=new_id_field,
)
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field])
return removed_result


class Deduplicator(ABC):
def __init__(
self,
id_field: str,
text_field: str,
grouped_field: str,
cache_dir: Optional[str] = None,
**kwargs,
):
self.id_field = id_field
self.text_field = text_field
self.grouped_field = grouped_field
self.cache_dir = cache_dir

def identify(self, *args, **kwargs):
raise NotImplementedError

def remove(
self, dataset: DocumentDataset, duplicates: DocumentDataset
) -> DocumentDataset:
"""
Parameters
----------
dataset: DocumentDataset
The input datset to remove duplicates from.
duplicates: DocumentDataset
The dataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are considered duplicates.
Only the first document in each group is retained.
Returns
-------
DocumentDataset of all documents with duplicates removed.
"""
if self.cache_dir is None:
msg = "Cache directory should be specified for improved performance for removal step."
warnings.warn(msg)

left = dataset.df
right = duplicates.df

if left.npartitions < right.npartitions:
msg = (
"The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. "
"This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly."
)
raise ValueError(msg)

removed_result = _perform_removal(
left=left,
duplicates=right,
id_field=self.id_field,
group_field=self.grouped_field,
)
return DocumentDataset(removed_result)

def __call__(
self, dataset: DocumentDataset, perform_removal: bool = False
) -> DocumentDataset:
"""
Parameters
----------
dataset: DocumentDataset
The input datset to remove duplicates from.
perform_removal: bool
If True, duplicates are removed from the dataset. If False, only the duplicates are identified.
Returns
-------
DocumentDataset of all duplicates (id field, group field) if if perform_removal is False.
DocumentDataset of all documents with duplicates removed if perform_removal is True.
"""
duplicates = self.identify(dataset)
if perform_removal:
return self.remove(dataset, duplicates)
return duplicates
29 changes: 19 additions & 10 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
import warnings
from contextlib import nullcontext
from datetime import datetime
from hashlib import md5
from typing import Optional, Union

Expand All @@ -27,13 +26,14 @@
from dask import dataframe as dd

from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator._deduplicator import Deduplicator
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.gpu_utils import is_cudf_type


class ExactDuplicates:
class ExactDuplicates(Deduplicator):
"""Find exact duplicates in a document corpus"""

SUPPORTED_HASHES = {"md5"}
Expand Down Expand Up @@ -64,6 +64,14 @@ def __init__(
raise ValueError(
f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}"
)

super().__init__(
id_field=id_field,
text_field=text_field,
grouped_field="_hashes",
cache_dir=cache_dir,
)

self.hash_method = hash_method
self.id_field = id_field
self.text_field = text_field
Expand Down Expand Up @@ -135,7 +143,7 @@ def hash_documents(
# TODO: Generalize ty using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Find document ID's for exact duplicates in a given DocumentDataset
Parameters
Expand Down Expand Up @@ -166,10 +174,11 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
self._logger.info(
f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}"
)
if is_cudf_type(result):
import dask_cudf

result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False)
else:
result_dataset = dd.read_parquet(write_path)
return DocumentDataset(result_dataset)
backend = "cudf" if is_cudf_type(result) else "pandas"
return DocumentDataset.read_parquet(
write_path,
backend=backend,
blocksize="512MiB",
files_per_partition=None,
split_row_groups=False,
)
22 changes: 17 additions & 5 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import time
from typing import Union

import dask_cudf

from nemo_curator._deduplicator import Deduplicator
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand All @@ -35,7 +34,7 @@
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


class FuzzyDuplicates:
class FuzzyDuplicates(Deduplicator):
def __init__(
self,
config: FuzzyDuplicatesConfig,
Expand Down Expand Up @@ -63,6 +62,13 @@ def __init__(
self._logger = logger

self.config = config

super().__init__(
id_field=self.config.id_field,
text_field=self.config.text_field,
grouped_field="group",
)

self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
Expand Down Expand Up @@ -129,7 +135,7 @@ def __init__(
profile_dir=self.config.profile_dir,
)

def __call__(self, dataset: DocumentDataset):
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Parameters
----------
Expand Down Expand Up @@ -243,4 +249,10 @@ def __call__(self, dataset: DocumentDataset):
print(f"Stage {stage_num}: Connected Components across buckets complete!")
stage_num += 1

return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
return DocumentDataset.read_parquet(
cc_path,
backend="cudf",
blocksize="512MiB",
files_per_partition=None,
split_row_groups=False,
)

0 comments on commit 89b9005

Please sign in to comment.