From b8d26dccc61acfca928e790cedd1ea54f8b2e716 Mon Sep 17 00:00:00 2001 From: Marko Bezulj <156311081+mbezuljTT@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:11:40 +0100 Subject: [PATCH] llama rms_norm data parallel (#17020) rms_norm data parallel pytest models/demos/llama3/tests/test_llama_rms_norm.py tested locally N300 PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_1_dp_1_tp_2-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_32_dp_2_tp_1-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_64_dp_2_tp_1-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_32_dp_1_tp_2-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_2_dp_2_tp_1-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_1_dp_1_tp_2-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_32_dp_2_tp_1-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_64_dp_2_tp_1-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_32_dp_1_tp_2-mesh_device0] PASSED models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_2_dp_2_tp_1-mesh_device0] SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:66: Skipping test: Tensor parallelism (8) does not match the number of devices (2). SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:66: Skipping test: Data parallelism (8) does not match the number of devices (2). SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:60: Skipping test: batch size 64 exceeds the maximum supported size for data parallelism factor 1 (max 32) CI https://github.com/tenstorrent/tt-metal/actions/runs/12930938703 --- models/demos/llama3/tests/test_llama_mlp.py | 4 +- .../demos/llama3/tests/test_llama_rms_norm.py | 71 +++++++++++++++---- models/demos/llama3/tt/model_config.py | 4 ++ models/utility_functions.py | 2 +- 4 files changed, 63 insertions(+), 18 deletions(-) diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index 8b951034af1..f6aac567efa 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -10,7 +10,7 @@ from models.demos.llama3.tt.llama_mlp import TtLlamaMLP from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward -from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallism +from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism from models.utility_functions import skip_for_grayskull @@ -48,7 +48,7 @@ def test_llama_mlp_inference(seq_len, batch_dp_tp, mesh_device, use_program_cache, reset_seeds, ensure_gc): batch_size, data_parallel, tensor_parallel = batch_dp_tp - skip, reason = skip_for_batch_parallism(batch_size, data_parallel) + skip, reason = skip_for_batch_parallelism(batch_size, data_parallel) if skip: pytest.skip(reason) diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index 5fdc99ee14d..993a1201eed 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -9,10 +9,7 @@ from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm as RefRMSNorm -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) +from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism from models.utility_functions import skip_for_grayskull from models.demos.llama3.tt.distributed_norm import DistributedNorm @@ -29,8 +26,18 @@ indirect=True, ) @pytest.mark.parametrize( - "batch_size", - (1,), + "batch_dp_tp", + [ + (1, 1, 8), + (8, 8, 1), + (1, 1, 2), + (32, 2, 1), + (64, 2, 1), + (32, 1, 2), + (64, 1, 2), + (2, 2, 1), + ], + ids=lambda args: "batch_{}_dp_{}_tp_{}".format(*args), ) @pytest.mark.parametrize( "max_seq_len", @@ -38,19 +45,37 @@ ) @pytest.mark.parametrize("mode", ["prefill", "decode"]) def test_llama_rms_norm_inference( + batch_dp_tp, max_seq_len, - batch_size, mode, mesh_device, use_program_cache, reset_seeds, ensure_gc, ): + batch_size, data_parallel, tensor_parallel = batch_dp_tp + + skip, reason = skip_for_batch_parallelism(batch_size, data_parallel) + if skip: + pytest.skip(reason) + + skip, reason = skip_for_parallelism( + mesh_device.get_num_devices() if mesh_device else 0, data_parallel, tensor_parallel + ) + if skip: + pytest.skip(reason) + dtype = ttnn.bfloat16 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args = TtModelArgs( + mesh_device, + max_batch_size=batch_size, + max_seq_len=max_seq_len, + data_parallel=data_parallel, + tensor_parallel=tensor_parallel, + ) model_args.n_layers = 1 state_dict = model_args.load_state_dict() @@ -80,30 +105,46 @@ def test_llama_rms_norm_inference( reference_model = RefRMSNorm(dim=model_args.dim, eps=model_args.norm_eps) reference_model.load_state_dict(partial_state_dict) - input = torch.rand(1, 1, 32, model_args.dim) + input = torch.rand(model_args.per_chip_batch_dim, 1, 32, model_args.dim) reference_output = reference_model(input) + if data_parallel > 1: + input_shard_dims = (0, None) # shard across batch dimension + else: + input_shard_dims = (None, -1) # shard across width dimension + # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) tt_input = ttnn.from_torch( input, device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, -1), mesh_shape=model_args.cluster_shape), - memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=input_shard_dims, mesh_shape=model_args.cluster_shape), + memory_config=( + model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), ) tt_output = tt_model(tt_input, mode=mode) # DistributedNorm outputs are replicated across devices + if data_parallel > 1: + # Data parallel is not running distributed norm. + # Data parallel per chip batch runs on dim 0. dim 3 is not utilized. + output_shard_dims = (0, 3) + elif model_args.is_galaxy: + output_shard_dims = (0, 3) + else: + output_shard_dims = (3, 0) + tt_output_torch = ttnn.to_torch( tt_output, mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, dims=(0, 3) if model_args.is_galaxy else (3, 0), mesh_shape=model_args.cluster_shape + mesh_device, dims=output_shard_dims, mesh_shape=model_args.cluster_shape ), - )[:1, :, :, :] + ) + if tensor_parallel > 1: + tt_output_torch = tt_output_torch[:1, :, :, :] passing, pcc_message = comp_pcc(reference_output, tt_output_torch) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index b88b86aa2d6..2d4d0fefb95 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -877,6 +877,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): def is_distributed_norm(self, mode): if not self.is_multichip: return False + if self.num_devices_dp > 1: + # data parallel assumes model fits single chip. + assert self.num_devices_tp == 1, "Hybrid mode not supported" + return False if all([dim > 1 for dim in list(self.mesh_device.shape)]): # 2D grid return True elif self.dim >= 8192 and mode == "prefill": # Somewhere between 4k and 8k WH runs out of L1 if not distributed diff --git a/models/utility_functions.py b/models/utility_functions.py index 280b7cfdf75..3b3edd6d578 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -1023,7 +1023,7 @@ def get_debug_tensor(num_pages_width, num_pages_height, dtype, page_width=32, pa return torch_tensor -def skip_for_batch_parallism(batch_size: int, data_parallel: int) -> Union[bool, str]: +def skip_for_batch_parallelism(batch_size: int, data_parallel: int) -> Union[bool, str]: if batch_size % data_parallel != 0: return ( True,