Skip to content

Commit

Permalink
Add comments and update example
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Dattagupta <[email protected]>
  • Loading branch information
ayushdg committed May 3, 2024
1 parent 065e139 commit 5a20aed
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
27 changes: 19 additions & 8 deletions examples/fuzzy_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,36 @@ def main(args):
dataset_dir = "/path/to/dataset"
log_dir = "./"
cache_dir = "./fuzzy_cache"
output_dir = "./"
output_dir = "./output"
dataset_id_field = "id"
dataset_text_field = "text"

filetype = "parquet"

# Fuzzy dup calculation only supports the cuDF/GPU backend
backend = "cudf"
assert args.device == "gpu"

with dask.config.set({"dataframe.backend": "cudf"}):
with dask.config.set({"dataframe.backend": backend}):
client = get_client(args, args.device)
client.run(pre_imports)

t0 = time.time()
input_dataset = DocumentDataset(
dd.read_parquet(
if filetype == "parquet":
input_dataset = DocumentDataset(
dd.read_parquet(
dataset_dir,
columns=[dataset_id_field, dataset_text_field],
blocksize="256MiB",
aggregate_files=True,
)
)
elif filetype == "jsonl":
input_dataset = DocumentDataset.read_json(
dataset_dir,
columns=[dataset_id_field, dataset_text_field],
blocksize="256MiB",
aggregate_files=True,
backend=backend,
)
)

fuzzy_dedup_config = FuzzyDuplicatesConfig(
cache_dir=cache_dir,
id_field=dataset_id_field,
Expand Down
24 changes: 24 additions & 0 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ def __init__(
config: FuzzyDuplicatesConfig,
logger: Union[logging.LoggerAdapter, str] = "./",
):
"""
Parameters
----------
config: FuzzyDuplicatesConfig,
Config options for finding FuzzyDuplicates
logger: Existing logger to log to, or a path to a log directory.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
Expand Down Expand Up @@ -451,6 +463,17 @@ def __init__(
)

def __call__(self, dataset: DocumentDataset):
"""
Parameters
----------
dataset: DocumentDataset
The input datset to compute FuzzyDuplicates. Must contain a text and unique id field.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
# Minhash + LSH
print("Stage1: Starting Minhash + LSH computation")
minhashLSH = Sequential([self.minhash, self.lsh])
Expand Down Expand Up @@ -634,6 +657,7 @@ def _get_output_map_based_on_str_bytes(
"""
Add output_partition_id to buckets_ddf
"""
documents_df = documents_df.copy()
documents_df[bytes_column] = documents_df[self.text_field].map_partitions(
lambda s: s.str.byte_count()
)
Expand Down

0 comments on commit 5a20aed

Please sign in to comment.