Skip to content

Commit

Permalink
llama rms_norm data parallel (#17020)
Browse files Browse the repository at this point in the history
rms_norm data parallel

pytest models/demos/llama3/tests/test_llama_rms_norm.py

tested locally N300
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_1_dp_1_tp_2-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_32_dp_2_tp_1-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_64_dp_2_tp_1-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_32_dp_1_tp_2-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-prefill-128-batch_2_dp_2_tp_1-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_1_dp_1_tp_2-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_32_dp_2_tp_1-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_64_dp_2_tp_1-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_32_dp_1_tp_2-mesh_device0]
PASSED
models/demos/llama3/tests/test_llama_rms_norm.py::test_llama_rms_norm_inference[wormhole_b0-True-decode-128-batch_2_dp_2_tp_1-mesh_device0]
SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:66:
Skipping test: Tensor parallelism (8) does not match the number of
devices (2).
SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:66:
Skipping test: Data parallelism (8) does not match the number of devices
(2).
SKIPPED [2] models/demos/llama3/tests/test_llama_rms_norm.py:60:
Skipping test: batch size 64 exceeds the maximum supported size for data
parallelism factor 1 (max 32)

CI https://github.com/tenstorrent/tt-metal/actions/runs/12930938703
  • Loading branch information
mbezuljTT authored Jan 23, 2025
1 parent 62dd556 commit b8d26dc
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 18 deletions.
4 changes: 2 additions & 2 deletions models/demos/llama3/tests/test_llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from models.demos.llama3.tt.llama_mlp import TtLlamaMLP
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward
from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallism
from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism
from models.utility_functions import skip_for_grayskull


Expand Down Expand Up @@ -48,7 +48,7 @@
def test_llama_mlp_inference(seq_len, batch_dp_tp, mesh_device, use_program_cache, reset_seeds, ensure_gc):
batch_size, data_parallel, tensor_parallel = batch_dp_tp

skip, reason = skip_for_batch_parallism(batch_size, data_parallel)
skip, reason = skip_for_batch_parallelism(batch_size, data_parallel)
if skip:
pytest.skip(reason)

Expand Down
71 changes: 56 additions & 15 deletions models/demos/llama3/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from models.common.rmsnorm import RMSNorm as TtRMSNorm
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm as RefRMSNorm
from models.utility_functions import (
comp_pcc,
comp_allclose,
)
from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism
from models.utility_functions import skip_for_grayskull
from models.demos.llama3.tt.distributed_norm import DistributedNorm

Expand All @@ -29,28 +26,56 @@
indirect=True,
)
@pytest.mark.parametrize(
"batch_size",
(1,),
"batch_dp_tp",
[
(1, 1, 8),
(8, 8, 1),
(1, 1, 2),
(32, 2, 1),
(64, 2, 1),
(32, 1, 2),
(64, 1, 2),
(2, 2, 1),
],
ids=lambda args: "batch_{}_dp_{}_tp_{}".format(*args),
)
@pytest.mark.parametrize(
"max_seq_len",
(128,), # For decode-only unit test, there's no need to run with large sequence lengths
)
@pytest.mark.parametrize("mode", ["prefill", "decode"])
def test_llama_rms_norm_inference(
batch_dp_tp,
max_seq_len,
batch_size,
mode,
mesh_device,
use_program_cache,
reset_seeds,
ensure_gc,
):
batch_size, data_parallel, tensor_parallel = batch_dp_tp

skip, reason = skip_for_batch_parallelism(batch_size, data_parallel)
if skip:
pytest.skip(reason)

skip, reason = skip_for_parallelism(
mesh_device.get_num_devices() if mesh_device else 0, data_parallel, tensor_parallel
)
if skip:
pytest.skip(reason)

dtype = ttnn.bfloat16

mesh_device.enable_async(True)

model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len)
model_args = TtModelArgs(
mesh_device,
max_batch_size=batch_size,
max_seq_len=max_seq_len,
data_parallel=data_parallel,
tensor_parallel=tensor_parallel,
)

model_args.n_layers = 1
state_dict = model_args.load_state_dict()
Expand Down Expand Up @@ -80,30 +105,46 @@ def test_llama_rms_norm_inference(
reference_model = RefRMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
reference_model.load_state_dict(partial_state_dict)

input = torch.rand(1, 1, 32, model_args.dim)
input = torch.rand(model_args.per_chip_batch_dim, 1, 32, model_args.dim)
reference_output = reference_model(input)

if data_parallel > 1:
input_shard_dims = (0, None) # shard across batch dimension
else:
input_shard_dims = (None, -1) # shard across width dimension

# DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode)
tt_input = ttnn.from_torch(
input,
device=mesh_device,
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, -1), mesh_shape=model_args.cluster_shape),
memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"]
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=input_shard_dims, mesh_shape=model_args.cluster_shape),
memory_config=(
model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
),
)

tt_output = tt_model(tt_input, mode=mode)

# DistributedNorm outputs are replicated across devices
if data_parallel > 1:
# Data parallel is not running distributed norm.
# Data parallel per chip batch runs on dim 0. dim 3 is not utilized.
output_shard_dims = (0, 3)
elif model_args.is_galaxy:
output_shard_dims = (0, 3)
else:
output_shard_dims = (3, 0)

tt_output_torch = ttnn.to_torch(
tt_output,
mesh_composer=ttnn.ConcatMesh2dToTensor(
mesh_device, dims=(0, 3) if model_args.is_galaxy else (3, 0), mesh_shape=model_args.cluster_shape
mesh_device, dims=output_shard_dims, mesh_shape=model_args.cluster_shape
),
)[:1, :, :, :]
)
if tensor_parallel > 1:
tt_output_torch = tt_output_torch[:1, :, :, :]

passing, pcc_message = comp_pcc(reference_output, tt_output_torch)

Expand Down
4 changes: 4 additions & 0 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len):
def is_distributed_norm(self, mode):
if not self.is_multichip:
return False
if self.num_devices_dp > 1:
# data parallel assumes model fits single chip.
assert self.num_devices_tp == 1, "Hybrid mode not supported"
return False
if all([dim > 1 for dim in list(self.mesh_device.shape)]): # 2D grid
return True
elif self.dim >= 8192 and mode == "prefill": # Somewhere between 4k and 8k WH runs out of L1 if not distributed
Expand Down
2 changes: 1 addition & 1 deletion models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def get_debug_tensor(num_pages_width, num_pages_height, dtype, page_width=32, pa
return torch_tensor


def skip_for_batch_parallism(batch_size: int, data_parallel: int) -> Union[bool, str]:
def skip_for_batch_parallelism(batch_size: int, data_parallel: int) -> Union[bool, str]:
if batch_size % data_parallel != 0:
return (
True,
Expand Down

0 comments on commit b8d26dc

Please sign in to comment.