Skip to content

Commit

Permalink
#14519: Fix rebase issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Nov 7, 2024
1 parent a7aeba6 commit a413e64
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 23 deletions.
2 changes: 1 addition & 1 deletion models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(stdscr):
for c in commands
for m in models
for d in devices
if not (m in ["11b", "11b-b"] and d == "n150") or not (m == "70b" and d in ["n150", "n300"])
if not ((m in ["11b", "11b-b"] and d == "n150") or (m == "70b" and d in ["n150", "n300"]))
]

# Create output entries
Expand Down
20 changes: 14 additions & 6 deletions models/demos/llama3/tests/multimodal/test_llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
],
indirect=True,
)
@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"])
@pytest.mark.parametrize(
"batch",
(1,),
ids=[
"batch_1",
],
)
def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc):
dtype = ttnn.bfloat16
pcc_required = 0.99
Expand Down Expand Up @@ -103,7 +109,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicate=True,
force_replicated=True,
)
tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b)
tt_xattn_cache_torch = [
Expand Down Expand Up @@ -202,8 +208,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
user_id=b,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))
tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim)
tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
tt_output_torch = tt_output_torch[..., :seq_len, :].view(1, seq_len, dim)
outputs.append(tt_output_torch)
tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim)

Expand All @@ -213,6 +219,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
ttnn.DRAM_MEMORY_CONFIG,
force_replicated=True,
)
tt_x = ttnn.interleaved_to_sharded(tt_x, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"])

xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous()
tt_xattn_mask = ttnn.from_torch(
xattn_mask_expand,
Expand Down Expand Up @@ -255,8 +263,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
mode=mode,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))
tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim)
tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))
tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim)

passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required)
logger.info(comp_allclose(pt_out, tt_output_torch))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
],
indirect=True,
)
@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"])
@pytest.mark.parametrize(
"batch",
(1,),
ids=[
"batch_1",
],
)
@torch.no_grad()
def test_llama_cross_attention_transformer_text_inference(
text_seq_len,
Expand Down Expand Up @@ -112,7 +118,7 @@ def test_llama_cross_attention_transformer_text_inference(
for b in range(batch):
tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_vision_tokens[b : b + 1],
force_replicate=True,
force_replicated=True,
)

tt_xattn_cache = [
Expand Down Expand Up @@ -233,7 +239,7 @@ def test_llama_cross_attention_transformer_text_inference(
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
)

rot_mats = get_prefill_rot_mat(
Expand Down
20 changes: 13 additions & 7 deletions models/demos/llama3/tests/multimodal/test_llama_cross_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
],
indirect=True,
)
@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"])
@pytest.mark.parametrize(
"batch",
(1,),
ids=[
"batch_1",
],
)
def test_llama_cross_attention_transformer_block_inference(
text_seq_len, batch, mesh_device, use_program_cache, reset_seeds, ensure_gc
):
Expand Down Expand Up @@ -97,7 +103,7 @@ def test_llama_cross_attention_transformer_block_inference(
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicate=True,
force_replicated=True,
)
tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b)
tt_xattn_cache_torch = [
Expand Down Expand Up @@ -195,7 +201,7 @@ def test_llama_cross_attention_transformer_block_inference(
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
)
tt_out = tt_model(
tt_tensor_x,
Expand All @@ -207,8 +213,8 @@ def test_llama_cross_attention_transformer_block_inference(
user_id=b,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))
tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim)
tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
tt_output_torch = tt_output_torch[..., :seq_len, :].view(1, seq_len, dim)
outputs.append(tt_output_torch)
tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim)

Expand Down Expand Up @@ -259,8 +265,8 @@ def test_llama_cross_attention_transformer_block_inference(
mode=mode,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))
tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim)
tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim)

passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required)
logger.info(comp_allclose(pt_out, tt_output_torch))
Expand Down
13 changes: 11 additions & 2 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,19 @@ def find_largest_divisor(n, max_divisor=8):
fuse_batch=seq_len <= max_seq,
)

xattn_cache_y_cores = (
16 // self.num_devices
) # Based on seqlen, this formula gives us a valid number of y cores
xattn_cache_x_cores = 8
self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config(
# using n_heads since xattn repeats KV to match Q
(((self.n_heads // self.num_devices) * seq_len // 64), self.head_dim),
ttnn.CoreGrid(y=8, x=8),
(
nearest_32(
(self.n_heads // self.num_devices) * seq_len // (xattn_cache_y_cores * xattn_cache_x_cores)
),
self.head_dim,
),
ttnn.CoreGrid(y=xattn_cache_y_cores, x=xattn_cache_x_cores),
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/multimodal/llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def __init__(

def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id):
# Always runs with batch=1
B, seqlen_y = 1, xattn_tokens.shape[2]
B, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2]
assert B == 1, "Batch size must be 1"
MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ
if seqlen_y > MAX_MM_SEQ_LEN:
xattn_tokens = ttnn.reshape(xattn_tokens, [1, B * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1])
Expand All @@ -146,7 +147,6 @@ def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id):
compute_kernel_config=self.compute_kernel_config_hifi4,
program_config=self.model_config["VISION_XATTN_KV_PROGCFG"](seqlen_y, MAX_MM_SEQ_LEN),
)

xv = ttnn.linear(
xattn_tokens,
self.wv,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/multimodal/llama_cross_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward(
attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn))

res = ttnn.add(x_11SH, attn_out)
mlp_out = self.feed_forward(self.ffn_norm(res), mode=mode)
mlp_out = self.feed_forward(self.ffn_norm(res, mode=mode), mode=mode)
if mode == "prefill":
# Making the assumption that you never mask decode rows
mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1),
)

return (
Expand Down

0 comments on commit a413e64

Please sign in to comment.