-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Praateek <[email protected]>
- Loading branch information
1 parent
fe41ac1
commit 89b9005
Showing
3 changed files
with
147 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import warnings | ||
from abc import ABC | ||
from typing import Optional | ||
|
||
import dask.dataframe as dd | ||
|
||
from nemo_curator.datasets.doc_dataset import DocumentDataset | ||
|
||
|
||
def _perform_removal( | ||
left: dd.DataFrame, | ||
duplicates: dd.DataFrame, | ||
id_field: str, | ||
group_field: str, | ||
) -> dd.DataFrame: | ||
new_id_field = f"{id_field}_new" | ||
|
||
duplicates_to_remove = ( | ||
duplicates.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]) | ||
.drop(columns=group_field) | ||
.rename(columns={id_field: new_id_field})[[new_id_field]] | ||
) | ||
merge = left.merge( | ||
right=duplicates_to_remove, | ||
how="left", | ||
broadcast=True, | ||
left_on=id_field, | ||
right_on=new_id_field, | ||
) | ||
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field]) | ||
return removed_result | ||
|
||
|
||
class Deduplicator(ABC): | ||
def __init__( | ||
self, | ||
id_field: str, | ||
text_field: str, | ||
grouped_field: str, | ||
cache_dir: Optional[str] = None, | ||
**kwargs, | ||
): | ||
self.id_field = id_field | ||
self.text_field = text_field | ||
self.grouped_field = grouped_field | ||
self.cache_dir = cache_dir | ||
|
||
def identify(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def remove( | ||
self, dataset: DocumentDataset, duplicates: DocumentDataset | ||
) -> DocumentDataset: | ||
""" | ||
Parameters | ||
---------- | ||
dataset: DocumentDataset | ||
The input datset to remove duplicates from. | ||
duplicates: DocumentDataset | ||
The dataset containing IDs of all documents and the corresponding duplicate group | ||
they belong to. Documents in the same group are considered duplicates. | ||
Only the first document in each group is retained. | ||
Returns | ||
------- | ||
DocumentDataset of all documents with duplicates removed. | ||
""" | ||
if self.cache_dir is None: | ||
msg = "Cache directory should be specified for improved performance for removal step." | ||
warnings.warn(msg) | ||
|
||
left = dataset.df | ||
right = duplicates.df | ||
|
||
if left.npartitions < right.npartitions: | ||
msg = ( | ||
"The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. " | ||
"This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." | ||
) | ||
raise ValueError(msg) | ||
|
||
removed_result = _perform_removal( | ||
left=left, | ||
duplicates=right, | ||
id_field=self.id_field, | ||
group_field=self.grouped_field, | ||
) | ||
return DocumentDataset(removed_result) | ||
|
||
def __call__( | ||
self, dataset: DocumentDataset, perform_removal: bool = False | ||
) -> DocumentDataset: | ||
""" | ||
Parameters | ||
---------- | ||
dataset: DocumentDataset | ||
The input datset to remove duplicates from. | ||
perform_removal: bool | ||
If True, duplicates are removed from the dataset. If False, only the duplicates are identified. | ||
Returns | ||
------- | ||
DocumentDataset of all duplicates (id field, group field) if if perform_removal is False. | ||
DocumentDataset of all documents with duplicates removed if perform_removal is True. | ||
""" | ||
duplicates = self.identify(dataset) | ||
if perform_removal: | ||
return self.remove(dataset, duplicates) | ||
return duplicates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters