Skip to content

Commit

Permalink
add changes from #445
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Jan 3, 2025
1 parent 05d5a65 commit 5cea7b6
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions nemo_curator/modules/fuzzy_dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dask_cudf
import numpy as np

from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
Expand Down Expand Up @@ -64,14 +64,16 @@ def __init__(
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
if MINHASH_PERMUTED_AVAILABLE:

if MINHASH_DEPRECATED_API:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
else:
self.seeds = self.generate_hash_permutation_seeds(
bit_width=64 if use_64bit_hash else 32,
n_permutations=self.num_hashes,
seed=seed,
)
else:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)

self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
self.id_field = id_field
self.text_field = text_field
Expand Down Expand Up @@ -137,7 +139,7 @@ def minhash32(
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -150,9 +152,14 @@ def minhash32(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")

return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
Expand All @@ -162,7 +169,7 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -175,9 +182,14 @@ def minhash64(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")

return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash64(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Expand Down

0 comments on commit 5cea7b6

Please sign in to comment.