Skip to content

Commit

Permalink
Fix issues from RPV2 Tutorial (NVIDIA#196)
Browse files Browse the repository at this point in the history
* Add worker adjustment based on memory for add_id

Signed-off-by: Ryan Wolf <[email protected]>

* Fix map buckets text field

Signed-off-by: Ryan Wolf <[email protected]>

* Fix mismatch metadata in add_id

Signed-off-by: Ryan Wolf <[email protected]>

* Address Ayush's review

Signed-off-by: Ryan Wolf <[email protected]>

---------

Signed-off-by: Ryan Wolf <[email protected]>
Signed-off-by: Yang Yu <[email protected]>
  • Loading branch information
ryantwolf authored and yyu22 committed Oct 9, 2024
1 parent a93d3fc commit ccd3833
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions nemo_curator/modules/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _add_id_fast(self, dataset: DocumentDataset) -> DocumentDataset:
self._add_id_fast_partition,
partition_zero_padding,
meta=meta,
enforce_metadata=False,
)

return DocumentDataset(id_df)
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/scripts/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def attach_args(
"Useful for creating a copy dataset with different IDs"
)
argumentHelper.add_distributed_args()
argumentHelper.set_default_n_workers(2.5)
parser.add_argument(
"--id-field-name",
type=str,
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_anchor_and_output_map_info(
map_buckets = _MapBuckets(
id_fields=["dataset_id", "doc_id"],
bucket_field=input_bucket_field,
text_field=input_text_field,
)
ddf_anchor_docs_with_bk = map_buckets.map_buckets_with_anchors(
documents_df=ddf_text, buckets_df=ddf_bk, shuffle_type=shuffle_type
Expand Down
21 changes: 21 additions & 0 deletions nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import argparse
import os

import psutil


class ArgumentHelper:
"""
Expand Down Expand Up @@ -389,6 +391,25 @@ def add_distributed_args(self) -> argparse.ArgumentParser:

return self.parser

def set_default_n_workers(self, max_mem_gb_per_worker: float):
"""
Sets the default --n-workers for a script to maximize parallelization while
ensuring we don't trigger an out of memory error. Like --n-workers, this
only applies when running the script locally.
Args:
max_mem_per_worker (float): The maximum memory that each worker usually achieves for a script
in units of gigabytes. It can be determined by watching the Dask dashboard. This value may
change based on the size of each shard, so use a jsonl shard size of about 100 MB.
"""
cpu_worker_limit = os.cpu_count()

memory_gb = psutil.virtual_memory().total / (1024**3)
mem_worker_limit = memory_gb // max_mem_gb_per_worker

n_workers = min(cpu_worker_limit, mem_worker_limit)
self.parser.set_defaults(n_workers=n_workers)

@staticmethod
def parse_client_args(args: argparse.Namespace):
"""
Expand Down

0 comments on commit ccd3833

Please sign in to comment.