From 48d48bf0d94e5263cd773da05d8a7666406865c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Mon, 24 Feb 2025 18:56:18 +0100 Subject: [PATCH 1/2] ci: Fix cherry-pick workflow (#567) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: oliver könig --- .github/workflows/cherry-pick-release-commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cherry-pick-release-commit.yml b/.github/workflows/cherry-pick-release-commit.yml index 77699ee8..2430e828 100644 --- a/.github/workflows/cherry-pick-release-commit.yml +++ b/.github/workflows/cherry-pick-release-commit.yml @@ -7,7 +7,7 @@ on: jobs: cherry-pick: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.12.0 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.22.7 secrets: PAT: ${{ secrets.PAT }} SLACK_WEBHOOK_ADMIN: ${{ secrets.SLACK_WEBHOOK_ADMIN }} From 119edd4b93ac37ae5267c20d827fd8d993d87887 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Mon, 24 Feb 2025 13:37:11 -0800 Subject: [PATCH 2/2] Improvements for semantic deduplication and DAPT tutorial (#564) * push fixes Signed-off-by: Sarah Yurick * run black Signed-off-by: Sarah Yurick * fix cache? Signed-off-by: Sarah Yurick * add vibhu's suggestions Signed-off-by: Sarah Yurick * update config to include all params Signed-off-by: Sarah Yurick * add vibhu's comments Signed-off-by: Sarah Yurick --------- Signed-off-by: Sarah Yurick --- config/sem_dedup_config.yaml | 20 ++--- docs/user-guide/semdedup.rst | 19 ++--- nemo_curator/modules/config.py | 71 ++++++++++------ .../modules/semantic_dedup/clusteringmodel.py | 47 +++++++---- .../modules/semantic_dedup/embeddings.py | 8 +- .../semanticclusterleveldedup.py | 10 +-- .../modules/semantic_dedup/semdedup.py | 14 +++- nemo_curator/scripts/semdedup/clustering.py | 3 +- .../scripts/semdedup/extract_dedup_data.py | 1 - tests/test_semdedup.py | 81 ++++++++++++++++--- .../configs/text_semantic_dedupe_config.yaml | 21 ++--- tutorials/dapt-curation/code/main.py | 7 +- tutorials/dapt-curation/code/utils.py | 15 ++-- .../config/sem_dedup_config.yaml | 20 ++--- 14 files changed, 229 insertions(+), 108 deletions(-) diff --git a/config/sem_dedup_config.yaml b/config/sem_dedup_config.yaml index 08366d43..dc69f352 100644 --- a/config/sem_dedup_config.yaml +++ b/config/sem_dedup_config.yaml @@ -3,22 +3,24 @@ cache_dir: "semdedup_cache" num_files: 16 # Embeddings configuration -embeddings_save_loc: "embeddings" embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: 128 +embeddings_save_loc: "embeddings" +# Options: "mean_pooling", "last_token" +embedding_pooling_strategy: "mean_pooling" +embedding_column: "embeddings" write_embeddings_to_disk: true +write_to_filename: false # Clustering configuration -clustering_save_loc: "clustering_results" -n_clusters: 1000 -seed: 1234 max_iter: 100 -kmeans_with_cos_dist: false - -# Semdedup configuration -which_to_keep: "hard" -largest_cluster_size_to_process: 100000 +n_clusters: 1000 +clustering_save_loc: "clustering_results" sim_metric: "cosine" +which_to_keep: "hard" +sort_clusters: true +kmeans_with_cos_dist: false +clustering_input_partition_size: "2gb" # Extract dedup configuration eps_thresholds: diff --git a/docs/user-guide/semdedup.rst b/docs/user-guide/semdedup.rst index 172b79d0..c9093877 100644 --- a/docs/user-guide/semdedup.rst +++ b/docs/user-guide/semdedup.rst @@ -42,22 +42,23 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here num_files: -1 # Embeddings configuration - embeddings_save_loc: "embeddings" embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: 128 + embeddings_save_loc: "embeddings" + embedding_pooling_strategy: "mean_pooling" + embedding_column: "embeddings" write_embeddings_to_disk: true + write_to_filename: false # Clustering configuration - clustering_save_loc: "clustering_results" - n_clusters: 1000 - seed: 1234 max_iter: 100 - kmeans_with_cos_dist: false - - # Semdedup configuration - which_to_keep: "hard" - largest_cluster_size_to_process: 100000 + n_clusters: 1000 + clustering_save_loc: "clustering_results" sim_metric: "cosine" + which_to_keep: "hard" + sort_clusters: true + kmeans_with_cos_dist: false + clustering_input_partition_size: "2gb" # Extract dedup configuration eps_thresholds: diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 67bf06af..d3261bef 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -147,26 +147,46 @@ class SemDedupConfig(BaseConfig): Attributes: cache_dir (str): Directory to store cache. - profile_dir (Optional[str]): If specified directory to write dask profile. Default is None. - cache_dir (str): Directory to store cache. + profile_dir (Optional[str]): If specified, directory to write Dask profile. + Default is None. num_files (int): Number of files. Default is -1, meaning all files. - embeddings_save_loc (str): Location to save embeddings. + embedding_model_name_or_path (str): Model name or path for embeddings. - embedding_batch_size (int): Inital Batch size for processing embeddings. - embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". - write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True. + Default is "sentence-transformers/all-MiniLM-L6-v2". + embedding_batch_size (int): Initial batch size for processing embeddings. + Default is 128. + embeddings_save_loc (str): Location to save embeddings. + Default is "embeddings". + embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process. + If None, it defaults to the available GPU memory minus 4 GB. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either + "mean_pooling" or "last_token". Default is "mean_pooling". + embedding_column (str): The column name that stores the embeddings. + Default is "embeddings". + write_embeddings_to_disk (bool): If True, saves the embeddings to disk. We recommend setting this to False when you have a delayed pipeline. - Setting it to False can lead to more memory overhead. + Setting it to False can lead to more memory overhead. Default is True. + write_to_filename (bool): If True, saves the embeddings to the same filename as input files. + Default False. + + max_iter (int): Maximum iterations for clustering. Default is 100. + n_clusters (int): Number of clusters. Default is 1000. clustering_save_loc (str): Location to save clustering results. - n_clusters (int): Number of clusters. - seed (int): Seed for clustering. - max_iter (int): Maximum iterations for clustering. - kmeans_with_cos_dist (bool): Use KMeans with cosine distance. - which_to_keep (str): Which duplicates to keep. - largest_cluster_size_to_process (int): Largest cluster size to process. + Default is "clustering_results". sim_metric (str): Similarity metric for deduplication. - eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically similar or not. + Default is "cosine". + which_to_keep (str): Method to determine which duplicates to keep. + Default is "hard". + sort_clusters (bool): Whether to sort clusters. Default is True. + kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance. + Default is False. + clustering_input_partition_size (str): The size of data partition with which to run KMeans. + Default is "2gb". + + eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically + similar or not. Default is [0.01, 0.001]. eps_to_extract (float): Epsilon value to extract deduplicated data. + Default is 0.01. """ cache_dir: str @@ -174,24 +194,25 @@ class SemDedupConfig(BaseConfig): num_files: int = -1 # Embeddings - embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 + embeddings_save_loc: str = "embeddings" + embedding_max_mem_gb: Optional[int] = None # Options: "mean_pooling", "last_token" embedding_pooling_strategy: str = "mean_pooling" + embedding_column: str = "embeddings" write_embeddings_to_disk: bool = True + write_to_filename: bool = False - # Clustering config - clustering_save_loc: str = "clustering_results" - n_clusters: int = 1000 - seed: int = 1234 + # Clustering max_iter: int = 100 - kmeans_with_cos_dist: bool = False - - # Semdedup config - which_to_keep: str = "hard" - largest_cluster_size_to_process: int = 100000 + n_clusters: int = 1000 + clustering_save_loc: str = "clustering_results" sim_metric: str = "cosine" + which_to_keep: str = "hard" + sort_clusters: bool = True + kmeans_with_cos_dist: bool = False + clustering_input_partition_size: str = "2gb" # Extract dedup config eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001]) diff --git a/nemo_curator/modules/semantic_dedup/clusteringmodel.py b/nemo_curator/modules/semantic_dedup/clusteringmodel.py index 18714ff7..8dcc5481 100644 --- a/nemo_curator/modules/semantic_dedup/clusteringmodel.py +++ b/nemo_curator/modules/semantic_dedup/clusteringmodel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,12 +54,12 @@ def __init__( max_iter: int, n_clusters: int, clustering_output_dir: str, - embedding_col: str = "embeddings", + embedding_column: str = "embeddings", sim_metric: str = "cosine", which_to_keep: str = "hard", sort_clusters: bool = True, kmeans_with_cos_dist: bool = False, - partition_size: str = "2gb", + clustering_input_partition_size: str = "2gb", logger: Union[logging.Logger, str] = "./", profile_dir: Optional[str] = None, ): @@ -71,12 +71,12 @@ def __init__( max_iter (int): Maximum number of iterations for the clustering algorithm. n_clusters (int): The number of clusters to form. clustering_output_dir (str): Directory path where clustering results will be saved. - embedding_col (str): Column name where the embeddings are stored. + embedding_column (str): Column name where the embeddings are stored. sim_metric (str): Similarity metric to use for clustering, default is "cosine". which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard". sort_clusters (bool): Whether to sort clusters, default is True. kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False. - partition_size (str): The size of data partition to run kmeans with, default is "2gb". + clustering_input_partition_size (str): The size of data partition to run kmeans with, default is "2gb". logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./". profile_dir (str): If specified directory to write dask profile. Default is None. @@ -86,11 +86,11 @@ def __init__( self.max_iter = max_iter self.n_clusters = n_clusters self.clustering_output_dir = clustering_output_dir - self.embedding_col = embedding_col + self.embedding_column = embedding_column self.sim_metric = sim_metric self.keep_hard = which_to_keep == "hard" self.kmeans_with_cos_dist = kmeans_with_cos_dist - self.partition_size = partition_size + self.clustering_input_partition_size = clustering_input_partition_size self.sort_clusters = sort_clusters self.logger = self._setup_logger(logger) self.profile_dir = profile_dir @@ -117,22 +117,39 @@ def _setup_logger(self, logger): def __call__(self, embeddings_dataset: DocumentDataset): embeddings_df = embeddings_dataset.df - if self.embedding_col not in embeddings_df.columns: + if self.embedding_column not in embeddings_df.columns: raise ValueError( - f"Expected embedding column '{self.embedding_col}'" + f"Expected embedding column '{self.embedding_column}'" f" to be in dataset. Only found columns {embeddings_df.columns}" ) with performance_report_if_with_ts_suffix(self.profile_dir, "clustering-model"): - embeddings_df = embeddings_df[[self.id_col, self.embedding_col]] + embeddings_df = embeddings_df[[self.id_col, self.embedding_column]] embeddings_df = embeddings_df.repartition( - partition_size=self.partition_size + partition_size=self.clustering_input_partition_size ) - embeddings_df = embeddings_df.to_backend("pandas").persist() + + try: + embeddings_df = embeddings_df.to_backend("pandas").persist() + embeddings_length = embeddings_df.shape[0].compute() + + if embeddings_length < self.n_clusters: + raise ValueError( + "Number of clusters is greater than the number of documents in your dataset: " + f"dataset length is {embeddings_length} while n_clusters is set to {self.n_clusters}. " + f"Please reduce n_clusters to be less than or equal to {embeddings_length}." + ) + except IndexError as e: + raise IndexError( + f'Original error message: "{e}". ' + "This could be due to empty partitions in your DocumentDataset. " + "Please check your dataset for empty partitions and remove them if necessary." + ) + embeddings_df = embeddings_df.to_backend("cudf") cupy_darr = embeddings_df.map_partitions( - get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1]) + get_embedding_ar, self.embedding_column, meta=cp.ndarray([1, 1]) ) cupy_darr.compute_chunk_sizes() t0 = time.time() @@ -156,7 +173,7 @@ def __call__(self, embeddings_dataset: DocumentDataset): meta_df["dist_to_cent"] = cp.zeros(1) embeddings_df = embeddings_df.map_partitions( add_dist_to_cents, - embedding_col=self.embedding_col, + embedding_col=self.embedding_column, centroids=kmeans.cluster_centers_, meta=meta_df, ) @@ -198,7 +215,7 @@ def __call__(self, embeddings_dataset: DocumentDataset): output_sorted_clusters_dir=os.path.join( self.clustering_output_dir, "sorted" ), - embedding_col=self.embedding_col, + embedding_col=self.embedding_column, sim_metric=self.sim_metric, keep_hard=self.keep_hard, kmeans_with_cos_dist=self.kmeans_with_cos_dist, diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 7f6315e5..f4abe987 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -228,6 +228,12 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: self.profile_dir, "embedding-creator" ): embedding_ddf = self.create_embeddings(dataset.df, self.input_column) + + # category column dtypes are not supported by the GPU-accelerated Parquet writer + for col in embedding_ddf.columns: + if embedding_ddf[col].dtype.name == "category": + embedding_ddf[col] = embedding_ddf[col].astype("str") + write_to_disk( embedding_ddf, self.embedding_output_dir, diff --git a/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py index 4329c2b0..65ff9b68 100644 --- a/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py +++ b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ def __init__( id_column_type: str, which_to_keep: str, output_dir: str, - embedding_col: str = "embeddings", + embedding_column: str = "embeddings", logger: Union[logging.Logger, str] = "./", profile_dir: Optional[str] = None, ) -> None: @@ -57,7 +57,7 @@ def __init__( id_column_type (str): Data type of the ID column. which_to_keep (str): Strategy for which duplicate to keep. output_dir (str): Directory to save output files. - embedding_col (str): Column where the embeddings are stored. + embedding_column (str): Column where the embeddings are stored. logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. profile_dir (str): If specified directory to write dask profile. Default is None. """ @@ -72,7 +72,7 @@ def __init__( output_dir, "semdedup_pruning_tables" ) self.computed_semantic_match_dfs = False - self.embedding_col = embedding_col + self.embedding_column = embedding_column self.logger = self._setup_logger(logger) self.profile_dir = profile_dir @@ -132,7 +132,7 @@ def compute_semantic_match_dfs( id_col_type=self.id_col_type, eps_list=eps_list, output_dir=self.semdedup_pruning_tables_dir, - embedding_col=self.embedding_col, + embedding_col=self.embedding_column, which_to_keep=self.which_to_keep, ) ) diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index a8c66e31..99a9d16c 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,10 +50,13 @@ def __init__( self.embedding_creator = EmbeddingCreator( embedding_model_name_or_path=config.embedding_model_name_or_path, embedding_batch_size=config.embedding_batch_size, + embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), + embedding_max_mem_gb=config.embedding_max_mem_gb, embedding_pooling_strategy=config.embedding_pooling_strategy, input_column=input_column, - embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), + embedding_column=config.embedding_column, write_embeddings_to_disk=config.write_embeddings_to_disk, + write_to_filename=config.write_to_filename, logger=logger, profile_dir=self.config.profile_dir, ) @@ -62,6 +65,12 @@ def __init__( max_iter=config.max_iter, n_clusters=config.n_clusters, clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc), + embedding_column=config.embedding_column, + sim_metric=config.sim_metric, + which_to_keep=config.which_to_keep, + sort_clusters=config.sort_clusters, + kmeans_with_cos_dist=config.kmeans_with_cos_dist, + clustering_input_partition_size=config.clustering_input_partition_size, logger=logger, profile_dir=self.config.profile_dir, ) @@ -77,6 +86,7 @@ def __init__( id_column_type=id_column_type, which_to_keep=config.which_to_keep, output_dir=os.path.join(cache_dir, config.clustering_save_loc), + embedding_column=config.embedding_column, logger=logger, profile_dir=self.config.profile_dir, ) diff --git a/nemo_curator/scripts/semdedup/clustering.py b/nemo_curator/scripts/semdedup/clustering.py index db4885c3..7f970336 100644 --- a/nemo_curator/scripts/semdedup/clustering.py +++ b/nemo_curator/scripts/semdedup/clustering.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -93,7 +93,6 @@ def attach_args(): " cache_dir for the directory to store cache," " clustering_save_loc for the location to save clustering results," " n_clusters for the number of clusters," - " seed for the seed for clustering," " max_iter for the maximum iterations for clustering," " kmeans_with_cos_dist for using K-Means with cosine distance." ), diff --git a/nemo_curator/scripts/semdedup/extract_dedup_data.py b/nemo_curator/scripts/semdedup/extract_dedup_data.py index b6ffaebc..788c02bd 100755 --- a/nemo_curator/scripts/semdedup/extract_dedup_data.py +++ b/nemo_curator/scripts/semdedup/extract_dedup_data.py @@ -72,7 +72,6 @@ def attach_args(): "Important configuration parameters include:" " cache_dir for the directory to store cache" " which_to_keep for specifying which duplicates to keep," - " largest_cluster_size_to_process for the largest cluster size to process," " sim_metric for the similarity metric for deduplication," " eps_thresholds for epsilon thresholds to calculate if semantically similar or not" " and eps_to_extract for the epsilon value to extract deduplicated data." diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 44f7f555..344e5469 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ import torch import torch.nn.functional as F from dask.dataframe.utils import assert_eq -from distributed import Client from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo_curator import SemDedup, SemDedupConfig @@ -52,34 +51,98 @@ def dedup_data(): return DocumentDataset(df) +@pytest.fixture +def non_dedup_data(): + df = cudf.DataFrame( + { + "doc_id": ["doc_1", "doc_2"], + "text": [ + "The quick brown fox jumps over the lazy dog", + "A test string", + ], + } + ) + df = dask_cudf.from_cudf(df, 2) + return DocumentDataset(df) + + @pytest.mark.gpu class TestSemDuplicates: + @pytest.mark.parametrize("n_clusters", [3, 10]) def test_sem_dedup( self, dedup_data, tmpdir, + n_clusters, gpu_client, ): print("client", gpu_client) + cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache") config = SemDedupConfig( cache_dir=cache_dir, - seed=42, - n_clusters=3, + n_clusters=n_clusters, eps_thresholds=[0.10], eps_to_extract=0.10, ) + sem_duplicates = SemDedup( config=config, input_column="text", id_column="id", id_column_type="int", ) - result = sem_duplicates(dedup_data) - result_df = result.df.compute() - duplicate_docs = [2, 3, 4, 200, 300] - expected_df = cudf.Series(duplicate_docs, name="id") - assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + + dedup_data_len = dedup_data.df.shape[0].compute() + if n_clusters > dedup_data_len: + # Number of records in the dataset should never be less than n_clusters + with pytest.raises(ValueError): + result = sem_duplicates(dedup_data) + else: + # Correctly returns the original dataset with no duplicates removed + result = sem_duplicates(dedup_data) + result_df = result.df.compute() + duplicate_docs = [2, 3, 4, 200, 300] + expected_df = cudf.Series(duplicate_docs, name="id") + assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + + @pytest.mark.parametrize("n_clusters", [2, 3]) + def test_no_sem_dedup( + self, + non_dedup_data, + tmpdir, + n_clusters, + gpu_client, + ): + print("client", gpu_client) + + cache_dir = os.path.join(tmpdir, "test_no_sem_dedup") + config = SemDedupConfig( + cache_dir=cache_dir, + n_clusters=n_clusters, + eps_thresholds=[0.10], + eps_to_extract=0.10, + ) + + sem_duplicates = SemDedup( + config=config, + input_column="text", + id_column="doc_id", + id_column_type="str", + ) + + non_dedup_data_len = non_dedup_data.df.shape[0].compute() + if n_clusters > non_dedup_data_len: + # Number of records in the dataset should never be less than n_clusters + with pytest.raises(ValueError): + result = sem_duplicates(non_dedup_data) + else: + # Correctly returns the original dataset with no duplicates removed + result = sem_duplicates(non_dedup_data) + result_df = result.df.compute() + duplicate_docs = ["doc_1", "doc_2"] + expected_df = cudf.Series(duplicate_docs, name="doc_id") + assert_eq(result_df["doc_id"].sort_values(), expected_df, check_index=False) @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"]) def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): diff --git a/tutorials/dapt-curation/code/configs/text_semantic_dedupe_config.yaml b/tutorials/dapt-curation/code/configs/text_semantic_dedupe_config.yaml index 5b8e63b7..baf2ef70 100644 --- a/tutorials/dapt-curation/code/configs/text_semantic_dedupe_config.yaml +++ b/tutorials/dapt-curation/code/configs/text_semantic_dedupe_config.yaml @@ -3,22 +3,23 @@ cache_dir: "workspace/semdedup_cache/text" num_files: 16 # Embeddings configuration -embeddings_save_loc: "embeddings" embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: 128 -write_embeddings_to_disk: false +embeddings_save_loc: "embeddings" +embedding_pooling_strategy: "mean_pooling" +embedding_column: "embeddings" +write_embeddings_to_disk: true +write_to_filename: false # Clustering configuration -clustering_save_loc: "clustering_results" -n_clusters: 20 -seed: 1234 max_iter: 100 -kmeans_with_cos_dist: false - -# Semdedup configuration -which_to_keep: "hard" -largest_cluster_size_to_process: 100000 +n_clusters: 20 +clustering_save_loc: "clustering_results" sim_metric: "cosine" +which_to_keep: "hard" +sort_clusters: true +kmeans_with_cos_dist: false +clustering_input_partition_size: "2gb" # Extract dedup configuration eps_thresholds: diff --git a/tutorials/dapt-curation/code/main.py b/tutorials/dapt-curation/code/main.py index 5f51ead8..44679f9b 100755 --- a/tutorials/dapt-curation/code/main.py +++ b/tutorials/dapt-curation/code/main.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -188,7 +188,6 @@ def run_curation_pipeline(args: Any, text_files: str, code_files: str) -> None: duplicates = semantic_dedupe( dataset=gpu_dataset_text, sem_dedupe_config_yaml_path=sem_dedupe_config_yaml_path, - cache_dir=CACHE_DIR, ) unique_ids = duplicates.df.to_backend("pandas").compute()["id"] semantic_dataset_text = DocumentDataset( @@ -200,11 +199,11 @@ def run_curation_pipeline(args: Any, text_files: str, code_files: str) -> None: CACHE_DIR = os.path.join(SCRIPT_DIR_PATH, "cache", "fuzzy_dedupe", "text") rm_dir(CACHE_DIR) fuzzy_dataset_text = fuzzy_dedupe( - dataset=semantic_dataset_text, cache=CACHE_DIR + dataset=semantic_dataset_text, cache_dir=CACHE_DIR ) CACHE_DIR = os.path.join(SCRIPT_DIR_PATH, "cache", "fuzzy_dedupe", "code") rm_dir(CACHE_DIR) - fuzzy_dataset_code = fuzzy_dedupe(dataset=gpu_dataset_code, cache=CACHE_DIR) + fuzzy_dataset_code = fuzzy_dedupe(dataset=gpu_dataset_code, cache_dir=CACHE_DIR) dataset_text.df = fuzzy_dataset_text.df.to_backend("pandas") dataset_code.df = fuzzy_dataset_code.df.to_backend("pandas") diff --git a/tutorials/dapt-curation/code/utils.py b/tutorials/dapt-curation/code/utils.py index c8163794..25df1dda 100755 --- a/tutorials/dapt-curation/code/utils.py +++ b/tutorials/dapt-curation/code/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -284,19 +284,19 @@ def exact_dedupe(dataset: DocumentDataset) -> DocumentDataset: return DocumentDataset(deduped) -def fuzzy_dedupe(dataset: DocumentDataset, cache: str) -> DocumentDataset: +def fuzzy_dedupe(dataset: DocumentDataset, cache_dir: str) -> DocumentDataset: """ Removes near-duplicate documents and code lines Args: dataset (DocumentDataset): The dataset containing documents. - type (str): Document type to process. + cache_dir (str): Directory for storing intermediate results. Returns: DocumentDataset: The deduplicated dataset. """ fuzzy_dedup_config = FuzzyDuplicatesConfig( - cache_dir=cache, + cache_dir=cache_dir, id_field="id", text_field="text", seed=42, @@ -322,14 +322,15 @@ def fuzzy_dedupe(dataset: DocumentDataset, cache: str) -> DocumentDataset: def semantic_dedupe( - dataset: DocumentDataset, sem_dedupe_config_yaml_path: str, cache_dir: str + dataset: DocumentDataset, + sem_dedupe_config_yaml_path: str, ): """ Perform semantic deduplication on the given dataset. Args: dataset (DocumentDataset): The dataset containing documents. - type (str): Document type to process. + sem_dedupe_config_yaml_path (str): The path to the semantic dedupe configuration file. Returns: The deduplicated DocumentDataset. @@ -390,6 +391,6 @@ def keep_document(self, score) -> bool: return score -def rm_dir(cache_dir): +def rm_dir(cache_dir: str): if os.path.isdir(cache_dir): os.system(f"rm -rf {cache_dir}") diff --git a/tutorials/peft-curation-with-sdg/config/sem_dedup_config.yaml b/tutorials/peft-curation-with-sdg/config/sem_dedup_config.yaml index 93ec29cb..82532b15 100644 --- a/tutorials/peft-curation-with-sdg/config/sem_dedup_config.yaml +++ b/tutorials/peft-curation-with-sdg/config/sem_dedup_config.yaml @@ -3,21 +3,23 @@ cache_dir: "_temp/semdedup_cache" num_files: 16 # Embeddings configuration -embeddings_save_loc: "embeddings" embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: 128 +embeddings_save_loc: "embeddings" +embedding_pooling_strategy: "mean_pooling" +embedding_column: "embeddings" +write_embeddings_to_disk: true +write_to_filename: false # Clustering configuration -clustering_save_loc: "clustering_results" -n_clusters: 20 -seed: 1234 max_iter: 100 -kmeans_with_cos_dist: false - -# Semdedup configuration -which_to_keep: "hard" -largest_cluster_size_to_process: 100000 +n_clusters: 20 +clustering_save_loc: "clustering_results" sim_metric: "cosine" +which_to_keep: "hard" +sort_clusters: true +kmeans_with_cos_dist: false +clustering_input_partition_size: "2gb" # Extract dedup configuration eps_thresholds: