Skip to content

Commit

Permalink
llama3 embedding data parallel (#17358)
Browse files Browse the repository at this point in the history
#### Code changes
Addition of data parallelism for llama3 embedding module(plus some CI
fixes).
#### Testing
Tested locally on N300 with `pytest
models/demos/llama3/tests/test_llama_embedding.py`, all tests pass.
#### CI runs

- [(Single card) Models
perf](https://github.com/tenstorrent/tt-metal/actions/runs/13011663398)
- [(Single card) Demo +
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/13011403767)
- [(T3K) Choose your
pipeline](https://github.com/tenstorrent/tt-metal/actions/runs/13028347394)
  • Loading branch information
ipotkonjak-tt committed Feb 11, 2025
1 parent d63212c commit fd7bb64
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 85 deletions.
45 changes: 38 additions & 7 deletions models/demos/llama3/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from models.utility_functions import (
comp_pcc,
comp_allclose,
skip_for_batch_parallelism,
skip_for_parallelism,
skip_for_model_parallelism,
)
from models.utility_functions import skip_for_grayskull

Expand All @@ -27,18 +30,42 @@
indirect=True,
)
@pytest.mark.parametrize(
"batch_size",
(1,),
"batch_dp_tp",
[(1, 1, 2), (2, 2, 1), (4, 2, 1)],
ids=lambda args: "batch_{}_dp_{}_tp_{}".format(*args),
)
@pytest.mark.parametrize(
"max_seq_len",
(128,), # For decode-only unit test, there's no need to run with large sequence lengths
)
def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache, reset_seeds, ensure_gc):
def test_llama_embedding(max_seq_len, batch_dp_tp, mesh_device, use_program_cache, reset_seeds, ensure_gc):
batch_size, data_parallel, tensor_parallel = batch_dp_tp

skip, reason = skip_for_batch_parallelism(batch_size, data_parallel)
if skip:
pytest.skip(reason)
skip, reason = skip_for_parallelism(
mesh_device.get_num_devices() if mesh_device else 0, data_parallel, tensor_parallel
)
if skip:
pytest.skip(reason)
skip, reason = skip_for_model_parallelism(data_parallel)
if skip:
pytest.skip(reason)

dtype = ttnn.bfloat16
mesh_device.enable_async(True)

model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len)
if data_parallel > 1:
mesh_device.reshape(ttnn.MeshShape(mesh_device.get_num_devices(), 1))

model_args = TtModelArgs(
mesh_device,
max_batch_size=batch_size,
data_parallel=data_parallel,
tensor_parallel=tensor_parallel,
max_seq_len=max_seq_len,
)
model_args.n_layers = 1

state_dict = model_args.load_state_dict()
Expand All @@ -59,15 +86,19 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache
dtype=dtype,
)

prompts = ["Joy"] * 32
pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts])
prompts = ["Joy"] * batch_size # 32
pt_input = torch.tensor([tokenizer.encode(prompt, bos=False, eos=False) for prompt in prompts])
reference_output = reference_emb(pt_input)
logger.info(f"reference_output: {reference_output.shape}")

tt_input = ttnn.from_torch(
pt_input.squeeze(1),
device=mesh_device,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.ShardTensor2dMesh(
mesh_device,
dims=(0, None) if model_args.num_devices_dp > 1 else (None, None),
mesh_shape=model_args.cluster_shape,
),
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
)
Expand Down
12 changes: 11 additions & 1 deletion models/demos/llama3/tests/test_llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from models.demos.llama3.tt.llama_mlp import TtLlamaMLP
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward
from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism
from models.utility_functions import (
comp_pcc,
comp_allclose,
skip_for_parallelism,
skip_for_batch_parallelism,
skip_for_model_parallelism,
)
from models.utility_functions import skip_for_grayskull


Expand Down Expand Up @@ -58,6 +64,10 @@ def test_llama_mlp_inference(seq_len, batch_dp_tp, mesh_device, use_program_cach
if skip:
pytest.skip(reason)

skip, reason = skip_for_model_parallelism(data_parallel)
if skip:
pytest.skip(reason)

dtype = ttnn.bfloat8_b
mode = "decode" if seq_len <= 32 else "prefill"

Expand Down
12 changes: 11 additions & 1 deletion models/demos/llama3/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from models.common.rmsnorm import RMSNorm as TtRMSNorm
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm as RefRMSNorm
from models.utility_functions import comp_pcc, comp_allclose, skip_for_parallelism, skip_for_batch_parallelism
from models.utility_functions import (
comp_pcc,
comp_allclose,
skip_for_parallelism,
skip_for_batch_parallelism,
skip_for_model_parallelism,
)
from models.utility_functions import skip_for_grayskull
from models.demos.llama3.tt.distributed_norm import DistributedNorm

Expand Down Expand Up @@ -65,6 +71,10 @@ def test_llama_rms_norm_inference(
if skip:
pytest.skip(reason)

skip, reason = skip_for_model_parallelism(data_parallel)
if skip:
pytest.skip(reason)

dtype = ttnn.bfloat16

mesh_device.enable_async(True)
Expand Down
6 changes: 5 additions & 1 deletion models/demos/llama3/tt/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def __init__(
torch_weight,
dtype=dtype,
device=self.mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device=mesh_device, dims=(None, 3), mesh_shape=args.cluster_shape),
mesh_mapper=ttnn.ShardTensor2dMesh(
mesh_device=mesh_device,
dims=(None, 3) if args.num_devices_tp > 1 else (None, None),
mesh_shape=args.cluster_shape,
),
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=args.get_model_config()["EMB_WEIGHTS_MEMCFG"],
cache_file_name=cache_name,
Expand Down
96 changes: 56 additions & 40 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def __init__(
self.optimizations = optimizations

self.dummy_weights = dummy_weights
self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size))
self.tile_padded_batch_rows = self.tile_size * int(
math.ceil((self.max_batch_size / self.num_devices_dp) / self.tile_size)
)

# Enable workarounds by default until di/dt issues are fixed
self.di_dt_workaround = os.getenv("DISABLE_DI_DT_WORKAROUND") != "1"
Expand Down Expand Up @@ -499,28 +501,29 @@ def __init__(
)

# Useful core grid based on batch size
if self.max_batch_size == 32:
grid_by_batch = (8, 4)
elif self.max_batch_size == 16:
grid_by_batch = (8, 2)
elif self.max_batch_size == 8:
grid_by_batch = (8, 1)
elif self.max_batch_size == 4:
grid_by_batch = (4, 1)
elif self.max_batch_size == 2:
grid_by_batch = (2, 1)
elif self.max_batch_size == 1:
grid_by_batch = (1, 1)
else:
raise ValueError(f"Batch size {self.max_batch_size} not supported")
core_range_set_by_batch = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(grid_by_batch[0] - 1, grid_by_batch[1] - 1),
),
}
)
# if self.max_batch_size == 32:
# grid_by_batch = (8, 4)
# elif self.max_batch_size == 16:
# grid_by_batch = (8, 2)
# elif self.max_batch_size == 8:
# grid_by_batch = (8, 1)
# elif self.max_batch_size == 4:
# grid_by_batch = (4, 1)
# elif self.max_batch_size == 2:
# grid_by_batch = (2, 1)
# elif self.max_batch_size == 1:
# grid_by_batch = (1, 1)
# else:
# raise ValueError(f"Batch size {self.max_batch_size} not supported")
# core_grid_by_batch = ttnn.CoreGrid(y=grid_by_batch[1], x=grid_by_batch[0])
# core_range_set_by_batch = ttnn.CoreRangeSet(
# {
# ttnn.CoreRange(
# ttnn.CoreCoord(0, 0),
# ttnn.CoreCoord(grid_by_batch[0] - 1, grid_by_batch[1] - 1),
# ),
# }
# )

self.model_config[
"SCORES_BATCHED_MM_OUTPUT_MEMCFG"
Expand All @@ -531,18 +534,18 @@ def __init__(
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set_by_batch,
[
128,
128,
],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
# self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig(
# ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
# ttnn.BufferType.L1,
# ttnn.ShardSpec(
# core_range_set_by_batch,
# [
# 128,
# 128,
# ],
# ttnn.ShardOrientation.ROW_MAJOR,
# ),
# )

# MLP configs
mlp_core_grid = (
Expand Down Expand Up @@ -931,13 +934,18 @@ def ccl_topology(self):
return ttnn.Topology.Linear
return None

def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False, data_parallel=1):
"""
Prepare inputs for decode mode.
x: (batch, seq, dim)
"""
dims = (None, None) if force_replicated else (None, -1)
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.cluster_shape)
if data_parallel > 1:
dims = (2, None)
else:
dims = (None, None) if force_replicated else (None, -1)
mesh_mapper = ttnn.ShardTensor2dMesh(
self.mesh_device, dims=dims, mesh_shape=self.cluster_shape
) # DP: shard on batch

if len(x.shape) == 3:
batch = x.shape[0]
Expand All @@ -950,15 +958,23 @@ def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=Fals
assert x.shape[3] == self.dim

assert seq_len == 1, "Only supporting decode mode"
assert batch % data_parallel == 0, "Only supporting data parallel with batch divisible by number of devices"

# Support input on device
if torch.is_tensor(x): # Input on host -> Use torch
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, dim]
# Pad small batches to 32
if batch < 32:
zeros = torch.zeros(1, seq_len, 32, self.dim)
zeros[:, :, :batch, :] = x
x = zeros
if data_parallel > 1:
padded_batch = []
for chunk in torch.chunk(x, data_parallel, 2):
zeros[:, :, : (batch // data_parallel), :] = chunk
padded_batch.append(zeros.clone())
x = torch.cat(padded_batch, dim=2)
else:
zeros[:, :, :batch, :] = x
x = zeros
elif len(x.shape) == 3: # Input on device -> Use ttnn
x = ttnn.reshape(x, (batch, seq_len, 1, self.dim)) # [batch, seqlen, dim] -> [batch, seqlen, 1, dim]
x = ttnn.permute(x, (1, 2, 0, 3)) # [seq_len, 1, batch, dim]
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/multimodal/llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(
wo_str = f"{state_dict_prefix}wo.weight"

# when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices
assert self.n_heads % configuration.num_devices == 0
assert self.n_kv_heads % configuration.num_devices == 0
assert self.n_heads % self.num_devices == 0
assert self.n_kv_heads % self.num_devices == 0

# TODO DRAM Shard the weights (see llama3 text)
self.wq = ttnn.as_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
# TODO: Generalize LMHead, maybe use llama_model's single-tile-sequence LMHead
lm_head_torch = self.state_dict[f"{state_dict_prefix}output.weight"].transpose(-1, -2)
total_splits = 8 # Arbitrary value which allows whole-tile splits in LM Head
num_splits = total_splits // self.configuration.num_devices
num_splits = total_splits // self.configuration.num_devices_tp
lm_head_torch = torch.chunk(lm_head_torch, num_splits, dim=-1)

cache_name = lambda name, suffix, split: weight_cache_path / (state_dict_prefix + f"{name}{suffix}{split}")
Expand Down Expand Up @@ -341,7 +341,7 @@ def forward(
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

if self.configuration.num_devices > 1:
if self.configuration.num_devices_tp > 1:
output = ttnn.all_gather(output, dim=3, num_links=1, topology=ttnn.Topology.Linear)
outputs.append(output)

Expand Down
8 changes: 4 additions & 4 deletions models/demos/llama3/tt/multimodal/llama_image_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(
wo_str = f"{state_dict_prefix}wo.weight"

# when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices
assert self.n_heads % configuration.num_devices == 0
assert self.n_kv_heads % configuration.num_devices == 0
assert self.n_heads % self.num_devices == 0
assert self.n_kv_heads % self.num_devices == 0

# Pad head_dim to multiple of 32
def pad_head_dim(weight, heads_out=True):
Expand All @@ -87,7 +87,7 @@ def pad_head_dim(weight, heads_out=True):
wv_padded = pad_head_dim(self.state_dict[wv_str])
wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False)
wq_chunked, wk_chunked, wv_chunked = (
torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded]
torch.chunk(w, self.num_devices) for w in [wq_padded, wk_padded, wv_padded]
)

self.wqkv = ttnn.as_tensor(
Expand All @@ -113,7 +113,7 @@ def pad_head_dim(weight, heads_out=True):
],
dim=-1,
)
for i in range(configuration.num_devices)
for i in range(self.num_devices)
],
dim=-1,
),
Expand Down
Loading

0 comments on commit fd7bb64

Please sign in to comment.