Skip to content

Commit

Permalink
more test for class
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek <[email protected]>
  • Loading branch information
praateekmahajan committed Jan 30, 2025
1 parent 37f6bee commit 69c8955
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
1 change: 1 addition & 0 deletions nemo_curator/_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def remove(
left = dataset.df
right = duplicates.df

print(f"{left.npartitions=}, {right.npartitions=}")
if left.npartitions < right.npartitions:
msg = (
"The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. "
Expand Down
54 changes: 52 additions & 2 deletions tests/test_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,27 @@
import pandas as pd
import pytest
from dask import dataframe as dd
from dask.dataframe.utils import assert_eq

from nemo_curator._deduplicator import _perform_removal
from nemo_curator._deduplicator import Deduplicator, _perform_removal
from nemo_curator.datasets import DocumentDataset

@pytest.fixture()
def dummy_deduplicator():
class TestDeduplicator(Deduplicator):
def __init__(self):
super().__init__(
id_field="id",
text_field="text",
grouped_field="group",
cache_dir=None,
)
def identify(self, ds: DocumentDataset):
""" Dummy identify which marks all documents as duplicate """
df = ds.df.drop(columns=[self.text_field])
df[self.grouped_field] = 0
return DocumentDataset(df[[self.id_field, self.grouped_field]])

return TestDeduplicator()


@pytest.fixture()
Expand Down Expand Up @@ -92,3 +110,35 @@ def test_not_remove_unique(ids: list[str], sample_data: dd.DataFrame):
assert len(result) == 1 + 9 + 1
# The last 10 ids should be in the result, there would be one more from the first 30
assert set(ids[30:]).issubset(set(result["id"].tolist()))


def test_deduplicator_class(dummy_deduplicator : Deduplicator):
# Create sample dataframes with specific partition counts
df1 = dd.from_pandas(pd.DataFrame({
'id': ['a1', 'a2', 'a3'],
'text': ['text1', 'text2', 'text3']
}), npartitions=2) # dataset with 2 partitions

dataset = DocumentDataset(df1)
duplicates = dummy_deduplicator.identify(dataset)
assert isinstance(duplicates, DocumentDataset)

# We are able to perform deduplication successfully
result = dummy_deduplicator.remove(dataset, duplicates)
assert isinstance(result, DocumentDataset)
result = result.df.compute()
assert len(result) == 1
assert list(result.columns) == ['id', 'text']

# Test that it raises ValueError when right npartitions are greater than left npartitions
with pytest.raises(ValueError) as exc_info:
dummy_deduplicator.remove(dataset, duplicates.repartition(npartitions=3))

expected_msg = (
"The number of partitions in the dataset to remove duplicates from is less than "
"the number of partitions in the duplicates dataset. This may lead to a shuffle "
"join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly."
)
assert str(exc_info.value) == expected_msg


0 comments on commit 69c8955

Please sign in to comment.