diff --git a/tests/test_deduplicator.py b/tests/test_deduplicator.py index 42da28fa..3b681070 100644 --- a/tests/test_deduplicator.py +++ b/tests/test_deduplicator.py @@ -7,6 +7,7 @@ from nemo_curator._deduplicator import Deduplicator, _perform_removal from nemo_curator.datasets import DocumentDataset + @pytest.fixture() def dummy_deduplicator(): class TestDeduplicator(Deduplicator): @@ -17,12 +18,13 @@ def __init__(self): grouped_field="group", cache_dir=None, ) + def identify(self, ds: DocumentDataset): - """ Dummy identify which marks all documents as duplicate """ + """Dummy identify which marks all documents as duplicate""" df = ds.df.drop(columns=[self.text_field]) df[self.grouped_field] = 0 return DocumentDataset(df[[self.id_field, self.grouped_field]]) - + return TestDeduplicator() @@ -112,13 +114,13 @@ def test_not_remove_unique(ids: list[str], sample_data: dd.DataFrame): assert set(ids[30:]).issubset(set(result["id"].tolist())) -def test_deduplicator_class(dummy_deduplicator : Deduplicator): +def test_deduplicator_class(dummy_deduplicator: Deduplicator): # Create sample dataframes with specific partition counts - df1 = dd.from_pandas(pd.DataFrame({ - 'id': ['a1', 'a2', 'a3'], - 'text': ['text1', 'text2', 'text3'] - }), npartitions=2) # dataset with 2 partitions - + df1 = dd.from_pandas( + pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}), + npartitions=2, + ) # dataset with 2 partitions + dataset = DocumentDataset(df1) duplicates = dummy_deduplicator.identify(dataset) assert isinstance(duplicates, DocumentDataset) @@ -128,17 +130,15 @@ def test_deduplicator_class(dummy_deduplicator : Deduplicator): assert isinstance(result, DocumentDataset) result = result.df.compute() assert len(result) == 1 - assert list(result.columns) == ['id', 'text'] + assert list(result.columns) == ["id", "text"] # Test that it raises ValueError when right npartitions are greater than left npartitions with pytest.raises(ValueError) as exc_info: dummy_deduplicator.remove(dataset, duplicates.repartition(npartitions=3)) - + expected_msg = ( "The number of partitions in the dataset to remove duplicates from is less than " "the number of partitions in the duplicates dataset. This may lead to a shuffle " "join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." ) assert str(exc_info.value) == expected_msg - -