From 9a9021a732f24b268583b98907f9db80c9f10be0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 7 Jan 2025 08:28:17 -0800 Subject: [PATCH 1/3] add test_multiprocessing_encoder with processing spawning in bash Signed-off-by: Phuong Nguyen --- examples/jax/encoder/common.py | 6 ++ examples/jax/encoder/conftest.py | 17 +++++ .../run_test_multiprocessing_encoder.sh | 18 +++++ .../encoder/test_multiprocessing_encoder.py | 70 ++++++------------- qa/L0_jax_distributed_unittest/test.sh | 2 +- 5 files changed, 64 insertions(+), 49 deletions(-) create mode 100644 examples/jax/encoder/conftest.py create mode 100644 examples/jax/encoder/run_test_multiprocessing_encoder.sh diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index c79fa45239..a0d62b41bb 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -12,3 +12,9 @@ def is_bf16_supported(): """Return if BF16 has hardware supported""" gpu_arch = get_device_compute_capability(0) return gpu_arch >= 80 + +@lru_cache +def is_fp8_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 90 diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py new file mode 100644 index 0000000000..435f66effb --- /dev/null +++ b/examples/jax/encoder/conftest.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" config for test_multiprocessing_encoder """ +import pytest + +def pytest_addoption(parser): + """Pytest hook for test_multiprocessing_encoder""" + parser.addoption("--num-process", action="store", default=0) + parser.addoption("--process-id", action="store", default=0) + +@pytest.fixture(autouse=True) +def multiprocessing_parses(request): + """Fixture for querying num-process and process-id""" + if request.cls: + request.cls.num_process = int(request.config.getoption("--num-process")) + request.cls.process_id = int(request.config.getoption("--process-id")) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh new file mode 100644 index 0000000000..29a21b49e3 --- /dev/null +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & +done +wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i & +done +wait + diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index ff6fd4d167..7d2df77b7d 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" import argparse -import multiprocessing as mp import os import unittest from functools import partial +import pytest import flax import jax @@ -21,10 +21,10 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding +from common import is_bf16_supported, is_fp8_supported import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from common import is_bf16_supported os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" @@ -252,7 +252,6 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) @@ -342,6 +341,9 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + if args.process_id == 0: + nltk.download("punkt_tab") + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) jax.distributed.initialize( @@ -551,69 +553,41 @@ def encoder_parser(args): return parser.parse_args(args) -def query_gpu(q): - """Query GPU info on the system""" - gpu_has_fp8, reason = te.fp8.is_fp8_available() - gpu_has_bf16 = is_bf16_supported() - num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) - - -def unittest_query_gpu(): - r""" - It is only used by TestEncoder. - The `jax.distributed.initialize` must be called before any other JAX or Flax API, - otherwise `jax.local_devices` will be incorrect. - Thus, fork another process to query number of GPUs and FP8 capability. - """ - q = mp.Queue() - p = mp.Process(target=query_gpu, args=(q,)) - p.start() - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() - p.join() - return num_gpu, gpu_has_fp8, gpu_has_bf16, reason - - +@pytest.mark.usefixtures("multiprocessing_parses") class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() + gpu_has_fp8 = is_fp8_supported() + gpu_has_bf16 = is_bf16_supported() def exec(self, use_fp8): """Run 3 epochs for testing""" - num_gpu = self.num_gpu + args = encoder_parser([]) + + num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 dp_size = num_gpu // tp_size batch_size = 64 // dp_size - arg_list = [] - for i in range(num_gpu): - args = encoder_parser([]) - args.num_process = num_gpu - args.use_fp8 = use_fp8 - args.batch_size = batch_size - args.test_batch_size = batch_size - args.process_id = i - arg_list.append(args) - - with mp.Pool(self.num_gpu) as p: - results = p.map(train_and_evaluate, arg_list) + args.use_fp8 = use_fp8 + args.batch_size = batch_size + args.test_batch_size = batch_size + args.num_process = num_gpu + args.process_id = self.process_id - return results + return train_and_evaluate(args) @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" - results = self.exec(False) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(False) + assert result[0] < 0.45 and result[1] > 0.79 - @unittest.skipIf(not gpu_has_fp8, reason) + @unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8") def test_te_fp8(self): """Test Transformer Engine with FP8""" - results = self.exec(True) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(True) + assert result[0] < 0.45 and result[1] > 0.79 if __name__ == "__main__": diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index f1d1c06d38..947796b029 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh From c8c4fc5cd598723cb5c77a74598067591d1c1de9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:09:43 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/encoder/common.py | 1 + examples/jax/encoder/conftest.py | 4 +++- examples/jax/encoder/run_test_multiprocessing_encoder.sh | 1 - 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index a0d62b41bb..93dbd408ea 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -13,6 +13,7 @@ def is_bf16_supported(): gpu_arch = get_device_compute_capability(0) return gpu_arch >= 80 + @lru_cache def is_fp8_supported(): """Return if FP8 has hardware supported""" diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py index 435f66effb..53da9d9cb0 100644 --- a/examples/jax/encoder/conftest.py +++ b/examples/jax/encoder/conftest.py @@ -1,14 +1,16 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -""" config for test_multiprocessing_encoder """ +"""config for test_multiprocessing_encoder""" import pytest + def pytest_addoption(parser): """Pytest hook for test_multiprocessing_encoder""" parser.addoption("--num-process", action="store", default=0) parser.addoption("--process-id", action="store", default=0) + @pytest.fixture(autouse=True) def multiprocessing_parses(request): """Fixture for querying num-process and process-id""" diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 29a21b49e3..269dff5da2 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -15,4 +15,3 @@ do pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i & done wait - From 2a4777db54a967a9a93876e673ef51a3b761ae76 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 8 Jan 2025 08:11:15 -0800 Subject: [PATCH 3/3] changed 2025 license Signed-off-by: Phuong Nguyen --- examples/jax/encoder/conftest.py | 3 ++- examples/jax/encoder/run_test_multiprocessing_encoder.sh | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py index 53da9d9cb0..b1648892aa 100644 --- a/examples/jax/encoder/conftest.py +++ b/examples/jax/encoder/conftest.py @@ -1,6 +1,7 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. + """config for test_multiprocessing_encoder""" import pytest diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 269dff5da2..6a1dd96739 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information.