From a65ad37e622ad89837b15520b9f2b6c7232d3423 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 10 Jan 2025 16:53:31 -0800 Subject: [PATCH] [JAX] Test_multiprocessing_encoder with process spawn in bash (#1394) * add test_multiprocessing_encoder with processing spawning in bash --------- Signed-off-by: Phuong Nguyen --- examples/jax/encoder/common.py | 7 ++ examples/jax/encoder/conftest.py | 20 ++++++ .../run_test_multiprocessing_encoder.sh | 17 +++++ .../encoder/test_multiprocessing_encoder.py | 70 ++++++------------- qa/L0_jax_distributed_unittest/test.sh | 2 +- 5 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 examples/jax/encoder/conftest.py create mode 100644 examples/jax/encoder/run_test_multiprocessing_encoder.sh diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index c79fa45239..93dbd408ea 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -12,3 +12,10 @@ def is_bf16_supported(): """Return if BF16 has hardware supported""" gpu_arch = get_device_compute_capability(0) return gpu_arch >= 80 + + +@lru_cache +def is_fp8_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 90 diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py new file mode 100644 index 0000000000..b1648892aa --- /dev/null +++ b/examples/jax/encoder/conftest.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for test_multiprocessing_encoder""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for test_multiprocessing_encoder""" + parser.addoption("--num-process", action="store", default=0) + parser.addoption("--process-id", action="store", default=0) + + +@pytest.fixture(autouse=True) +def multiprocessing_parses(request): + """Fixture for querying num-process and process-id""" + if request.cls: + request.cls.num_process = int(request.config.getoption("--num-process")) + request.cls.process_id = int(request.config.getoption("--process-id")) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh new file mode 100644 index 0000000000..6a1dd96739 --- /dev/null +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & +done +wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i & +done +wait diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index ff6fd4d167..7d2df77b7d 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" import argparse -import multiprocessing as mp import os import unittest from functools import partial +import pytest import flax import jax @@ -21,10 +21,10 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding +from common import is_bf16_supported, is_fp8_supported import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from common import is_bf16_supported os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" @@ -252,7 +252,6 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) @@ -342,6 +341,9 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + if args.process_id == 0: + nltk.download("punkt_tab") + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) jax.distributed.initialize( @@ -551,69 +553,41 @@ def encoder_parser(args): return parser.parse_args(args) -def query_gpu(q): - """Query GPU info on the system""" - gpu_has_fp8, reason = te.fp8.is_fp8_available() - gpu_has_bf16 = is_bf16_supported() - num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) - - -def unittest_query_gpu(): - r""" - It is only used by TestEncoder. - The `jax.distributed.initialize` must be called before any other JAX or Flax API, - otherwise `jax.local_devices` will be incorrect. - Thus, fork another process to query number of GPUs and FP8 capability. - """ - q = mp.Queue() - p = mp.Process(target=query_gpu, args=(q,)) - p.start() - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() - p.join() - return num_gpu, gpu_has_fp8, gpu_has_bf16, reason - - +@pytest.mark.usefixtures("multiprocessing_parses") class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() + gpu_has_fp8 = is_fp8_supported() + gpu_has_bf16 = is_bf16_supported() def exec(self, use_fp8): """Run 3 epochs for testing""" - num_gpu = self.num_gpu + args = encoder_parser([]) + + num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 dp_size = num_gpu // tp_size batch_size = 64 // dp_size - arg_list = [] - for i in range(num_gpu): - args = encoder_parser([]) - args.num_process = num_gpu - args.use_fp8 = use_fp8 - args.batch_size = batch_size - args.test_batch_size = batch_size - args.process_id = i - arg_list.append(args) - - with mp.Pool(self.num_gpu) as p: - results = p.map(train_and_evaluate, arg_list) + args.use_fp8 = use_fp8 + args.batch_size = batch_size + args.test_batch_size = batch_size + args.num_process = num_gpu + args.process_id = self.process_id - return results + return train_and_evaluate(args) @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" - results = self.exec(False) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(False) + assert result[0] < 0.45 and result[1] > 0.79 - @unittest.skipIf(not gpu_has_fp8, reason) + @unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8") def test_te_fp8(self): """Test Transformer Engine with FP8""" - results = self.exec(True) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(True) + assert result[0] < 0.45 and result[1] > 0.79 if __name__ == "__main__": diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index f1d1c06d38..947796b029 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh