Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek <[email protected]>
  • Loading branch information
praateekmahajan committed Jan 30, 2025
1 parent 69c8955 commit de25476
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/test_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nemo_curator._deduplicator import Deduplicator, _perform_removal
from nemo_curator.datasets import DocumentDataset


@pytest.fixture()
def dummy_deduplicator():
class TestDeduplicator(Deduplicator):
Expand All @@ -17,12 +18,13 @@ def __init__(self):
grouped_field="group",
cache_dir=None,
)

def identify(self, ds: DocumentDataset):
""" Dummy identify which marks all documents as duplicate """
"""Dummy identify which marks all documents as duplicate"""
df = ds.df.drop(columns=[self.text_field])
df[self.grouped_field] = 0
return DocumentDataset(df[[self.id_field, self.grouped_field]])

return TestDeduplicator()


Expand Down Expand Up @@ -112,13 +114,13 @@ def test_not_remove_unique(ids: list[str], sample_data: dd.DataFrame):
assert set(ids[30:]).issubset(set(result["id"].tolist()))


def test_deduplicator_class(dummy_deduplicator : Deduplicator):
def test_deduplicator_class(dummy_deduplicator: Deduplicator):
# Create sample dataframes with specific partition counts
df1 = dd.from_pandas(pd.DataFrame({
'id': ['a1', 'a2', 'a3'],
'text': ['text1', 'text2', 'text3']
}), npartitions=2) # dataset with 2 partitions
df1 = dd.from_pandas(
pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}),
npartitions=2,
) # dataset with 2 partitions

dataset = DocumentDataset(df1)
duplicates = dummy_deduplicator.identify(dataset)
assert isinstance(duplicates, DocumentDataset)
Expand All @@ -128,17 +130,15 @@ def test_deduplicator_class(dummy_deduplicator : Deduplicator):
assert isinstance(result, DocumentDataset)
result = result.df.compute()
assert len(result) == 1
assert list(result.columns) == ['id', 'text']
assert list(result.columns) == ["id", "text"]

# Test that it raises ValueError when right npartitions are greater than left npartitions
with pytest.raises(ValueError) as exc_info:
dummy_deduplicator.remove(dataset, duplicates.repartition(npartitions=3))

expected_msg = (
"The number of partitions in the dataset to remove duplicates from is less than "
"the number of partitions in the duplicates dataset. This may lead to a shuffle "
"join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly."
)
assert str(exc_info.value) == expected_msg


0 comments on commit de25476

Please sign in to comment.