From 5cea7b67d597e51e88662ffe2bf89eb87b0e86a4 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Fri, 3 Jan 2025 12:16:19 -0800 Subject: [PATCH] add changes from https://github.com/NVIDIA/NeMo-Curator/pull/445 Signed-off-by: Sarah Yurick --- nemo_curator/modules/fuzzy_dedup/minhash.py | 36 ++++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/nemo_curator/modules/fuzzy_dedup/minhash.py b/nemo_curator/modules/fuzzy_dedup/minhash.py index d000e04d..b38b2268 100644 --- a/nemo_curator/modules/fuzzy_dedup/minhash.py +++ b/nemo_curator/modules/fuzzy_dedup/minhash.py @@ -24,7 +24,7 @@ import dask_cudf import numpy as np -from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE +from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix @@ -64,14 +64,16 @@ def __init__( """ self.num_hashes = num_hashes self.char_ngram = char_ngrams - if MINHASH_PERMUTED_AVAILABLE: + + if MINHASH_DEPRECATED_API: + self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed) + else: self.seeds = self.generate_hash_permutation_seeds( bit_width=64 if use_64bit_hash else 32, n_permutations=self.num_hashes, seed=seed, ) - else: - self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed) + self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32 self.id_field = id_field self.text_field = text_field @@ -137,7 +139,7 @@ def minhash32( if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") - if not MINHASH_PERMUTED_AVAILABLE: + if MINHASH_DEPRECATED_API: warnings.warn( "Using an outdated minhash implementation, please update to cuDF version 24.12 " "or later for improved performance. " @@ -150,9 +152,14 @@ def minhash32( seeds_a = cudf.Series(seeds[:, 0], dtype="uint32") seeds_b = cudf.Series(seeds[:, 1], dtype="uint32") - return ser.str.minhash_permuted( - a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram - ) + if MINHASH_PERMUTED_AVAILABLE: + return ser.str.minhash_permuted( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) + else: + return ser.str.minhash( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) def minhash64( self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int @@ -162,7 +169,7 @@ def minhash64( """ if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") - if not MINHASH_PERMUTED_AVAILABLE: + if MINHASH_DEPRECATED_API: warnings.warn( "Using an outdated minhash implementation, please update to cuDF version 24.12 " "or later for improved performance. " @@ -175,9 +182,14 @@ def minhash64( seeds_a = cudf.Series(seeds[:, 0], dtype="uint64") seeds_b = cudf.Series(seeds[:, 1], dtype="uint64") - return ser.str.minhash64_permuted( - a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram - ) + if MINHASH_PERMUTED_AVAILABLE: + return ser.str.minhash64_permuted( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) + else: + return ser.str.minhash64( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: """