Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for semantic deduplication and DAPT tutorial #564

Merged
merged 6 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions config/sem_dedup_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@ cache_dir: "semdedup_cache"
num_files: 16

# Embeddings configuration
embeddings_save_loc: "embeddings"
embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: 128
embeddings_save_loc: "embeddings"
# Options: "mean_pooling", "last_token"
embedding_pooling_strategy: "mean_pooling"
embedding_column: "embeddings"
write_embeddings_to_disk: true
write_to_filename: false

# Clustering configuration
clustering_save_loc: "clustering_results"
n_clusters: 1000
seed: 1234
max_iter: 100
kmeans_with_cos_dist: false

# Semdedup configuration
which_to_keep: "hard"
largest_cluster_size_to_process: 100000
n_clusters: 1000
clustering_save_loc: "clustering_results"
sim_metric: "cosine"
which_to_keep: "hard"
sort_clusters: true
kmeans_with_cos_dist: false
clustering_input_partition_size: "2gb"

# Extract dedup configuration
eps_thresholds:
Expand Down
19 changes: 10 additions & 9 deletions docs/user-guide/semdedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,23 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here
num_files: -1

# Embeddings configuration
embeddings_save_loc: "embeddings"
embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: 128
embeddings_save_loc: "embeddings"
embedding_pooling_strategy: "mean_pooling"
embedding_column: "embeddings"
write_embeddings_to_disk: true
write_to_filename: false

# Clustering configuration
clustering_save_loc: "clustering_results"
n_clusters: 1000
seed: 1234
max_iter: 100
kmeans_with_cos_dist: false

# Semdedup configuration
which_to_keep: "hard"
largest_cluster_size_to_process: 100000
n_clusters: 1000
clustering_save_loc: "clustering_results"
sim_metric: "cosine"
which_to_keep: "hard"
sort_clusters: true
kmeans_with_cos_dist: false
clustering_input_partition_size: "2gb"

# Extract dedup configuration
eps_thresholds:
Expand Down
71 changes: 46 additions & 25 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -147,51 +147,72 @@ class SemDedupConfig(BaseConfig):
Attributes:
cache_dir (str): Directory to store cache.
profile_dir (Optional[str]): If specified directory to write dask profile. Default is None.
cache_dir (str): Directory to store cache.
profile_dir (Optional[str]): If specified, directory to write Dask profile.
Default is None.
num_files (int): Number of files. Default is -1, meaning all files.
embeddings_save_loc (str): Location to save embeddings.
embedding_model_name_or_path (str): Model name or path for embeddings.
embedding_batch_size (int): Inital Batch size for processing embeddings.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True.
Default is "sentence-transformers/all-MiniLM-L6-v2".
embedding_batch_size (int): Initial batch size for processing embeddings.
Default is 128.
embeddings_save_loc (str): Location to save embeddings.
Default is "embeddings".
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either
"mean_pooling" or "last_token". Default is "mean_pooling".
embedding_column (str): The column name that stores the embeddings.
Default is "embeddings".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
Setting it to False can lead to more memory overhead. Default is True.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files.
Default False.
max_iter (int): Maximum iterations for clustering. Default is 100.
n_clusters (int): Number of clusters. Default is 1000.
clustering_save_loc (str): Location to save clustering results.
n_clusters (int): Number of clusters.
seed (int): Seed for clustering.
max_iter (int): Maximum iterations for clustering.
kmeans_with_cos_dist (bool): Use KMeans with cosine distance.
which_to_keep (str): Which duplicates to keep.
largest_cluster_size_to_process (int): Largest cluster size to process.
Default is "clustering_results".
sim_metric (str): Similarity metric for deduplication.
eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically similar or not.
Default is "cosine".
which_to_keep (str): Method to determine which duplicates to keep.
Default is "hard".
sort_clusters (bool): Whether to sort clusters. Default is True.
kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance.
Default is False.
clustering_input_partition_size (str): The size of data partition with which to run KMeans.
Default is "2gb".
eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically
similar or not. Default is [0.01, 0.001].
eps_to_extract (float): Epsilon value to extract deduplicated data.
Default is 0.01.
"""

cache_dir: str
profile_dir: Optional[str] = None
num_files: int = -1

# Embeddings
embeddings_save_loc: str = "embeddings"
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: int = 128
embeddings_save_loc: str = "embeddings"
embedding_max_mem_gb: Optional[int] = None
# Options: "mean_pooling", "last_token"
embedding_pooling_strategy: str = "mean_pooling"
embedding_column: str = "embeddings"
write_embeddings_to_disk: bool = True
write_to_filename: bool = False

# Clustering config
clustering_save_loc: str = "clustering_results"
n_clusters: int = 1000
seed: int = 1234
# Clustering
max_iter: int = 100
kmeans_with_cos_dist: bool = False

# Semdedup config
which_to_keep: str = "hard"
largest_cluster_size_to_process: int = 100000
n_clusters: int = 1000
clustering_save_loc: str = "clustering_results"
sim_metric: str = "cosine"
which_to_keep: str = "hard"
sort_clusters: bool = True
kmeans_with_cos_dist: bool = False
clustering_input_partition_size: str = "2gb"

# Extract dedup config
eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001])
Expand Down
47 changes: 32 additions & 15 deletions nemo_curator/modules/semantic_dedup/clusteringmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,12 +54,12 @@ def __init__(
max_iter: int,
n_clusters: int,
clustering_output_dir: str,
embedding_col: str = "embeddings",
embedding_column: str = "embeddings",
sim_metric: str = "cosine",
which_to_keep: str = "hard",
sort_clusters: bool = True,
kmeans_with_cos_dist: bool = False,
partition_size: str = "2gb",
clustering_input_partition_size: str = "2gb",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
):
Expand All @@ -71,12 +71,12 @@ def __init__(
max_iter (int): Maximum number of iterations for the clustering algorithm.
n_clusters (int): The number of clusters to form.
clustering_output_dir (str): Directory path where clustering results will be saved.
embedding_col (str): Column name where the embeddings are stored.
embedding_column (str): Column name where the embeddings are stored.
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
sort_clusters (bool): Whether to sort clusters, default is True.
kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
partition_size (str): The size of data partition to run kmeans with, default is "2gb".
clustering_input_partition_size (str): The size of data partition to run kmeans with, default is "2gb".
logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
profile_dir (str): If specified directory to write dask profile. Default is None.

Expand All @@ -86,11 +86,11 @@ def __init__(
self.max_iter = max_iter
self.n_clusters = n_clusters
self.clustering_output_dir = clustering_output_dir
self.embedding_col = embedding_col
self.embedding_column = embedding_column
self.sim_metric = sim_metric
self.keep_hard = which_to_keep == "hard"
self.kmeans_with_cos_dist = kmeans_with_cos_dist
self.partition_size = partition_size
self.clustering_input_partition_size = clustering_input_partition_size
self.sort_clusters = sort_clusters
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir
Expand All @@ -117,22 +117,39 @@ def _setup_logger(self, logger):
def __call__(self, embeddings_dataset: DocumentDataset):
embeddings_df = embeddings_dataset.df

if self.embedding_col not in embeddings_df.columns:
if self.embedding_column not in embeddings_df.columns:
raise ValueError(
f"Expected embedding column '{self.embedding_col}'"
f"Expected embedding column '{self.embedding_column}'"
f" to be in dataset. Only found columns {embeddings_df.columns}"
)

with performance_report_if_with_ts_suffix(self.profile_dir, "clustering-model"):
embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]
embeddings_df = embeddings_df[[self.id_col, self.embedding_column]]
embeddings_df = embeddings_df.repartition(
partition_size=self.partition_size
partition_size=self.clustering_input_partition_size
)
embeddings_df = embeddings_df.to_backend("pandas").persist()

try:
embeddings_df = embeddings_df.to_backend("pandas").persist()
embeddings_length = embeddings_df.shape[0].compute()

if embeddings_length < self.n_clusters:
raise ValueError(
"Number of clusters is greater than the number of documents in your dataset: "
f"dataset length is {embeddings_length} while n_clusters is set to {self.n_clusters}. "
f"Please reduce n_clusters to be less than or equal to {embeddings_length}."
)
except IndexError as e:
raise IndexError(
f'Original error message: "{e}". '
"This could be due to empty partitions in your DocumentDataset. "
"Please check your dataset for empty partitions and remove them if necessary."
)

embeddings_df = embeddings_df.to_backend("cudf")

cupy_darr = embeddings_df.map_partitions(
get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
get_embedding_ar, self.embedding_column, meta=cp.ndarray([1, 1])
)
cupy_darr.compute_chunk_sizes()
t0 = time.time()
Expand All @@ -156,7 +173,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):
meta_df["dist_to_cent"] = cp.zeros(1)
embeddings_df = embeddings_df.map_partitions(
add_dist_to_cents,
embedding_col=self.embedding_col,
embedding_col=self.embedding_column,
centroids=kmeans.cluster_centers_,
meta=meta_df,
)
Expand Down Expand Up @@ -198,7 +215,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):
output_sorted_clusters_dir=os.path.join(
self.clustering_output_dir, "sorted"
),
embedding_col=self.embedding_col,
embedding_col=self.embedding_column,
sim_metric=self.sim_metric,
keep_hard=self.keep_hard,
kmeans_with_cos_dist=self.kmeans_with_cos_dist,
Expand Down
8 changes: 7 additions & 1 deletion nemo_curator/modules/semantic_dedup/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -228,6 +228,12 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
self.profile_dir, "embedding-creator"
):
embedding_ddf = self.create_embeddings(dataset.df, self.input_column)

# category column dtypes are not supported by the GPU-accelerated Parquet writer
for col in embedding_ddf.columns:
if embedding_ddf[col].dtype.name == "category":
embedding_ddf[col] = embedding_ddf[col].astype("str")
Comment on lines +232 to +235
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closes #505.


write_to_disk(
embedding_ddf,
self.embedding_output_dir,
Expand Down
10 changes: 5 additions & 5 deletions nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
id_column_type: str,
which_to_keep: str,
output_dir: str,
embedding_col: str = "embeddings",
embedding_column: str = "embeddings",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
) -> None:
Expand All @@ -57,7 +57,7 @@ def __init__(
id_column_type (str): Data type of the ID column.
which_to_keep (str): Strategy for which duplicate to keep.
output_dir (str): Directory to save output files.
embedding_col (str): Column where the embeddings are stored.
embedding_column (str): Column where the embeddings are stored.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
profile_dir (str): If specified directory to write dask profile. Default is None.
"""
Expand All @@ -72,7 +72,7 @@ def __init__(
output_dir, "semdedup_pruning_tables"
)
self.computed_semantic_match_dfs = False
self.embedding_col = embedding_col
self.embedding_column = embedding_column
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir

Expand Down Expand Up @@ -132,7 +132,7 @@ def compute_semantic_match_dfs(
id_col_type=self.id_col_type,
eps_list=eps_list,
output_dir=self.semdedup_pruning_tables_dir,
embedding_col=self.embedding_col,
embedding_col=self.embedding_column,
which_to_keep=self.which_to_keep,
)
)
Expand Down
14 changes: 12 additions & 2 deletions nemo_curator/modules/semantic_dedup/semdedup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,10 +50,13 @@ def __init__(
self.embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=config.embedding_model_name_or_path,
embedding_batch_size=config.embedding_batch_size,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
embedding_max_mem_gb=config.embedding_max_mem_gb,
embedding_pooling_strategy=config.embedding_pooling_strategy,
input_column=input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
embedding_column=config.embedding_column,
write_embeddings_to_disk=config.write_embeddings_to_disk,
write_to_filename=config.write_to_filename,
logger=logger,
profile_dir=self.config.profile_dir,
)
Expand All @@ -62,6 +65,12 @@ def __init__(
max_iter=config.max_iter,
n_clusters=config.n_clusters,
clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc),
embedding_column=config.embedding_column,
sim_metric=config.sim_metric,
which_to_keep=config.which_to_keep,
sort_clusters=config.sort_clusters,
kmeans_with_cos_dist=config.kmeans_with_cos_dist,
clustering_input_partition_size=config.clustering_input_partition_size,
logger=logger,
profile_dir=self.config.profile_dir,
)
Expand All @@ -77,6 +86,7 @@ def __init__(
id_column_type=id_column_type,
which_to_keep=config.which_to_keep,
output_dir=os.path.join(cache_dir, config.clustering_save_loc),
embedding_column=config.embedding_column,
logger=logger,
profile_dir=self.config.profile_dir,
)
Expand Down
Loading