Skip to content

Commit

Permalink
Prevent plugging an allocator twice (NVIDIA#154)
Browse files Browse the repository at this point in the history
* Preving plugging an allocator twice

Signed-off-by: Vibhu Jawa <[email protected]>

* Remove extra import

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix defaults for RMM-POOL and other style fixes

Signed-off-by: Vibhu Jawa <[email protected]>

* Switch it rmm_pytorch off by default

Signed-off-by: Vibhu Jawa <[email protected]>

---------

Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Jul 19, 2024
1 parent a79af76 commit fb12646
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 66 deletions.
9 changes: 8 additions & 1 deletion nemo_curator/scripts/domain_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,15 @@ def main():
if not os.path.exists(args.output_data_dir):
os.makedirs(args.output_data_dir)

# Some times jsonl files are stored as .json
# So to handle that case we can pass the input_file_extension
if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
else:
input_file_extension = args.input_file_type

input_files = get_remaining_files(
args.input_data_dir, args.output_data_dir, args.input_file_type
args.input_data_dir, args.output_data_dir, input_file_extension
)
print(f"Total input files {len(input_files)}", flush=True)

Expand Down
1 change: 0 additions & 1 deletion nemo_curator/scripts/find_exact_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def main(args):
logger.info(f"Starting workflow with args:\n {args}")

assert args.hash_method == "md5", "Currently only md5 hash is supported"
args.set_torch_to_use_rmm = False
client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")
if args.device == "gpu":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main(args):
assert args.hash_bytes in {4, 8}, "Currently only 32bit/64bit hashes are supported"
assert args.device == "gpu"

args.set_torch_to_use_rmm = False
client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")
client.run(pre_imports)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def main(args):
"""
st = time.time()
output_path = os.path.join(args.output_dir, "connected_components.parquet")
args.set_torch_to_use_rmm = False
args.enable_spilling = True

client = get_client(**ArgumentHelper.parse_client_args(args))
Expand Down
1 change: 0 additions & 1 deletion nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def main(args):
logger.info(f"Starting workflow with args:\n {args}")

assert args.device == "gpu"
args.set_torch_to_use_rmm = False
client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")
client.run(pre_imports)
Expand Down
9 changes: 8 additions & 1 deletion nemo_curator/scripts/quality_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,15 @@ def main():
if not os.path.exists(args.output_data_dir):
os.makedirs(args.output_data_dir)

# Some time jsonl files are stored as .json
# So to handle that case we can pass the input_file_extension
if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
else:
input_file_extension = args.input_file_type

input_files = get_remaining_files(
args.input_data_dir, args.output_data_dir, args.input_file_type
args.input_data_dir, args.output_data_dir, input_file_extension
)
print(f"Total input files {len(input_files)}", flush=True)

Expand Down
15 changes: 11 additions & 4 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def get_client(
protocol="tcp",
rmm_pool_size="1024M",
enable_spilling=True,
set_torch_to_use_rmm=True,
set_torch_to_use_rmm=False,
) -> Client:
"""
Initializes or connects to a Dask cluster.
The Dask cluster can be CPU-based or GPU-based (if GPUs are available).
The intialization ensures maximum memory efficiency for the GPU by:
1. Ensuring the PyTorch memory pool is the same as the RAPIDS memory pool.
2. Enabling spilling for cuDF.
1. Ensuring the PyTorch memory pool is the same as the RAPIDS memory pool. (If `set_torch_to_use_rmm` is True)
2. Enabling spilling for cuDF. (If `enable_spilling` is True)
Args:
cluster_type: The type of cluster to set up. Either "cpu" or "gpu". Defaults to "cpu".
Expand Down Expand Up @@ -171,11 +171,18 @@ def _set_torch_to_use_rmm():
See article:
https://medium.com/rapids-ai/pytorch-rapids-rmm-maximize-the-memory-efficiency-of-your-workflows-f475107ba4d4
"""

import torch
from rmm.allocators.torch import rmm_torch_allocator

if torch.cuda.get_allocator_backend() == "pluggable":
warnings.warn(
"PyTorch allocator already plugged in, not switching to RMM. "
"Please ensure you have not already swapped it."
)
return

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


Expand Down
92 changes: 36 additions & 56 deletions nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ def add_arg_text_ddf_blocksize(self):
help="The block size for chunking jsonl files for text ddf in mb",
)

def add_arg_model_path(self, help="The path to the model file"):
self.parser.add_argument(
"--model-path",
type=str,
help=help,
required=True,
)

def add_arg_autocaset(self, help="Whether to use autocast or not"):
ArgumentHelper.attach_bool_arg(
parser=self.parser,
flag_name="autocast",
default=True,
help=help,
)

def add_distributed_args(self) -> argparse.ArgumentParser:
"""
Adds default set of arguments that are needed for Dask cluster setup
Expand Down Expand Up @@ -392,67 +408,30 @@ def parse_distributed_classifier_args(
description,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser = ArgumentHelper(parser).add_distributed_args()
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_args()
argumentHelper.add_arg_input_data_dir(required=True)
argumentHelper.add_arg_output_data_dir(help="The path of the output files")
argumentHelper.add_arg_input_file_type()
argumentHelper.add_arg_input_file_extension()
argumentHelper.add_arg_output_file_type()
argumentHelper.add_arg_input_text_field()
argumentHelper.add_arg_enable_spilling()
argumentHelper.add_arg_set_torch_to_use_rmm()
argumentHelper.add_arg_batch_size(
help="The batch size to be used for inference"
)
argumentHelper.add_arg_model_path()
argumentHelper.add_arg_autocaset()

# Set low default RMM pool size for classifier
# to allow pytorch to grow its memory usage
# by default
parser.set_defaults(rmm_pool_size="512MB")
parser.add_argument(
"--input-data-dir",
type=str,
help="The path of the input files",
required=True,
)
parser.add_argument(
"--output-data-dir",
type=str,
help="The path of the output files",
required=True,
)
parser.add_argument(
"--model-path",
type=str,
help="The path to the model file",
required=True,
)
parser.add_argument(
"--input-file-type",
type=str,
help="The type of the input files",
required=True,
)
parser.add_argument(
"--output-file-type",
type=str,
default="jsonl",
help="The type of the output files",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="The batch size to be used for inference",
)
ArgumentHelper.attach_bool_arg(
parser, "autocast", default=True, help="Whether to use autocast or not"
)
ArgumentHelper.attach_bool_arg(
parser,
"enable-spilling",
default=True,
help="Whether to enable spilling or not",
)

argumentHelper.parser.set_defaults(rmm_pool_size="512MB")
# Setting to False makes it more stable for long running jobs
# possibly because of memory fragmentation
ArgumentHelper.attach_bool_arg(
parser,
"set-torch-to-use-rmm",
default=False,
help="Whether to set torch to use RMM or not",
)

return parser
argumentHelper.parser.set_defaults(set_torch_to_use_rmm=False)
return argumentHelper.parser

@staticmethod
def parse_gpu_dedup_args(description: str) -> argparse.ArgumentParser:
Expand All @@ -472,6 +451,7 @@ def parse_gpu_dedup_args(description: str) -> argparse.ArgumentParser:

# Set default device to GPU for dedup
argumentHelper.parser.set_defaults(device="gpu")
argumentHelper.parser.set_defaults(set_torch_to_use_rmm=False)
argumentHelper.parser.add_argument(
"--input-data-dirs",
type=str,
Expand Down

0 comments on commit fb12646

Please sign in to comment.