Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Test_multiprocessing_encoder with process spawn in bash #1394

Merged
merged 5 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,10 @@ def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80


@lru_cache
def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90
20 changes: 20 additions & 0 deletions examples/jax/encoder/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""config for test_multiprocessing_encoder"""
import pytest


def pytest_addoption(parser):
"""Pytest hook for test_multiprocessing_encoder"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)


@pytest.fixture(autouse=True)
def multiprocessing_parses(request):
"""Fixture for querying num-process and process-id"""
if request.cls:
request.cls.num_process = int(request.config.getoption("--num-process"))
request.cls.process_id = int(request.config.getoption("--process-id"))
17 changes: 17 additions & 0 deletions examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
70 changes: 22 additions & 48 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse
import multiprocessing as mp
import os
import unittest
from functools import partial
import pytest

import flax
import jax
Expand All @@ -21,10 +21,10 @@
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding

from common import is_bf16_supported, is_fp8_supported
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
Expand Down Expand Up @@ -252,7 +252,6 @@ def eval_model(

def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Expand Down Expand Up @@ -342,6 +341,9 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
if args.process_id == 0:
nltk.download("punkt_tab")

train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

jax.distributed.initialize(
Expand Down Expand Up @@ -551,69 +553,41 @@ def encoder_parser(args):
return parser.parse_args(args)


def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])


def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason


@pytest.mark.usefixtures("multiprocessing_parses")
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()
gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()

def exec(self, use_fp8):
"""Run 3 epochs for testing"""
num_gpu = self.num_gpu
args = encoder_parser([])

num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size

arg_list = []
for i in range(num_gpu):
args = encoder_parser([])
args.num_process = num_gpu
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.process_id = i
arg_list.append(args)

with mp.Pool(self.num_gpu) as p:
results = p.map(train_and_evaluate, arg_list)
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.num_process = num_gpu
args.process_id = self.process_id

return results
return train_and_evaluate(args)

@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(False)
assert result[0] < 0.45 and result[1] > 0.79

@unittest.skipIf(not gpu_has_fp8, reason)
@unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
results = self.exec(True)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(True)
assert result[0] < 0.45 and result[1] > 0.79


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
Loading