Skip to content

Commit

Permalink
llama embedding data parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
ipotkonjak-tt committed Jan 27, 2025
1 parent b8d26dc commit 7e6ddab
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 48 deletions.
28 changes: 22 additions & 6 deletions models/demos/llama3/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,30 @@
indirect=True,
)
@pytest.mark.parametrize(
"batch_size",
(1,),
"batch_dp_tp",
[(1, 1, 2), (2, 2, 1), (4, 2, 1)],
ids=lambda args: "batch_{}_dp_{}_tp_{}".format(*args),
)
@pytest.mark.parametrize(
"max_seq_len",
(128,), # For decode-only unit test, there's no need to run with large sequence lengths
)
def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache, reset_seeds, ensure_gc):
def test_llama_embedding(max_seq_len, batch_dp_tp, mesh_device, use_program_cache, reset_seeds, ensure_gc):
batch_size, data_parallel, tensor_parallel = batch_dp_tp

dtype = ttnn.bfloat16
mesh_device.enable_async(True)

model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len)
if data_parallel > 1:
mesh_device.reshape(ttnn.MeshShape(2, 1))

model_args = TtModelArgs(
mesh_device,
max_batch_size=batch_size,
data_parallel=data_parallel,
tensor_parallel=tensor_parallel,
max_seq_len=max_seq_len,
)
model_args.n_layers = 1

state_dict = model_args.load_state_dict()
Expand All @@ -61,15 +73,19 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache
dtype=dtype,
)

prompts = ["Joy"] * 32
prompts = ["Joy"] * batch_size # 32
pt_input = torch.tensor([tokenizer.encode(prompt, bos=False, eos=False) for prompt in prompts])
reference_output = reference_emb(pt_input)
logger.info(f"reference_output: {reference_output.shape}")

tt_input = ttnn.from_torch(
pt_input.squeeze(1),
device=mesh_device,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.ShardTensor2dMesh(
mesh_device,
dims=(0, None) if model_args.num_devices_dp > 1 else (None, None),
mesh_shape=model_args.cluster_shape,
),
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
)
Expand Down
6 changes: 5 additions & 1 deletion models/demos/llama3/tt/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def __init__(
torch_weight,
dtype=dtype,
device=self.mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device=mesh_device, dims=(None, 3), mesh_shape=args.cluster_shape),
mesh_mapper=ttnn.ShardTensor2dMesh(
mesh_device=mesh_device,
dims=(None, 3) if args.num_devices_tp > 1 else (None, None),
mesh_shape=args.cluster_shape,
),
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=args.get_model_config()["EMB_WEIGHTS_MEMCFG"],
cache_file_name=cache_name,
Expand Down
97 changes: 56 additions & 41 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def __init__(
if "instruct" in self.DEFAULT_CACHE_PATH.lower():
self.instruct = True
self.dummy_weights = dummy_weights
self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size))
self.tile_padded_batch_rows = self.tile_size * int(
math.ceil((self.max_batch_size / self.num_devices_dp) / self.tile_size)
)

# Enable workarounds by default until di/dt issues are fixed
self.di_dt_workaround = os.getenv("DISABLE_DI_DT_WORKAROUND") != "1"
Expand Down Expand Up @@ -471,29 +473,29 @@ def find_largest_divisor(n, max_divisor=8):
)

# Useful core grid based on batch size
if self.max_batch_size == 32:
grid_by_batch = (8, 4)
elif self.max_batch_size == 16:
grid_by_batch = (8, 2)
elif self.max_batch_size == 8:
grid_by_batch = (8, 1)
elif self.max_batch_size == 4:
grid_by_batch = (4, 1)
elif self.max_batch_size == 2:
grid_by_batch = (2, 1)
elif self.max_batch_size == 1:
grid_by_batch = (1, 1)
else:
raise ValueError(f"Batch size {self.max_batch_size} not supported")
core_grid_by_batch = ttnn.CoreGrid(y=grid_by_batch[1], x=grid_by_batch[0])
core_range_set_by_batch = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(grid_by_batch[0] - 1, grid_by_batch[1] - 1),
),
}
)
# if self.max_batch_size == 32:
# grid_by_batch = (8, 4)
# elif self.max_batch_size == 16:
# grid_by_batch = (8, 2)
# elif self.max_batch_size == 8:
# grid_by_batch = (8, 1)
# elif self.max_batch_size == 4:
# grid_by_batch = (4, 1)
# elif self.max_batch_size == 2:
# grid_by_batch = (2, 1)
# elif self.max_batch_size == 1:
# grid_by_batch = (1, 1)
# else:
# raise ValueError(f"Batch size {self.max_batch_size} not supported")
# core_grid_by_batch = ttnn.CoreGrid(y=grid_by_batch[1], x=grid_by_batch[0])
# core_range_set_by_batch = ttnn.CoreRangeSet(
# {
# ttnn.CoreRange(
# ttnn.CoreCoord(0, 0),
# ttnn.CoreCoord(grid_by_batch[0] - 1, grid_by_batch[1] - 1),
# ),
# }
# )

self.model_config[
"SCORES_BATCHED_MM_OUTPUT_MEMCFG"
Expand All @@ -504,18 +506,18 @@ def find_largest_divisor(n, max_divisor=8):
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set_by_batch,
[
128,
128,
],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
# self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig(
# ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
# ttnn.BufferType.L1,
# ttnn.ShardSpec(
# core_range_set_by_batch,
# [
# 128,
# 128,
# ],
# ttnn.ShardOrientation.ROW_MAJOR,
# ),
# )

# MLP configs
mlp_core_grid = (
Expand Down Expand Up @@ -894,13 +896,18 @@ def ccl_topology(self):
return ttnn.Topology.Linear
return None

def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False, data_parallel=1):
"""
Prepare inputs for decode mode.
x: (batch, seq, dim)
"""
dims = (None, None) if force_replicated else (None, -1)
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.cluster_shape)
if data_parallel > 1:
dims = (2, None)
else:
dims = (None, None) if force_replicated else (None, -1)
mesh_mapper = ttnn.ShardTensor2dMesh(
self.mesh_device, dims=dims, mesh_shape=self.cluster_shape
) # DP: shard on batch

if len(x.shape) == 3:
batch = x.shape[0]
Expand All @@ -913,15 +920,23 @@ def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=Fals
assert x.shape[3] == self.dim

assert seq_len == 1, "Only supporting decode mode"
assert batch % data_parallel == 0, "Only supporting data parallel with batch divisible by number of devices"

# Support input on device
if torch.is_tensor(x): # Input on host -> Use torch
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, dim]
# Pad small batches to 32
if batch < 32:
zeros = torch.zeros(1, seq_len, 32, self.dim)
zeros[:, :, :batch, :] = x
x = zeros
if data_parallel > 1:
padded_batch = []
for chunk in torch.chunk(x, data_parallel, 2):
zeros[:, :, : (batch // data_parallel), :] = chunk
padded_batch.append(zeros.clone())
x = torch.cat(padded_batch, dim=2)
else:
zeros[:, :, :batch, :] = x
x = zeros
elif len(x.shape) == 3: # Input on device -> Use ttnn
x = ttnn.reshape(x, (batch, seq_len, 1, self.dim)) # [batch, seqlen, dim] -> [batch, seqlen, 1, dim]
x = ttnn.permute(x, (1, 2, 0, 3)) # [seq_len, 1, batch, dim]
Expand Down

0 comments on commit 7e6ddab

Please sign in to comment.