Skip to content

Commit

Permalink
add vibhu's suggestions
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Feb 23, 2025
1 parent b6d239c commit 6a75a1a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 55 deletions.
15 changes: 11 additions & 4 deletions nemo_curator/modules/semantic_dedup/clusteringmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ def __call__(self, embeddings_dataset: DocumentDataset):

try:
embeddings_df = embeddings_df.to_backend("pandas").persist()
except IndexError:
raise RuntimeError(
"DocumentDataset contains empty partitions. "
"Please remove empty partitions from the dataset before running semantic deduplication."

if embeddings_df.shape[0].compute() < self.n_clusters:
raise ValueError(
"Number of clusters is greater than the number of documents in your dataset. "
"Please reduce n_clusters to be less than or equal."
)
except IndexError as e:
raise IndexError(
f'Original error message: "{e}". '
"This could be due to empty partitions in your DocumentDataset. "
"Please check your dataset for empty partitions and remove them if necessary."
)

embeddings_df = embeddings_df.to_backend("cudf")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -163,8 +163,7 @@ def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
output_parquet_path = os.path.join(
self.output_dir, f"unique_ids_{eps_to_extract}.parquet"
)

result = extract_dedup_data(
extract_dedup_data(
eps=eps_to_extract,
n_clusters=self.n_clusters,
id_col=self.id_col,
Expand All @@ -176,9 +175,6 @@ def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
logger=self.logger,
profile_dir=self.profile_dir,
)
# Result is None if there are no duplicates
if result is None:
return None

fps = [
os.path.join(output_parquet_path, file_name)
Expand Down
11 changes: 2 additions & 9 deletions nemo_curator/modules/semantic_dedup/semdedup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -96,13 +96,6 @@ def call(self, dataset: DocumentDataset) -> DocumentDataset:
embeddings_dataset = self.embedding_creator(dataset)
self.clustering_model(embeddings_dataset)
self.semantic_cluster_dedup.compute_semantic_match_dfs(self.eps_thresholds)

result = self.semantic_cluster_dedup.extract_dedup_data(
return self.semantic_cluster_dedup.extract_dedup_data(
eps_to_extract=self.eps_to_extract
)

# If no duplicates are found, return the original dataset
if result is not None:
return result
else:
return dataset
34 changes: 12 additions & 22 deletions nemo_curator/utils/semdedup_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -395,11 +395,6 @@ def extract_pruned_data(

t0 = time.time()

np_files = [
os.path.join(sorted_clusters_dir, f"cluster_{i}.npy") for i in range(n_clusters)
]
total_records = sum(get_num_records(file_path) for file_path in np_files)

with performance_report_if_with_ts_suffix(
profile_dir,
"extracting-pruned-from-clusters",
Expand All @@ -417,20 +412,18 @@ def extract_pruned_data(
results_df = results_df.persist()
progress(results_df)

try:
results_df.to_parquet(output_parquet_path)
except TypeError:
# Catching "Implicit conversion to a host PyArrow object via __arrow_array__ is not allowed"
logger.info("No semantic duplicates found")
return total_records, 0, total_records

results_df.to_parquet(output_parquet_path)
if logger:
logger.info(
f"Time taken for Extracting Pruned Data : {time.time() - t0} and output written at {output_parquet_path}"
)

total_kept = len(results_df)

np_files = [
os.path.join(sorted_clusters_dir, f"cluster_{i}.npy") for i in range(n_clusters)
]
total_records = sum(get_num_records(file_path) for file_path in np_files)
# Aggregate results
total_removed = total_records - total_kept
return total_kept, total_removed, total_records
Expand Down Expand Up @@ -479,12 +472,9 @@ def extract_dedup_data(
df = pd.DataFrame(result_dict)
df.to_csv(output_summary_file, index=False)

if removed > 0:
fps = [
os.path.join(output_parquet_path, file_name)
for file_name in os.listdir(output_parquet_path)
]
ids_to_keep_df = dd.from_map(cudf.read_parquet, fps)
return ids_to_keep_df
else:
return None
fps = [
os.path.join(output_parquet_path, file_name)
for file_name in os.listdir(output_parquet_path)
]
ids_to_keep_df = dd.from_map(cudf.read_parquet, fps)
return ids_to_keep_df
51 changes: 37 additions & 14 deletions tests/test_semdedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch
import torch.nn.functional as F
from dask.dataframe.utils import assert_eq
from distributed import Client
from transformers import AutoConfig, AutoModel, AutoTokenizer

from nemo_curator import SemDedup, SemDedupConfig
Expand Down Expand Up @@ -69,59 +68,83 @@ def non_dedup_data():

@pytest.mark.gpu
class TestSemDuplicates:
@pytest.mark.parametrize("n_clusters", [3, 10])
def test_sem_dedup(
self,
dedup_data,
tmpdir,
n_clusters,
gpu_client,
):
print("client", gpu_client)

cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache")
config = SemDedupConfig(
cache_dir=cache_dir,
seed=42,
n_clusters=3,
n_clusters=n_clusters,
eps_thresholds=[0.10],
eps_to_extract=0.10,
)

sem_duplicates = SemDedup(
config=config,
input_column="text",
id_column="id",
id_column_type="int",
)
result = sem_duplicates(dedup_data)
result_df = result.df.compute()
duplicate_docs = [2, 3, 4, 200, 300]
expected_df = cudf.Series(duplicate_docs, name="id")
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

dedup_data_len = dedup_data.df.shape[0].compute()
if n_clusters > dedup_data_len:
# Number of records in the dataset should never be less than n_clusters
with pytest.raises(ValueError):
result = sem_duplicates(dedup_data)
else:
# Correctly returns the original dataset with no duplicates removed
result = sem_duplicates(dedup_data)
result_df = result.df.compute()
duplicate_docs = [2, 3, 4, 200, 300]
expected_df = cudf.Series(duplicate_docs, name="id")
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

@pytest.mark.parametrize("n_clusters", [2, 3])
def test_no_sem_dedup(
self,
non_dedup_data,
tmpdir,
n_clusters,
gpu_client,
):
print("client", gpu_client)
cache_dir = os.path.join(tmpdir, "test_no_sem_dedup_cache")

cache_dir = os.path.join(tmpdir, "test_no_sem_dedup")
config = SemDedupConfig(
cache_dir=cache_dir,
seed=42,
n_clusters=3,
n_clusters=n_clusters,
eps_thresholds=[0.10],
eps_to_extract=0.10,
)

sem_duplicates = SemDedup(
config=config,
input_column="text",
id_column="doc_id",
id_column_type="str",
)
result = sem_duplicates(non_dedup_data)
result_df = result.df.compute()
duplicate_docs = ["doc_1", "doc_2"]
expected_df = cudf.Series(duplicate_docs, name="doc_id")
assert_eq(result_df["doc_id"].sort_values(), expected_df, check_index=False)

non_dedup_data_len = non_dedup_data.df.shape[0].compute()
if n_clusters > non_dedup_data_len:
# Number of records in the dataset should never be less than n_clusters
with pytest.raises(ValueError):
result = sem_duplicates(non_dedup_data)
else:
# Correctly returns the original dataset with no duplicates removed
result = sem_duplicates(non_dedup_data)
result_df = result.df.compute()
duplicate_docs = ["doc_1", "doc_2"]
expected_df = cudf.Series(duplicate_docs, name="doc_id")
assert_eq(result_df["doc_id"].sort_values(), expected_df, check_index=False)

@pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"])
def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy):
Expand Down

0 comments on commit 6a75a1a

Please sign in to comment.