From e6651e96f34ada33fb3b58570e4e3f3252c32c4d Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 1 Nov 2024 13:10:21 -0700 Subject: [PATCH 01/19] #14519: Use FlashDecode in LlamaVision xattn --- .../multimodal/test_llama_cross_block.py | 12 ++-- .../tt/multimodal/llama_cross_attention.py | 58 +++++++++---------- .../llama3/tt/multimodal/llama_cross_block.py | 3 - .../sdpa_decode/device/sdpa_decode_op.cpp | 4 +- 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index cdfa5ab19eb..26ebfc73d23 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -139,6 +139,8 @@ def test_llama_cross_attention_transformer_block_inference( xattn_mask = xattn_mask * -1e9 xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) + if mode == "decode": + xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, @@ -151,8 +153,8 @@ def test_llama_cross_attention_transformer_block_inference( tt_xattn_mask = ttnn.reshape( tt_xattn_mask, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, vision_seq_len], - [batch, n_heads // model_args.num_devices, 32, vision_seq_len], + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], ), ) @@ -167,6 +169,8 @@ def test_llama_cross_attention_transformer_block_inference( ) full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) + if mode == "decode": + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, @@ -179,8 +183,8 @@ def test_llama_cross_attention_transformer_block_inference( tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, head_dim], - [batch, n_heads // model_args.num_devices, 32, head_dim], + [1, batch, n_heads // model_args.num_devices, head_dim], + [1, batch, 32, head_dim], ), ) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index b307651aa2c..c890cb9be5a 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -47,6 +47,7 @@ def __init__( self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa self.configuration = configuration @@ -220,16 +221,16 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, # Below is how we want to reshape. It results in poor PCC # 1, B, D -> B, 1, NH, DH -> B, NH, 1, DH - # xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) - # xq = ttnn.reshape(xq, (batch, 1, self.n_local_heads, self.head_dim)) - # xq = ttnn.transpose(xq, 1, 2) - # xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT) - - xq, _, _ = ttnn.experimental.nlp_create_qkv_heads( - xq, xq, num_heads=self.n_local_heads, num_kv_heads=self.n_local_heads // 2, transpose_k_heads=False + xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) + # Tell shape about padding + xq = ttnn.reshape( + xq, + shape=ttnn.Shape( + [1, 1, batch, xq.shape[-1]], + [1, 1, xq.shape[-2], xq.shape[-1]], + ), ) - xq = ttnn.transpose(xq, 0, 2) - xq = ttnn.slice(xq, (0, 0, 0, 0), (batch, self.n_local_heads, 1, self.head_dim)) + xq = ttnn.reshape(xq, (1, batch, self.n_local_heads, self.head_dim)) xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT) xq = self.q_norm(xq, mode="decode") @@ -237,39 +238,32 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xk, xv = xattn_cache cache_seq_len = xk.shape[-2] - scores = ttnn.matmul( - xq, - ttnn.transpose(xk, -1, -2), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi4, - program_config=self.model_config["VISION_XATTN_SCORE_PROGCFG"](batch, cache_seq_len), + program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(), + q_chunk_size=32, + k_chunk_size=128, + exp_approx_mode=False, ) - scores = ttnn.multiply(scores, self.scale) - # WARNING: This add is buggy if xattn_mask has to be broadcasted to n_local_heads. Workaround is to broadcast on host side - # Host side must explicitly create this tensor with same padding as input tensor - scores = ttnn.add(scores, xattn_mask) - scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) + # TODO: Can I get rid of the KV repeat_interleave? - output = ttnn.matmul( - scores, + output = ttnn.transformer.scaled_dot_product_attention_decode( + xq, + xk, xv, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi4, - program_config=self.model_config["VISION_XATTN_OUTPUT_PROGCFG"](batch, cache_seq_len), + is_causal=False, + attn_mask=xattn_mask, + scale=self.scale, + program_config=program_config, + compute_kernel_config=self.compute_kernel_config_sdpa, ) # WARNING: this broadcast is also broken, must broadcast on host output = ttnn.mul(output, full_text_row_masked_out_mask_1NSH) - output = ttnn.transpose(output, 0, 2) # B, NH, 1, DH -> 1, NH, B, DH - output = ttnn.slice(output, (0, 0, 0, 0), (1, self.n_local_heads, batch, self.head_dim)) + output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT) + output = ttnn.reshape(output, (1, 1, batch, self.n_local_heads * self.head_dim)) output = ttnn.to_layout(output, layout=ttnn.TILE_LAYOUT) - # B, NH, S, DH -> B, S, D - # B, NH, 1, DH -> 1, 1, B, D - output = ttnn.experimental.nlp_concat_heads(output) # 1, NH, B, DH -> 1, 1, B, D output = ttnn.matmul( output, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 4f3d1cf394a..4e00fc384be 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -127,9 +127,6 @@ def forward( xattn_cache, mode, ): - seq_len = x_11SH.shape[-2] - # assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - attn_out = self.attention( x_11SH=self.attention_norm(x_11SH, mode=mode), xattn_mask=xattn_mask, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index b192917fe09..efb7ee090bc 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -52,7 +52,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ const auto mask_shape_unpadded = mask_tensor.get_logical_shape(); TT_FATAL(mask_shape[2] == q_shape[2], "Expect same number of padded heads in mask as in Q, got {} and {}", mask_shape[2], q_shape[2]); - TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[3], q_shape_unpadded[2]); + TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[2], q_shape_unpadded[2]); if (! this->paged_attention) TT_FATAL(mask_shape[3] == k_shape[2], "Expect same sequence length in mask as in K, got {} and {}", mask_shape[3], k_shape[2]); TT_FATAL(mask_shape[3] % k_chunk_size == 0, "Mask sequence length must be multiple of chunk size, got: {} and {}", mask_shape[3], k_chunk_size); @@ -151,7 +151,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ std::vector ScaledDotProductAttentionDecode::compute_output_shapes( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_padded_shape()}; + return {input_tensors.at(0).get_logical_shape()}; } std::vector ScaledDotProductAttentionDecode::create_output_tensors( From 7ed2318a04f9d8246f534cc8cba2e4f488af9aa3 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 4 Nov 2024 05:26:29 -0800 Subject: [PATCH 02/19] #14519: WIP create simpler interface for LlamaVision --- .../demos/llama3/demo/simple_vision_demo.py | 241 ++++++++++++++++++ .../tt/multimodal/llama_vision_model.py | 14 +- 2 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 models/demos/llama3/demo/simple_vision_demo.py diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py new file mode 100644 index 00000000000..f691a71c449 --- /dev/null +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Optional +from loguru import logger + +from PIL import Image as PIL_Image +from termcolor import cprint + +import llama_models.llama3.reference_impl.generation as llama_reference_generation +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat, ModelInput + +from llama_models.llama3.api.datatypes import ImageMedia, UserMessage + +from pkg_resources import resource_filename + +IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) + +import torch +import pytest +import os +import ttnn + + +class LlamaVision: + def __init__(self, model, model_args, mesh_device, vllm=False): + """ + Creating a LlamaVision wrapper requires only a mesh_device and model_args. + With model_args you have the checkpoint location, can specify max batch size + and max seqlen, and other model specific parameters. + + LlamaVision is general to text and chat. + + For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. + + """ + self.model = model + self.model_args = model_args + self.mesh_device = mesh_device + self.vllm = vllm + + def get_prefill_inputs(self, model_input): + """ + Responsible for taking model_input: ModelInput and returning vision_images, vision_mask, tokens + """ + images = model_input.vision.images + mask = model_input.visiom.mask + tokens = model_input.tokens + + return images, mask, tokens + + def forward_prefill(self, vision_images, vision_mask, tokens, total_len, text_only_inference=False): + """ + Performs vision encode step then text prefill. + Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) + """ + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=[vision_images], + batch_masks=[vision_mask], + total_len=total_len, + ) + + position_ids = torch.arange(tokens.shape[-1], dtype=torch.long) + + logits = self.model.forward( + position_ids, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + text_only_inference, + ) + + return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits + + def forward_decode( + self, + position_ids, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + text_only_inference=False, + ): + """ + Performs text decode step. + Returns logits + """ + pass + + +def get_sampler(temperature, top_p, tokenizer): + def sample(logits): + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = llama_reference_generation.sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + token = next_token[0].item() + text = tokenizer.decode(next_token.tolist()) + return token, text + + return sample + + +def create_multimodal_model(mesh_device, dtype=ttnn.bfloat16): + from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.demos.llama3.tt.model_config import TtModelArgs + + tt_model_args = TtModelArgs(mesh_device) + checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) + model = CrossAttentionTransformer( + mesh_device, + checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + ) + model.setup_cache(tt_model_args.max_batch_size, torch.float32) # TODO: is a no-op + return tt_model_args, model + + +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "target", + ("tt", "cpu"), +) +@pytest.mark.parametrize( + "warmup_iters", + (0, 1), +) +def test_llama_multimodal_demo_text( + mesh_device, + target, + warmup_iters, + temperature: float = 0.5, + top_p: float = 0.9, + max_seq_len: int = 512, + max_batch_size: int = 4, + max_gen_len: Optional[int] = 200, + model_parallel_size: Optional[int] = None, +): + """ + Simple multimodal demo with limited dependence on reference code. + """ + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + if target == "cpu": + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + model_args = generator.args + model = LlamaVision(generator.model, model_args, None) + tokenizer = generator.tokenizer + formatter = generator.formatter + else: + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device) + model = LlamaVision(model, model_args, mesh_device) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + + with open(IMG_PATH / "dog.jpg", "rb") as f: + img = PIL_Image.open(f).convert("RGB") + + dialogs = [] + with open(IMG_PATH / "dog.jpg", "rb") as f: + img = PIL_Image.open(f).convert("RGB") + + dialogs = [ + [ + UserMessage( + content=[ + ImageMedia(image=img), + "Describe this image in two sentences", + ], + ) + ], + ] + # text only + dialogs += [ + [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], + ] + + sampler = get_sampler(temperature, top_p, tokenizer) + + print(f"Running text completion on {target}") + for _ in range(warmup_iters + 1): + for dialog in dialogs: + # result = generator.chat_completion( + # dialog, + # max_gen_len=max_gen_len, + # temperature=temperature, + # top_p=top_p, + # ) + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") + + model_input = formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) + + # Do initial prefill + vision_images, vision_mask, tokens = model.get_prefill_inputs(model_input) + total_len = len(tokens) + max_gen_len # Prepares mask for full length of output + # Create tokens tensor + pad_id = tokenizer.pad_id + bsz = 1 + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) + tokens[0, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits = model.forward_prefill( + vision_images, vision_mask, tokens, total_len + ) + + next_token, text = sampler(logits) + logger.info(f"Prefill output: {text}") + + # Iterate over decode + # for gen_idx in range(max_gen_len-1): + + # out_message = result.generation + # print(f"> {out_message.role.capitalize()}: {out_message.content}") + # for t in out_message.tool_calls: + # print(f" Tool call: {t.tool_name} ({t.arguments})") + # print("\n==================================\n") diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 47bff66e6e8..214bc6ce7c5 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -123,7 +123,6 @@ def _get_xattn_mask( class CrossAttentionTransformer(torch.nn.Module): def __init__( self, - args: llama_reference_model.ModelArgs, mesh_device, state_dict, weight_cache_path, @@ -131,9 +130,8 @@ def __init__( configuration, ) -> None: super().__init__() - self.params = args - self.model_dim = args.dim + self.model_dim = configuration.dim self.mesh_device = mesh_device self.state_dict = state_dict @@ -162,11 +160,11 @@ def __init__( dtype=ttnn.bfloat8_b, configuration=configuration, ) - self.image_res = args.vision_chunk_size - self.max_num_chunks = args.vision_max_num_chunks + self.image_res = configuration.vision_chunk_size + self.max_num_chunks = configuration.vision_max_num_chunks self.image_transform = partial( - llama_reference_image_transforms.VariableSizeImageTransform(size=args.vision_chunk_size), - max_num_chunks=args.vision_max_num_chunks, + llama_reference_image_transforms.VariableSizeImageTransform(size=configuration.vision_chunk_size), + max_num_chunks=configuration.vision_max_num_chunks, ) def setup_cache(self, max_batch_size: int, dtype: torch.dtype): @@ -200,7 +198,7 @@ def compute_vision_tokens_masks( stacked_images, num_chunks = _stack_images( transformed_images, max_num_chunks=self.max_num_chunks, - image_res=self.params.vision_chunk_size, + image_res=self.configuration.vision_chunk_size, max_num_images=max_num_images, ) From 42f87a764f250a94606a5d81af5d76a399ec7042 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 4 Nov 2024 10:30:49 -0800 Subject: [PATCH 03/19] #14519: Update xattn test input shapes since masks with non-causal FlashDecode are different --- .../multimodal/test_llama_class_embedding.py | 2 +- .../multimodal/test_llama_cross_attention.py | 16 ++++++++++------ ...est_llama_cross_attention_transformer_text.py | 15 ++++++++++----- .../tests/multimodal/test_llama_cross_block.py | 14 ++++++++------ 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index 663787a18d1..dc395842338 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -63,7 +63,7 @@ def forward(self, x): @pytest.mark.parametrize( "bsz, num_concurrent_media, num_chunks", [ - ((1, 4, 4)), + ((1, 1, 4)), ], ) @pytest.mark.parametrize( diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 383cf3dd8bb..54f972f1400 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -32,7 +32,8 @@ ], indirect=True, ) -def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize("batch", (1,)) +def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -57,7 +58,6 @@ def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_ ) reference_model.load_state_dict(partial_state_dict) - batch = 1 num_chunks = 4 vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) @@ -151,6 +151,8 @@ def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_ xattn_mask = xattn_mask * -1e9 xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) + if mode == "decode": + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, @@ -163,8 +165,8 @@ def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_ tt_xattn_mask = ttnn.reshape( tt_xattn_mask, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, vision_seq_len], - [batch, n_heads // model_args.num_devices, 32, vision_seq_len], + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], ), ) @@ -179,6 +181,8 @@ def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_ ) full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) full_text_mask_expand = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) + if mode == "decode": + full_text_mask_expand = full_text_mask_expand.permute(2, 0, 1, 3).contiguous() tt_full_text_mask = ttnn.from_torch( full_text_mask_expand, device=mesh_device, @@ -191,8 +195,8 @@ def test_llama_cross_attention_inference(text_seq_len, mesh_device, use_program_ tt_full_text_mask = ttnn.reshape( tt_full_text_mask, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, head_dim], - [batch, n_heads // model_args.num_devices, 32, head_dim], + [1, batch, n_heads // model_args.num_devices, head_dim], + [1, batch, 32, head_dim], ), ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 26f30bdef91..3873f9adcb7 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -39,8 +39,10 @@ ], indirect=True, ) +@pytest.mark.parametrize("batch", (1,)) def test_llama_cross_attention_transformer_text_inference( text_seq_len, + batch, mesh_device, use_program_cache, reset_seeds, @@ -82,7 +84,6 @@ def test_llama_cross_attention_transformer_text_inference( reference_model.setup_cache(model_args.max_batch_size, torch.float32) reference_model.load_state_dict(partial_state_dict) - batch = 1 num_chunks = 4 chunk_length = nearest_32(model_args.vision_chunk_ntok) vision_seq_len = num_chunks * chunk_length @@ -235,6 +236,8 @@ def test_llama_cross_attention_transformer_text_inference( transformation_mats = None xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) + if mode == "decode": + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, @@ -247,11 +250,13 @@ def test_llama_cross_attention_transformer_text_inference( tt_xattn_mask = ttnn.reshape( tt_xattn_mask, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, vision_seq_len], - [batch, n_heads // model_args.num_devices, 32, vision_seq_len], + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], ), ) full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) + if mode == "decode": + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, @@ -264,8 +269,8 @@ def test_llama_cross_attention_transformer_text_inference( tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, shape=ttnn.Shape( - [batch, n_heads // model_args.num_devices, seq_len, head_dim], - [batch, n_heads // model_args.num_devices, 32, head_dim], + [1, batch, n_heads // model_args.num_devices, head_dim], + [1, batch, 32, head_dim], ), ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 26ebfc73d23..7ccf90e0004 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -28,8 +28,9 @@ ], indirect=True, ) +@pytest.mark.parametrize("batch", (1,)) def test_llama_cross_attention_transformer_block_inference( - text_seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc + text_seq_len, batch, mesh_device, use_program_cache, reset_seeds, ensure_gc ): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -53,7 +54,6 @@ def test_llama_cross_attention_transformer_block_inference( reference_model = llama_reference_mod.CrossAttentionTransformerBlock(args=model_args, layer_id=0, no_ffn=False) reference_model.load_state_dict(partial_state_dict) - batch = 1 num_chunks = 4 vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) @@ -138,9 +138,9 @@ def test_llama_cross_attention_transformer_block_inference( xattn_mask = xattn_mask.unsqueeze(1) xattn_mask = xattn_mask * -1e9 - xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) + xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) # B, NH, St, Sv if mode == "decode": - xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, @@ -168,9 +168,11 @@ def test_llama_cross_attention_transformer_block_inference( ) ) full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) - full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) + full_text_mask_expand_1NSH = full_text_mask.expand( + -1, n_heads // model_args.num_devices, -1, head_dim + ) # B, NH, St, Hd if mode == "decode": - full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, From 97a9d451467b4533433e60e517afc9091e5d5f0c Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 4 Nov 2024 10:31:52 -0800 Subject: [PATCH 04/19] #14519: Change TMs in xattn. Naive TMs now fail xattn test, so this commit goes back to nlp_tms to create/concat heads. --- .../tt/multimodal/llama_cross_attention.py | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index c890cb9be5a..ff0e0b79223 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -163,7 +163,7 @@ def compute_xattn_kv_cache(self, xattn_tokens): xk = ttnn.reshape(xk, [bsz, 1, seqlen_y, -1]) xv = ttnn.reshape(xv, [bsz, 1, seqlen_y, -1]) else: - # 1, B, S, D -> B, NH, S, DH? + # 1, B, S, D -> B, NH, S, DH xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( xk, xk, @@ -178,6 +178,15 @@ def compute_xattn_kv_cache(self, xattn_tokens): num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False, ) + # def create_heads(x): + # x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + # x = ttnn.reshape(x, [bsz, seqlen_y, self.n_local_kv_heads, self.head_dim]) + # x = ttnn.transpose(x, 1, 2) + # x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) + # return x + + # xk = create_heads(xk) + # xv = create_heads(xv) xk = self.k_norm(xk, mode="decode") @@ -219,18 +228,26 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, program_config=self.model_config["VISION_XATTN_Q_PROGCFG"](batch), ) - # Below is how we want to reshape. It results in poor PCC - # 1, B, D -> B, 1, NH, DH -> B, NH, 1, DH - xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) - # Tell shape about padding - xq = ttnn.reshape( - xq, - shape=ttnn.Shape( - [1, 1, batch, xq.shape[-1]], - [1, 1, xq.shape[-2], xq.shape[-1]], - ), + # # Below is how we want to reshape. It results in poor PCC + # # 1, B, D -> B, 1, NH, DH -> B, NH, 1, DH + # xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) + # # Tell shape about padding + # xq = ttnn.reshape( + # xq, + # shape=ttnn.Shape( + # [1, 1, batch, xq.shape[-1]], + # [1, 1, xq.shape[-2], xq.shape[-1]], + # ), + # ) + # xq = ttnn.reshape(xq, (1, batch, self.n_local_heads, self.head_dim)) + # xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT) + + xq, _, _ = ttnn.experimental.nlp_create_qkv_heads( + xq, xq, num_heads=self.n_local_heads, num_kv_heads=self.n_local_heads // 2, transpose_k_heads=False ) - xq = ttnn.reshape(xq, (1, batch, self.n_local_heads, self.head_dim)) + xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) + xq = ttnn.slice(xq, (0, 0, 0, 0), (xq.shape[0], xq.shape[1], batch, xq.shape[3])) + xq = ttnn.transpose(xq, 1, 2) xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT) xq = self.q_norm(xq, mode="decode") @@ -261,9 +278,16 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, # WARNING: this broadcast is also broken, must broadcast on host output = ttnn.mul(output, full_text_row_masked_out_mask_1NSH) + # This is how we should be reshaping + # output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT) + # output = ttnn.reshape(output, (1, 1, batch, self.n_local_heads * self.head_dim)) + # output = ttnn.to_layout(output, layout=ttnn.TILE_LAYOUT) + output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT) - output = ttnn.reshape(output, (1, 1, batch, self.n_local_heads * self.head_dim)) + output = ttnn.transpose(output, 1, 2) # 1, B, NH, DH -> 1, NH, B, DH + output = ttnn.slice(output, (0, 0, 0, 0), (1, self.n_local_heads, batch, self.head_dim)) output = ttnn.to_layout(output, layout=ttnn.TILE_LAYOUT) + output = ttnn.experimental.nlp_concat_heads(output) output = ttnn.matmul( output, From 236b5f8f8241de827c21a3edff50cd1e59f7337f Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Mon, 4 Nov 2024 10:33:12 -0800 Subject: [PATCH 05/19] #14519: Simple vision demo is functional, with llama_vision_model supporting mask shapes required by non-causal FlashDecode --- .../demos/llama3/demo/simple_vision_demo.py | 101 +++++++++++++----- .../tt/multimodal/llama_vision_model.py | 12 ++- 2 files changed, 80 insertions(+), 33 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index f691a71c449..b147747d0f2 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -22,6 +22,7 @@ import pytest import os import ttnn +import time class LlamaVision: @@ -46,12 +47,12 @@ def get_prefill_inputs(self, model_input): Responsible for taking model_input: ModelInput and returning vision_images, vision_mask, tokens """ images = model_input.vision.images - mask = model_input.visiom.mask + mask = model_input.vision.mask tokens = model_input.tokens return images, mask, tokens - def forward_prefill(self, vision_images, vision_mask, tokens, total_len, text_only_inference=False): + def forward_prefill(self, vision_images, vision_mask, tokens, total_len, prefill_len, text_only_inference=False): """ Performs vision encode step then text prefill. Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) @@ -62,7 +63,7 @@ def forward_prefill(self, vision_images, vision_mask, tokens, total_len, text_on total_len=total_len, ) - position_ids = torch.arange(tokens.shape[-1], dtype=torch.long) + position_ids = torch.arange(prefill_len, dtype=torch.long) logits = self.model.forward( position_ids, @@ -77,7 +78,7 @@ def forward_prefill(self, vision_images, vision_mask, tokens, total_len, text_on def forward_decode( self, - position_ids, + position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, @@ -88,11 +89,21 @@ def forward_decode( Performs text decode step. Returns logits """ - pass + position_ids = torch.tensor([position_id], dtype=torch.long) + logits = self.model.forward( + position_ids, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + text_only_inference, + ) + return logits def get_sampler(temperature, top_p, tokenizer): def sample(logits): + logger.info(f"Sampling {logits.shape=}") if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = llama_reference_generation.sample_top_p(probs, top_p) @@ -141,15 +152,22 @@ def create_multimodal_model(mesh_device, dtype=ttnn.bfloat16): "warmup_iters", (0, 1), ) +@pytest.mark.parametrize( + "test_case", + [ + "normal", + ], +) def test_llama_multimodal_demo_text( mesh_device, target, warmup_iters, - temperature: float = 0.5, + test_case, + temperature: float = 0, top_p: float = 0.9, max_seq_len: int = 512, max_batch_size: int = 4, - max_gen_len: Optional[int] = 200, + max_gen_len: Optional[int] = 100, model_parallel_size: Optional[int] = None, ): """ @@ -195,47 +213,72 @@ def test_llama_multimodal_demo_text( ) ], ] - # text only - dialogs += [ - [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], - ] sampler = get_sampler(temperature, top_p, tokenizer) print(f"Running text completion on {target}") - for _ in range(warmup_iters + 1): + for iter_num in range(warmup_iters + 1): for dialog in dialogs: - # result = generator.chat_completion( - # dialog, - # max_gen_len=max_gen_len, - # temperature=temperature, - # top_p=top_p, - # ) for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") + if iter_num <= warmup_iters: + logger.info(f"Warmup iteration {iter_num}") + model_input = formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) # Do initial prefill - vision_images, vision_mask, tokens = model.get_prefill_inputs(model_input) - total_len = len(tokens) + max_gen_len # Prepares mask for full length of output + vision_images, vision_mask, prompt_tokens = model.get_prefill_inputs(model_input) + prefill_len = len(prompt_tokens) + total_len = prefill_len + max_gen_len # Prepares mask for full length of output # Create tokens tensor pad_id = tokenizer.pad_id bsz = 1 tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) - tokens[0, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long) + prefill_start = time.perf_counter() xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits = model.forward_prefill( - vision_images, vision_mask, tokens, total_len + vision_images, vision_mask, tokens, total_len, prefill_len ) + prefill_end = time.perf_counter() next_token, text = sampler(logits) - logger.info(f"Prefill output: {text}") + logger.info(f"Prefill output: {next_token}:{text}") + tokens[0, prefill_len] = next_token + decode_times = [] # Iterate over decode - # for gen_idx in range(max_gen_len-1): - # out_message = result.generation - # print(f"> {out_message.role.capitalize()}: {out_message.content}") - # for t in out_message.tool_calls: - # print(f" Tool call: {t.tool_name} ({t.arguments})") - # print("\n==================================\n") + for gen_idx in range(max_gen_len - 1): + decode_start = time.perf_counter() + position_id = prefill_len + gen_idx + logits = model.forward_decode( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + next_token, text = sampler(logits) + # Update next token + tokens[0, position_id + 1] = next_token + logger.info(f"Decode output {position_id}: {next_token}:{text}") + decode_end = time.perf_counter() + decode_times.append(decode_end - decode_start) + + if text in ["<|eot_id|>", "<|eom_id|>"]: + break + + # Log full text output + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + # Remove <|image|> tokens since they break the tokenizer + tokens_out = [ + t if t not in vision_tokens else tokenizer.pad_id for t in tokens[0].tolist()[: position_id + 2] + ] + text = tokenizer.decode(tokens_out) + logger.info(f"Full text: {text}") + + prefill_time_ms = (prefill_end - prefill_start) * 1000 + logger.info(f"Prefill time: {prefill_time_ms:.2f} ms") + decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 + logger.info(f"Decode time: {decode_time_ms:.2f} ms") diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 214bc6ce7c5..780925fb242 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -312,6 +312,8 @@ def forward( "constant", get_negative_inf_value(torch.float32), ) + if mode == "decode": + xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, @@ -331,6 +333,8 @@ def forward( full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) + if mode == "decode": + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, @@ -386,26 +390,26 @@ def forward( tt_xattn_mask, shape=ttnn.Shape( [ + seq_len, batch, self.configuration.n_heads // self.configuration.num_devices, - seq_len, xattn_mask.shape[-1], ], - [batch, self.configuration.n_heads // self.configuration.num_devices, 32, xattn_mask.shape[-1]], + [seq_len, batch, 32, xattn_mask.shape[-1]], ), ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, shape=ttnn.Shape( [ + seq_len, batch, self.configuration.n_heads // self.configuration.num_devices, - seq_len, self.configuration.head_dim, ], [ + seq_len, batch, - self.configuration.n_heads // self.configuration.num_devices, 32, self.configuration.head_dim, ], From 555f27dd896377982d893aad77a793a7a3feccc9 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 5 Nov 2024 12:21:48 -0800 Subject: [PATCH 06/19] #14519: unit tests for xattn, xblock, and xtransformer now support batch > 1. WIP, since these changes have now broken the full model and demos --- .../multimodal/test_llama_cross_attention.py | 165 +++++--- ..._llama_cross_attention_transformer_text.py | 389 ++++++++++-------- .../multimodal/test_llama_cross_block.py | 185 +++++---- models/demos/llama3/tt/llama_common.py | 6 +- models/demos/llama3/tt/model_config.py | 9 + .../tt/multimodal/llama_cross_attention.py | 52 ++- .../llama_cross_attention_transformer_text.py | 30 +- .../llama3/tt/multimodal/llama_cross_block.py | 6 +- 8 files changed, 513 insertions(+), 329 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 54f972f1400..73fc48f44b1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -32,7 +32,7 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1,)) +@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -79,10 +79,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset pt_xattn_tokens = (torch.rand(batch, vision_seq_len, dim) * 2) - 1 tt_xattn_tokens = pt_xattn_tokens.clone() - tt_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_xattn_tokens, - force_replicated=True, - ) """ Test compute_xattn_kv_cache @@ -91,7 +87,25 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] - tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) + # Iterate over batch + # Preallocate K and V caches + tt_xattn_cache = [ + ttnn.from_torch( + torch.zeros(batch, n_heads, vision_seq_len, head_dim), + device=mesh_device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + ) + for _ in range(2) + ] + for b in range(batch): + tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_xattn_tokens[b : b + 1], + force_replicate=True, + ) + tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( batch, @@ -123,20 +137,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset mode = "prefill" if i == 0 else "decode" pt_x = (torch.rand(batch, seq_len, dim) * 2) - 1 tt_x = pt_x.clone() - if mode == "prefill": - tt_x = model_args.prepare_inputs_ttnn_prefill( - tt_x, - force_replicated=True, - ) - else: - tt_x = model_args.prepare_inputs_ttnn_decode( - tt_x, - ttnn.DRAM_MEMORY_CONFIG, - force_replicated=True, - ) - # TODO Convert to sharded input for decode, since that's what attention expects from RMSnorm - tt_x = ttnn.interleaved_to_sharded(tt_x, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"]) + # Common mask prep xattn_mask = torch.bernoulli( torch.full( ( @@ -151,24 +153,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset xattn_mask = xattn_mask * -1e9 xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) - if mode == "decode": - xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() - tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - if mode == "decode": - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [1, batch, n_heads // model_args.num_devices, vision_seq_len], - [1, batch, 32, vision_seq_len], - ), - ) full_text_mask = torch.bernoulli( torch.full( @@ -181,17 +165,80 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset ) full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) full_text_mask_expand = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) - if mode == "decode": - full_text_mask_expand = full_text_mask_expand.permute(2, 0, 1, 3).contiguous() - tt_full_text_mask = ttnn.from_torch( - full_text_mask_expand, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + + pt_out = reference_model.forward( + pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache ) - if mode == "decode": + + if mode == "prefill": + outputs = [] + for b in range(batch): + tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( + tt_x[b : b + 1], + force_replicated=True, + ) + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_full_text_mask = ttnn.from_torch( + full_text_mask_expand[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_out = tt_model( + tt_tensor_x, + xattn_mask=tt_xattn_mask, + full_text_row_masked_out_mask_1NSH=tt_full_text_mask, + xattn_cache=tt_xattn_cache, + mode=mode, + user_id=b, + ) + + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) + tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim) + outputs.append(tt_output_torch) + tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) + + else: + tt_x = model_args.prepare_inputs_ttnn_decode( + tt_x, + ttnn.DRAM_MEMORY_CONFIG, + force_replicated=True, + ) + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], + ), + ) + + full_text_mask_expand = full_text_mask_expand.permute(2, 0, 1, 3).contiguous() + tt_full_text_mask = ttnn.from_torch( + full_text_mask_expand, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) tt_full_text_mask = ttnn.reshape( tt_full_text_mask, shape=ttnn.Shape( @@ -200,23 +247,17 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset ), ) - pt_out = reference_model.forward( - pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache - ) + tt_out = tt_model( + tt_x, + xattn_mask=tt_xattn_mask, + full_text_row_masked_out_mask_1NSH=tt_full_text_mask, + xattn_cache=tt_xattn_cache, + mode=mode, + ) - tt_out = tt_model( - tt_x, - xattn_mask=tt_xattn_mask, - full_text_row_masked_out_mask_1NSH=tt_full_text_mask, - xattn_cache=tt_xattn_cache, - mode=mode, - ) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) + tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim) - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) - if mode == "prefill": - tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) - else: - tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 3873f9adcb7..9548e70fd5f 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -39,7 +39,8 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1,)) +@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) +@torch.no_grad() def test_llama_cross_attention_transformer_text_inference( text_seq_len, batch, @@ -53,11 +54,13 @@ def test_llama_cross_attention_transformer_text_inference( mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host model_args.max_seq_len = 4096 model_args.kv_seq_len = model_args.max_seq_len model_args.sliding_window = model_args.max_seq_len + model_args.n_layers = 1 + model_args.vision_num_cross_attention_layers = 1 state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -66,154 +69,174 @@ def test_llama_cross_attention_transformer_text_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - tt_model = TtLlamaCrossAttentionTransformerText( - mesh_device, - state_dict, - state_dict_prefix=first_layer_prefix, - weight_cache_path=model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=model_args, - ) - dim = model_args.dim head_dim = model_args.head_dim n_heads = model_args.n_heads - n_kv_heads = model_args.n_kv_heads - # norm_eps = model_args.norm_eps reference_model = llama_reference_mod.CrossAttentionTransformerText(args=model_args) reference_model.setup_cache(model_args.max_batch_size, torch.float32) - reference_model.load_state_dict(partial_state_dict) + reference_model.load_state_dict(partial_state_dict, strict=False) num_chunks = 4 - chunk_length = nearest_32(model_args.vision_chunk_ntok) - vision_seq_len = num_chunks * chunk_length + vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) all_tests_pass = True + tt_model = TtLlamaCrossAttentionTransformerText( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) vision_tokens = torch.randn((batch, vision_seq_len, dim)) tt_vision_tokens = vision_tokens.clone() - tt_vision_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_vision_tokens, - force_replicated=True, - ) - with torch.no_grad(): - """ - Test compute_xattn_kv_cache - """ - xattn_caches = torch.stack( - [layer.compute_xattn_kv_cache(vision_tokens) for layer in reference_model.cross_attention_layers] + """ + Test compute_xattn_kv_cache + """ + xattn_caches = torch.stack( + [layer.compute_xattn_kv_cache(vision_tokens) for layer in reference_model.cross_attention_layers] + ) + # unstack layers + pt_xattn_cache_chunks = torch.chunk(xattn_caches, len(reference_model.cross_attention_layers), dim=0) + # unstack k/v + pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] + pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] + # slice out replicated k/v heads + pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] + + # Iterate over batch + # Preallocate K and V caches + tt_xattn_cache = tt_model.setup_cache(max_batch_size=batch) + for b in range(batch): + tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_vision_tokens[b : b + 1], + force_replicate=True, ) - # unstack layers - pt_xattn_cache_chunks = torch.chunk(xattn_caches, len(reference_model.cross_attention_layers), dim=0) - # unstack k/v - pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] - pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] - # slice out replicated k/v heads - pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] - - tt_xattn_cache = [layer.compute_xattn_kv_cache(tt_vision_tokens) for layer in tt_model.cross_attention_layers] - tt_xattn_cache_torch = [ - ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, - n_heads, - vision_seq_len, - head_dim, - ) - for kv_cache in tt_xattn_cache - for x in kv_cache - ] - for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) - - logger.info(comp_allclose(pt, tt)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"compute_xattn_kv_cache Passed!") - else: - logger.warning(f"compute_xattn_kv_cache Failed!") - all_tests_pass = False - - assert all_tests_pass - - # Test forward pass of the model - n_iter = 10 - prev_pos = 0 - # tokens = torch.randint(100, 1000, (batch, text_seq_len+n_iter), dtype=torch.long)#, device="cuda" - tokens = torch.randint( - 0, model_args.vocab_size, (batch, text_seq_len + n_iter), dtype=torch.long - ) # , device="cuda" - for i in range(n_iter): - # Test prefill and decode - mode = "prefill" if i == 0 else "decode" - seq_len = text_seq_len if mode == "prefill" else 1 - cur_pos = seq_len + prev_pos - - # Prepare pytorch inputs - position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) # , device="cuda" - - logger.info(f"mode: {mode}, seq_len: {seq_len}, cur_pos: {cur_pos}") - logger.info(f"position_ids: {position_ids}") - xattn_mask = torch.bernoulli( - torch.full( - ( - batch, - seq_len, - vision_seq_len, - ), - 0.25, - ) + tt_xattn_cache = [ + layer.compute_xattn_kv_cache(tt_tensor_vision_tokens, tt_xattn_cache[layer_num], user_id=b) + for layer_num, layer in enumerate(tt_model.cross_attention_layers) + ] + tt_xattn_cache_torch = [ + ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( + batch, + n_heads, + vision_seq_len, + head_dim, + ) + for kv_cache in tt_xattn_cache + for x in kv_cache + ] + + for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): + passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) + + logger.info(comp_allclose(pt, tt)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info(f"compute_xattn_kv_cache Passed!") + else: + logger.warning(f"compute_xattn_kv_cache Failed!") + all_tests_pass = False + + assert all_tests_pass + + # Test forward pass of the model + n_iter = 10 + prev_pos = 0 + # tokens = torch.randint(100, 1000, (batch, text_seq_len+n_iter), dtype=torch.long)#, device="cuda" + tokens = torch.randint(0, model_args.vocab_size, (batch, text_seq_len + n_iter), dtype=torch.long) + for i in range(n_iter): + # Test prefill and decode + mode = "prefill" if i == 0 else "decode" + seq_len = text_seq_len if mode == "prefill" else 1 + cur_pos = seq_len + prev_pos + + # Prepare pytorch inputs + position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) # , device="cuda" + + logger.info(f"mode: {mode}, seq_len: {seq_len}, cur_pos: {cur_pos}") + logger.info(f"position_ids: {position_ids}") + + # Common mask prep + xattn_mask = torch.bernoulli( + torch.full( + ( + batch, + seq_len, + vision_seq_len, + ), + 0.25, ) - xattn_mask = xattn_mask.unsqueeze(1) - xattn_mask = xattn_mask * -1e9 - - full_text_mask = torch.bernoulli( - torch.full( - ( - batch, - seq_len, - ), - 0.75 if seq_len != 1 else 1.0, - ) + ) + xattn_mask = xattn_mask.unsqueeze(1) + xattn_mask = xattn_mask * -1e9 + + xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) + + full_text_mask = torch.bernoulli( + torch.full( + ( + batch, + seq_len, + ), + 0.75 if seq_len != 1 else 1.0, ) - full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) + ) + full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) + full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) - h = reference_model.get_partially_trainable_embedding(tokens[:, position_ids]) + full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, dim) - logits = reference_model.forward( - position_ids, - h, - xattn_mask, - full_text_mask, - xattn_caches, - text_only_inference=True, - ) + h = reference_model.get_partially_trainable_embedding(tokens[:, position_ids]) - # Prepare TT inputs + TEXT_ONLY = False - if mode == "prefill": + logits = reference_model.forward( + position_ids, + h, + xattn_mask, + full_text_mask, + xattn_caches, + text_only_inference=TEXT_ONLY, + ) + + # Prepare TT inputs + if mode == "prefill": + outputs = [] + for b in range(batch): tt_h = model_args.prepare_inputs_ttnn_prefill( - h, + h[b : b + 1], ) - else: - tt_h = model_args.prepare_inputs_ttnn_decode( - h, - ttnn.DRAM_MEMORY_CONFIG, + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_full_text_mask_expand_1NSH = ttnn.from_torch( + full_text_mask_expand_1NSH[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_full_text_mask_expand_11SD = ttnn.from_torch( + full_text_mask_expand_11SD[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - tt_position_id = ttnn.from_torch( - position_ids.reshape(batch, seq_len), - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - - if mode == "prefill": rot_mats = get_prefill_rot_mat( model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len ) @@ -226,18 +249,50 @@ def test_llama_cross_attention_transformer_text_inference( mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - else: - rot_mats, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=cur_pos - 1, + + tt_out = tt_model( + tt_h, + xattn_mask=tt_xattn_mask, + full_text_row_masked_out_mask_1NSH=tt_full_text_mask_expand_1NSH, + full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, + xattn_caches=tt_xattn_cache, + current_pos=None, + rot_mat=rot_mats, + transformation_mats=transformation_mats, + user_id=b, + mode=mode, + text_only_inference=TEXT_ONLY, ) - transformation_mats = None - xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) - if mode == "decode": - xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) + tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, -1) + outputs.append(tt_output_torch) + + tt_out = torch.cat(outputs, dim=0).view(batch, seq_len, -1) + pcc_required = prefill_pcc_required + + else: + tt_h = model_args.prepare_inputs_ttnn_decode( + h, + ttnn.DRAM_MEMORY_CONFIG, + ) + position_ids = position_ids.reshape(1).expand(batch) + tt_position_id = ttnn.from_torch( + position_ids, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + rot_mats, _ = get_single_rot_mat( + model_args.head_dim, mesh_device, model_args.num_devices, start_pos=cur_pos - 1, batch=batch + ) + + transformation_mats = None + + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, @@ -246,17 +301,15 @@ def test_llama_cross_attention_transformer_text_inference( memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - if mode == "decode": - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [1, batch, n_heads // model_args.num_devices, vision_seq_len], - [1, batch, 32, vision_seq_len], - ), - ) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], + ), + ) full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) - if mode == "decode": - full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, @@ -265,23 +318,12 @@ def test_llama_cross_attention_transformer_text_inference( memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - if mode == "decode": - tt_full_text_mask_expand_1NSH = ttnn.reshape( - tt_full_text_mask_expand_1NSH, - shape=ttnn.Shape( - [1, batch, n_heads // model_args.num_devices, head_dim], - [1, batch, 32, head_dim], - ), - ) - - full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, dim) - tt_full_text_mask_expand_11SD = ttnn.from_torch( - full_text_mask_expand_11SD, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + tt_full_text_mask_expand_1NSH = ttnn.reshape( + tt_full_text_mask_expand_1NSH, + shape=ttnn.Shape( + [1, batch, n_heads // model_args.num_devices, head_dim], + [1, batch, 32, head_dim], + ), ) if mode == "decode": tt_full_text_mask_expand_11SD = None @@ -290,25 +332,22 @@ def test_llama_cross_attention_transformer_text_inference( tt_h, xattn_mask=tt_xattn_mask, full_text_row_masked_out_mask_1NSH=tt_full_text_mask_expand_1NSH, - full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, + full_text_row_masked_out_mask_11SD=None, xattn_caches=tt_xattn_cache, current_pos=tt_position_id, rot_mat=rot_mats, transformation_mats=transformation_mats, - user_id=0, mode=mode, - text_only_inference=True, + text_only_inference=TEXT_ONLY, ) tt_out = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - if mode == "prefill": - tt_out = tt_out[0].reshape(logits.shape) - pcc_required = prefill_pcc_required - else: - tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).view(logits.shape) - pcc_required = decode_pcc_required - passing, pcc_message = comp_pcc(logits, tt_out, pcc_required) - logger.info(comp_allclose(logits, tt_out)) - logger.info(f"PCC: {pcc_message}") - prev_pos = cur_pos - assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + + tt_out = tt_out[0, :, :batch, :].reshape(logits.shape) + pcc_required = decode_pcc_required + + passing, pcc_message = comp_pcc(logits, tt_out, pcc_required) + logger.info(comp_allclose(logits, tt_out)) + logger.info(f"PCC: {pcc_message}") + prev_pos = cur_pos + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 7ccf90e0004..e95bc5dc649 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -28,7 +28,7 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1,)) +@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) def test_llama_cross_attention_transformer_block_inference( text_seq_len, batch, mesh_device, use_program_cache, reset_seeds, ensure_gc ): @@ -37,7 +37,11 @@ def test_llama_cross_attention_transformer_block_inference( mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch) + # Limit the max seqlen to 4k to avoid OOM on host + model_args.max_seq_len = 4096 + model_args.kv_seq_len = model_args.max_seq_len + model_args.sliding_window = model_args.max_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -49,8 +53,6 @@ def test_llama_cross_attention_transformer_block_inference( dim = model_args.dim head_dim = model_args.head_dim n_heads = model_args.n_heads - n_kv_heads = model_args.n_kv_heads - norm_eps = model_args.norm_eps reference_model = llama_reference_mod.CrossAttentionTransformerBlock(args=model_args, layer_id=0, no_ffn=False) reference_model.load_state_dict(partial_state_dict) @@ -71,10 +73,6 @@ def test_llama_cross_attention_transformer_block_inference( pt_xattn_tokens = (torch.rand(batch, vision_seq_len, dim) * 2) - 1 tt_xattn_tokens = pt_xattn_tokens.clone() - tt_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_xattn_tokens, - force_replicated=True, - ) """ Test compute_xattn_kv_cache @@ -83,7 +81,25 @@ def test_llama_cross_attention_transformer_block_inference( pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] - tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_xattn_tokens) + # Iterate over batch + # Preallocate K and V caches + tt_xattn_cache = [ + ttnn.from_torch( + torch.zeros(batch, n_heads, vision_seq_len, head_dim), + device=mesh_device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + ) + for _ in range(2) + ] + for b in range(batch): + tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_xattn_tokens[b : b + 1], + force_replicate=True, + ) + tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( batch, @@ -110,21 +126,14 @@ def test_llama_cross_attention_transformer_block_inference( """ Test forward, prefill and decode! """ - for i in range(10): - seq_len = text_seq_len if i == 0 else 1 + n_iter = 10 + for i in range(n_iter): mode = "prefill" if i == 0 else "decode" + seq_len = text_seq_len if mode == "prefill" else 1 pt_x = (torch.rand(batch, seq_len, dim) * 2) - 1 tt_x = pt_x.clone() - if mode == "prefill": - tt_x = model_args.prepare_inputs_ttnn_prefill( - tt_x, - ) - else: - tt_x = model_args.prepare_inputs_ttnn_decode( - tt_x, - ttnn.DRAM_MEMORY_CONFIG, # TODO for the current configuration the decode input needs to be on DRAM - ) + # Common mask prep xattn_mask = torch.bernoulli( torch.full( ( @@ -138,25 +147,7 @@ def test_llama_cross_attention_transformer_block_inference( xattn_mask = xattn_mask.unsqueeze(1) xattn_mask = xattn_mask * -1e9 - xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) # B, NH, St, Sv - if mode == "decode": - xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() - tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - if mode == "decode": - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [1, batch, n_heads // model_args.num_devices, vision_seq_len], - [1, batch, 32, vision_seq_len], - ), - ) + xattn_mask_expand = xattn_mask.expand(-1, n_heads // model_args.num_devices, -1, -1) full_text_mask = torch.bernoulli( torch.full( @@ -168,20 +159,90 @@ def test_llama_cross_attention_transformer_block_inference( ) ) full_text_mask = full_text_mask.unsqueeze(1).unsqueeze(-1) - full_text_mask_expand_1NSH = full_text_mask.expand( - -1, n_heads // model_args.num_devices, -1, head_dim - ) # B, NH, St, Hd - if mode == "decode": - full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() - tt_full_text_mask_expand_1NSH = ttnn.from_torch( - full_text_mask_expand_1NSH, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + full_text_mask_expand_1NSH = full_text_mask.expand(-1, n_heads // model_args.num_devices, -1, head_dim) + + full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, dim) + + pt_out = reference_model.forward( + pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache ) - if mode == "decode": + + if mode == "prefill": + outputs = [] + for b in range(batch): + tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( + tt_x[b : b + 1], + ) + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_full_text_mask_expand_1NSH = ttnn.from_torch( + full_text_mask_expand_1NSH[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_full_text_mask_expand_11SD = ttnn.from_torch( + full_text_mask_expand_11SD[b : b + 1], + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_out = tt_model( + tt_tensor_x, + xattn_mask=tt_xattn_mask, + full_text_row_masked_out_mask_1NSH=tt_full_text_mask_expand_1NSH, + full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, + xattn_cache=tt_xattn_cache, + mode=mode, + user_id=b, + ) + + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) + tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim) + outputs.append(tt_output_torch) + tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) + + else: + tt_x = model_args.prepare_inputs_ttnn_decode( + tt_x, + ttnn.DRAM_MEMORY_CONFIG, + ) + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [1, batch, n_heads // model_args.num_devices, vision_seq_len], + [1, batch, 32, vision_seq_len], + ), + ) + + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.permute(2, 0, 1, 3).contiguous() + tt_full_text_mask_expand_1NSH = ttnn.from_torch( + full_text_mask_expand_1NSH, + device=mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, shape=ttnn.Shape( @@ -202,25 +263,9 @@ def test_llama_cross_attention_transformer_block_inference( if mode == "decode": tt_full_text_mask_expand_11SD = None - pt_out = reference_model.forward( - pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache - ) - - tt_out = tt_model( - tt_x, - xattn_mask=tt_xattn_mask, - full_text_row_masked_out_mask_1NSH=tt_full_text_mask_expand_1NSH, - full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, - xattn_cache=tt_xattn_cache, - mode=mode, - ) - - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) + tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim) - if mode == "prefill": - tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(batch, seq_len, dim) - else: - tt_output_torch = tt_output_torch[0, ..., :batch, :].transpose(0, 1).view(batch, seq_len, dim) passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) logger.info(f"PCC: {pcc_message}") diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 1cb10bd53ce..e8e88222e6e 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -185,7 +185,7 @@ def get_rot_transformation_mat(dhead): def get_single_rot_mat( - dhead, mesh_device, num_devices, start_pos=0, theta: float = 500000.0, use_scaled=True, on_host=False + dhead, mesh_device, num_devices, start_pos=0, theta: float = 500000.0, use_scaled=True, on_host=False, batch=1 ): freqs_unscaled = 1.0 / (theta ** (torch.arange(0, dhead, 2)[: (dhead // 2)].float() / dhead)) if use_scaled: @@ -210,13 +210,13 @@ def get_single_rot_mat( current_rot_mat[torch.arange(1, dhead, 2), torch.arange(0, dhead, 2)] = sin_freqs.clone() return ttnn.from_torch( - current_rot_mat.T.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim + current_rot_mat.T.unsqueeze(0).unsqueeze(0).expand(-1, batch, -1, -1), # 1,batch,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( - rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim + rot_matrix.unsqueeze(0).unsqueeze(0).expand(-1, batch, -1, -1), # 1,batch,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 3df47bd7105..21355b422a0 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -566,6 +566,15 @@ def find_largest_divisor(n, max_divisor=8): fuse_batch=seq_len <= max_seq, ) + self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config( + # using n_heads since xattn repeats KV to match Q + (((self.n_heads // self.num_devices) * seq_len // 64), self.head_dim), + ttnn.CoreGrid(y=8, x=8), + ttnn.ShardStrategy.HEIGHT, + ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) # RMS NORM self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"] = self.create_sharded_norm_config(attn_input_grid) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index ff0e0b79223..787f29248ae 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -131,11 +131,12 @@ def __init__( eps=self.norm_eps, ) - def compute_xattn_kv_cache(self, xattn_tokens): - bsz, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2] + def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): + # Always runs with batch=1 + B, seqlen_y = 1, xattn_tokens.shape[2] MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seqlen_y > MAX_MM_SEQ_LEN: - xattn_tokens = ttnn.reshape(xattn_tokens, [1, bsz * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + xattn_tokens = ttnn.reshape(xattn_tokens, [1, B * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) xk = ttnn.linear( xattn_tokens, @@ -155,13 +156,13 @@ def compute_xattn_kv_cache(self, xattn_tokens): program_config=self.model_config["VISION_XATTN_KV_PROGCFG"](seqlen_y, MAX_MM_SEQ_LEN), ) if seqlen_y > MAX_MM_SEQ_LEN: - xk = ttnn.reshape(xk, [1, bsz, seqlen_y, -1]) - xv = ttnn.reshape(xv, [1, bsz, seqlen_y, -1]) + xk = ttnn.reshape(xk, [1, B, seqlen_y, -1]) + xv = ttnn.reshape(xv, [1, B, seqlen_y, -1]) if self.n_local_kv_heads == 1: # Only a simple reshape required, no need to split - xk = ttnn.reshape(xk, [bsz, 1, seqlen_y, -1]) - xv = ttnn.reshape(xv, [bsz, 1, seqlen_y, -1]) + xk = ttnn.reshape(xk, [B, 1, seqlen_y, -1]) + xv = ttnn.reshape(xv, [B, 1, seqlen_y, -1]) else: # 1, B, S, D -> B, NH, S, DH xk, _, _ = ttnn.experimental.nlp_create_qkv_heads( @@ -180,7 +181,7 @@ def compute_xattn_kv_cache(self, xattn_tokens): ) # def create_heads(x): # x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) - # x = ttnn.reshape(x, [bsz, seqlen_y, self.n_local_kv_heads, self.head_dim]) + # x = ttnn.reshape(x, [B, seqlen_y, self.n_local_kv_heads, self.head_dim]) # x = ttnn.transpose(x, 1, 2) # x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) # return x @@ -193,7 +194,17 @@ def compute_xattn_kv_cache(self, xattn_tokens): # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - return [xk, xv] + + k_cache, v_cache = xattn_cache + + # Work around fill_cache memory constraint by making these sharded + k_fill = ttnn.interleaved_to_sharded(xk, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) + v_fill = ttnn.interleaved_to_sharded(xv, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) + + ttnn.fill_cache(k_cache, k_fill, user_id) + ttnn.fill_cache(v_cache, v_fill, user_id) + + return xattn_cache ### Below is how I would like to implement TMs, but it results in poor PCC xk = ttnn.to_layout(xk, layout=ttnn.ROW_MAJOR_LAYOUT) @@ -311,7 +322,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, else: return output - def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): + def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id): seq_len = x_11SH.shape[-2] # B, S, D assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" @@ -338,12 +349,16 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH xq = self.q_norm(xq, mode="prefill") - xk, xv = xattn_cache - cache_seq_len = xk.shape[-2] + k_cache, v_cache = xattn_cache + cache_seq_len = k_cache.shape[-2] + + k_cache_user = ttnn.slice( + k_cache, (user_id, 0, 0, 0), (user_id + 1, k_cache.shape[1], k_cache.shape[2], k_cache.shape[3]) + ) scores = ttnn.matmul( xq, - ttnn.transpose(xk, -1, -2), + ttnn.transpose(k_cache_user, -1, -2), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config_hifi2, @@ -355,9 +370,12 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH scores = ttnn.add(scores, xattn_mask) scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) + v_cache_user = ttnn.slice( + v_cache, (user_id, 0, 0, 0), (user_id + 1, v_cache.shape[1], v_cache.shape[2], v_cache.shape[3]) + ) output = ttnn.matmul( scores, - xv, + v_cache_user, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config_hifi4, @@ -395,8 +413,10 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH else: return output - def forward(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, mode): + def forward(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, mode, user_id=0): if mode == "prefill": - return self.forward_prefill(x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache) + return self.forward_prefill( + x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id=user_id + ) else: return self.forward_decode(x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index edc1379fb24..74392f3a732 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -17,6 +17,10 @@ from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding +from models.utility_functions import ( + nearest_32, +) + def _get_full_row_masked_out_mask( attn_bias, @@ -215,9 +219,31 @@ def _get_xattn_mask( full_text_row_masked_out_mask, ) - def setup_cache(self, max_batch_size, dtype): + def setup_cache(self, max_batch_size): self.cache_is_setup = True + # Prepare xattn_caches + chunk_length = nearest_32(self.configuration.vision_chunk_ntok) + vision_seq_len = self.configuration.vision_max_num_chunks * chunk_length + xattn_cache = [ + [ + ttnn.from_torch( + torch.zeros( + max_batch_size, self.configuration.n_heads, vision_seq_len, self.configuration.head_dim + ), + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + ) + for _ in range(2) + ] + for l in range(len(self.cross_attention_layers)) + ] + + return xattn_cache + def forward( self, h: ttnn.Tensor, @@ -247,6 +273,7 @@ def forward( full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, full_text_row_masked_out_mask_11SD=full_text_row_masked_out_mask_11SD, mode=mode, + user_id=user_id, ) h = layer( h, @@ -286,5 +313,6 @@ def forward( outputs.append(output) output = ttnn.concat(outputs, dim=-1) + output = ttnn.reshape(output, [1, 1, seq_len, -1]) return output diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 4e00fc384be..7ef3754faeb 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -114,8 +114,8 @@ def __init__( memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - def compute_xattn_kv_cache(self, xattn_tokens): - return self.attention.compute_xattn_kv_cache(xattn_tokens) + def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): + return self.attention.compute_xattn_kv_cache(xattn_tokens, xattn_cache, user_id) def forward( self, @@ -126,6 +126,7 @@ def forward( full_text_row_masked_out_mask_1NSH, xattn_cache, mode, + user_id=0, ): attn_out = self.attention( x_11SH=self.attention_norm(x_11SH, mode=mode), @@ -133,6 +134,7 @@ def forward( xattn_cache=xattn_cache, full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, mode=mode, + user_id=user_id, ) attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) From ef5314614150e3cbd280bc718d4f7fbc9eb0e29c Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 5 Nov 2024 17:24:18 -0800 Subject: [PATCH 07/19] #14519: Fix up Llama vision model. Simple demo works again with batch=1, nonworking for batch>1. --- .../demos/llama3/demo/simple_vision_demo.py | 80 ++++++++++++++----- ..._llama_cross_attention_transformer_text.py | 5 +- .../tt/multimodal/llama_vision_model.py | 48 ++++++----- 3 files changed, 90 insertions(+), 43 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index b147747d0f2..72f9ebe327e 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -52,7 +52,17 @@ def get_prefill_inputs(self, model_input): return images, mask, tokens - def forward_prefill(self, vision_images, vision_mask, tokens, total_len, prefill_len, text_only_inference=False): + def forward_prefill( + self, + vision_images, + vision_mask, + tokens, + xattn_caches, + user_id, + total_len, + prefill_len, + text_only_inference=False, + ): """ Performs vision encode step then text prefill. Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) @@ -61,6 +71,8 @@ def forward_prefill(self, vision_images, vision_mask, tokens, total_len, prefill batch_images=[vision_images], batch_masks=[vision_mask], total_len=total_len, + xattn_caches=xattn_caches, + user_id=user_id, ) position_ids = torch.arange(prefill_len, dtype=torch.long) @@ -72,6 +84,7 @@ def forward_prefill(self, vision_images, vision_mask, tokens, total_len, prefill full_text_row_masked_out_mask, xattn_caches, text_only_inference, + user_id=user_id, ) return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits @@ -103,7 +116,6 @@ def forward_decode( def get_sampler(temperature, top_p, tokenizer): def sample(logits): - logger.info(f"Sampling {logits.shape=}") if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = llama_reference_generation.sample_top_p(probs, top_p) @@ -118,11 +130,15 @@ def sample(logits): return sample -def create_multimodal_model(mesh_device, dtype=ttnn.bfloat16): +def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn.bfloat16): from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer from models.demos.llama3.tt.model_config import TtModelArgs - tt_model_args = TtModelArgs(mesh_device) + tt_model_args = TtModelArgs(mesh_device, max_batch_size=max_batch_size) + # limit length or we'll run out of space + tt_model_args.max_seq_len = max_seq_len + tt_model_args.kv_seq_len = max_seq_len + tt_model_args.sliding_window = max_seq_len checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) model = CrossAttentionTransformer( mesh_device, @@ -131,7 +147,6 @@ def create_multimodal_model(mesh_device, dtype=ttnn.bfloat16): dtype=dtype, configuration=tt_model_args, ) - model.setup_cache(tt_model_args.max_batch_size, torch.float32) # TODO: is a no-op return tt_model_args, model @@ -151,6 +166,7 @@ def create_multimodal_model(mesh_device, dtype=ttnn.bfloat16): @pytest.mark.parametrize( "warmup_iters", (0, 1), + ids=["cold", "warm"], ) @pytest.mark.parametrize( "test_case", @@ -166,8 +182,8 @@ def test_llama_multimodal_demo_text( temperature: float = 0, top_p: float = 0.9, max_seq_len: int = 512, - max_batch_size: int = 4, - max_gen_len: Optional[int] = 100, + max_batch_size: int = 1, + max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): """ @@ -191,28 +207,42 @@ def test_llama_multimodal_demo_text( else: mesh_device.enable_program_cache() mesh_device.enable_async(True) - model_args, model = create_multimodal_model(mesh_device) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) model = LlamaVision(model, model_args, mesh_device) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) - with open(IMG_PATH / "dog.jpg", "rb") as f: - img = PIL_Image.open(f).convert("RGB") + xattn_caches = model.model.setup_cache(model_args.max_batch_size) - dialogs = [] with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") + with open(IMG_PATH / "pasta.jpeg", "rb") as f: + img2 = PIL_Image.open(f).convert("RGB") + + with open(IMG_PATH / "ocr_image.jpeg", "rb") as f: + ocr_image = PIL_Image.open(f).convert("RGB") + + with open(IMG_PATH / "clutter.jpeg", "rb") as f: + clutter = PIL_Image.open(f).convert("RGB") + dialogs = [ - [ - UserMessage( - content=[ - ImageMedia(image=img), - "Describe this image in two sentences", - ], - ) - ], + # image understanding + [UserMessage(content=[ImageMedia(image=img), "Describe this image in two sentences"])], + [UserMessage(content=[ImageMedia(image=img2), "What is for dinner?"])], + [UserMessage(content=[ImageMedia(image=ocr_image), "What is the full text of this image? Do OCR"])], + [UserMessage(content=[ImageMedia(image=clutter), "What objects are in this image?"])], ] + # dialogs = [ + # [ + # UserMessage( + # content=[ + # ImageMedia(image=img), + # "Describe this image in two sentences", + # ], + # ) + # ], + # ] sampler = get_sampler(temperature, top_p, tokenizer) @@ -238,12 +268,18 @@ def test_llama_multimodal_demo_text( tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long) prefill_start = time.perf_counter() xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits = model.forward_prefill( - vision_images, vision_mask, tokens, total_len, prefill_len + vision_images, + vision_mask, + tokens, + xattn_caches, + user_id=0, + total_len=total_len, + prefill_len=prefill_len, ) prefill_end = time.perf_counter() next_token, text = sampler(logits) - logger.info(f"Prefill output: {next_token}:{text}") + # logger.info(f"Prefill output: {next_token}:{text}") tokens[0, prefill_len] = next_token decode_times = [] @@ -262,7 +298,7 @@ def test_llama_multimodal_demo_text( next_token, text = sampler(logits) # Update next token tokens[0, position_id + 1] = next_token - logger.info(f"Decode output {position_id}: {next_token}:{text}") + # logger.info(f"Decode output {position_id}: {next_token}:{text}") decode_end = time.perf_counter() decode_times.append(decode_end - decode_start) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 9548e70fd5f..e67a36b19d5 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -59,8 +59,7 @@ def test_llama_cross_attention_transformer_text_inference( model_args.max_seq_len = 4096 model_args.kv_seq_len = model_args.max_seq_len model_args.sliding_window = model_args.max_seq_len - model_args.n_layers = 1 - model_args.vision_num_cross_attention_layers = 1 + state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -74,7 +73,7 @@ def test_llama_cross_attention_transformer_text_inference( n_heads = model_args.n_heads reference_model = llama_reference_mod.CrossAttentionTransformerText(args=model_args) reference_model.setup_cache(model_args.max_batch_size, torch.float32) - reference_model.load_state_dict(partial_state_dict, strict=False) + reference_model.load_state_dict(partial_state_dict) num_chunks = 4 vision_seq_len = num_chunks * nearest_32(model_args.vision_chunk_ntok) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 780925fb242..99a52d5ea16 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -167,14 +167,16 @@ def __init__( max_num_chunks=configuration.vision_max_num_chunks, ) - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.text_model.setup_cache(max_batch_size, dtype) + def setup_cache(self, max_batch_size): + return self.text_model.setup_cache(max_batch_size) def compute_vision_tokens_masks( self, batch_images: List[List[PIL_Image.Image]], batch_masks: List[List[List[int]]], total_len: int, + xattn_caches, + user_id, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: skip_vision_encoder = False @@ -243,7 +245,8 @@ def compute_vision_tokens_masks( ) xattn_caches = [ - layer.compute_xattn_kv_cache(vision_tokens_tt) for layer in self.text_model.cross_attention_layers + layer.compute_xattn_kv_cache(vision_tokens_tt, xattn_caches[layer_num], user_id=user_id) + for layer_num, layer in enumerate(self.text_model.cross_attention_layers) ] padded_masks = _pad_masks( # torch.Size([1, 512, 1, 4]) batch_masks, @@ -269,12 +272,16 @@ def compute_vision_tokens_masks( ) return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) - def validate_inputs(self, tokens): + def validate_inputs(self, tokens, position_ids): batch, seq_len = tokens.shape[:2] assert batch == 1, f"Only batch 1 is supported, got {batch}" assert ( seq_len <= self.configuration.max_seq_len ), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}" + assert len(position_ids.shape) == 1, f"Position ids must be 1D, got {len(position_ids.shape)}" + assert ( + batch == self.configuration.max_batch_size + ), f"Batch size must match max batch size. Got {batch}, expected {self.configuration.max_batch_size}" def forward( self, @@ -284,8 +291,10 @@ def forward( full_text_row_masked_out_mask: torch.Tensor, xattn_caches: torch.Tensor, text_only_inference: bool = False, + user_id=0, ) -> torch.Tensor: - self.validate_inputs(tokens) + B = tokens.shape[0] + self.validate_inputs(tokens, position_ids) h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) batch, seq_len = h.shape[:2] padded_seq_len = _get_padded_prefill_seqlen(seq_len) @@ -295,7 +304,7 @@ def forward( mode = "prefill" # Prepare TT inputs for text_model tt_position_id = ttnn.from_torch( - position_ids.reshape(batch, seq_len), + position_ids, device=self.mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, @@ -346,15 +355,6 @@ def forward( ) tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) - full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) - tt_full_text_mask_expand_11SD = ttnn.from_torch( - full_text_mask_expand_11SD, - device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), - ) # Check mask shapes, pad if in prefill? if mode == "prefill": h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) @@ -373,16 +373,28 @@ def forward( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) + + full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) + tt_full_text_mask_expand_11SD = ttnn.from_torch( + full_text_mask_expand_11SD, + device=self.mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) else: tt_h = self.configuration.prepare_inputs_ttnn_decode( h, ttnn.DRAM_MEMORY_CONFIG, ) - rot_mats, rot_matrix = get_single_rot_mat( + rot_mats, _ = get_single_rot_mat( self.configuration.head_dim, self.mesh_device, self.configuration.num_devices, start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 + # TODO: B must match max_batch_size, be careful + batch=B, ) transformation_mats = None @@ -427,7 +439,7 @@ def forward( current_pos=tt_position_id, rot_mat=rot_mats, transformation_mats=transformation_mats, - user_id=0, + user_id=user_id, mode=mode, text_only_inference=text_only_inference, ) @@ -437,7 +449,7 @@ def forward( if mode == "prefill": tt_out = tt_out[0].reshape(batch, padded_seq_len, -1)[:, :seq_len, :] # DEBUG: undo padding else: - tt_out = tt_out[0, ..., :batch, :].transpose(0, 1).reshape(batch, seq_len, -1) + tt_out = tt_out[0, :, :batch, :].reshape(batch, seq_len, -1) return tt_out From 7d50e6c55b3db527916cb900317b948556cc3fe8 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 6 Nov 2024 11:29:40 -0800 Subject: [PATCH 08/19] #14519: Refactor LlamaVision class to clean up separation of input prep and model execution with ttnn tensors. --- .../demos/llama3/demo/simple_vision_demo.py | 111 +++--- .../tt/multimodal/llama_vision_model.py | 373 ++++++++++++------ 2 files changed, 315 insertions(+), 169 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 72f9ebe327e..7dad35d2098 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -61,7 +61,6 @@ def forward_prefill( user_id, total_len, prefill_len, - text_only_inference=False, ): """ Performs vision encode step then text prefill. @@ -75,18 +74,33 @@ def forward_prefill( user_id=user_id, ) - position_ids = torch.arange(prefill_len, dtype=torch.long) + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) = self.model.prepare_inputs_prefill( + tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len + ) - logits = self.model.forward( - position_ids, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, + tt_logits = self.model.ttnn_prefill_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, xattn_caches, - text_only_inference, - user_id=user_id, + tt_position_id, + rot_mats, + transformation_mats, + user_id, ) + B = 1 + logits = self.model.process_output_prefill(tt_logits, B, prefill_len) + return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits def forward_decode( @@ -96,21 +110,40 @@ def forward_decode( cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, - text_only_inference=False, ): """ Performs text decode step. Returns logits """ - position_ids = torch.tensor([position_id], dtype=torch.long) - logits = self.model.forward( - position_ids, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, + + # forward_decode should be traced callable + # decorator does compilation, capture, execute + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + transformation_mats, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + tt_logits = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, xattn_caches, - text_only_inference, + tt_position_id, + rot_mats, + transformation_mats, ) + + B = tokens.shape[0] + S = 1 + logits = self.model.process_output_decode(tt_logits, B, S) return logits @@ -159,10 +192,6 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn ], indirect=True, ) -@pytest.mark.parametrize( - "target", - ("tt", "cpu"), -) @pytest.mark.parametrize( "warmup_iters", (0, 1), @@ -176,7 +205,6 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn ) def test_llama_multimodal_demo_text( mesh_device, - target, warmup_iters, test_case, temperature: float = 0, @@ -192,25 +220,12 @@ def test_llama_multimodal_demo_text( ckpt_dir = os.environ["LLAMA_DIR"] tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - if target == "cpu": - generator = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - model_args = generator.args - model = LlamaVision(generator.model, model_args, None) - tokenizer = generator.tokenizer - formatter = generator.formatter - else: - mesh_device.enable_program_cache() - mesh_device.enable_async(True) - model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) - model = LlamaVision(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) + model = LlamaVision(model, model_args, mesh_device) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) xattn_caches = model.model.setup_cache(model_args.max_batch_size) @@ -228,25 +243,14 @@ def test_llama_multimodal_demo_text( dialogs = [ # image understanding - [UserMessage(content=[ImageMedia(image=img), "Describe this image in two sentences"])], + [UserMessage(content=[ImageMedia(image=img), "Write a haiku for this image."])], [UserMessage(content=[ImageMedia(image=img2), "What is for dinner?"])], [UserMessage(content=[ImageMedia(image=ocr_image), "What is the full text of this image? Do OCR"])], [UserMessage(content=[ImageMedia(image=clutter), "What objects are in this image?"])], ] - # dialogs = [ - # [ - # UserMessage( - # content=[ - # ImageMedia(image=img), - # "Describe this image in two sentences", - # ], - # ) - # ], - # ] sampler = get_sampler(temperature, top_p, tokenizer) - print(f"Running text completion on {target}") for iter_num in range(warmup_iters + 1): for dialog in dialogs: for msg in dialog: @@ -283,7 +287,6 @@ def test_llama_multimodal_demo_text( tokens[0, prefill_len] = next_token decode_times = [] - # Iterate over decode for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 99a52d5ea16..274381d21a3 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -279,30 +279,21 @@ def validate_inputs(self, tokens, position_ids): seq_len <= self.configuration.max_seq_len ), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}" assert len(position_ids.shape) == 1, f"Position ids must be 1D, got {len(position_ids.shape)}" - assert ( - batch == self.configuration.max_batch_size - ), f"Batch size must match max batch size. Got {batch}, expected {self.configuration.max_batch_size}" - def forward( - self, - position_ids: torch.Tensor, - tokens: torch.Tensor, - cross_attention_masks: torch.Tensor, - full_text_row_masked_out_mask: torch.Tensor, - xattn_caches: torch.Tensor, - text_only_inference: bool = False, - user_id=0, - ) -> torch.Tensor: - B = tokens.shape[0] + def prepare_inputs_common(self, position_ids, tokens): self.validate_inputs(tokens, position_ids) h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) - batch, seq_len = h.shape[:2] - padded_seq_len = _get_padded_prefill_seqlen(seq_len) - if seq_len == 1: - mode = "decode" - else: - mode = "prefill" - # Prepare TT inputs for text_model + return h + + def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len): + B = tokens.shape[0] + assert B == 1, f"Only batch 1 is supported, got {B}" + # S = tokens.shape[1] # TODO: Get B, S from tokens when we don't pass full tokens around + S = prefill_len + position_ids = torch.arange(S, dtype=torch.long) + h = self.prepare_inputs_common(position_ids, tokens) + padded_seq_len = _get_padded_prefill_seqlen(S) + tt_position_id = ttnn.from_torch( position_ids, device=self.mesh_device, @@ -314,15 +305,12 @@ def forward( xattn_mask = cross_attention_masks[:, :, position_ids] xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1) - if mode == "prefill": - xattn_mask_expand = torch.nn.functional.pad( - xattn_mask_expand, - (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]), - "constant", - get_negative_inf_value(torch.float32), - ) - if mode == "decode": - xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() + xattn_mask_expand = torch.nn.functional.pad( + xattn_mask_expand, + (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]), + "constant", + get_negative_inf_value(torch.float32), + ) tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, @@ -335,16 +323,108 @@ def forward( tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] - if mode == "prefill": - full_text_mask = torch.nn.functional.pad( - full_text_mask, (0, 0, 0, padded_seq_len - full_text_mask.shape[2]), "constant", 0 - ) + full_text_mask = torch.nn.functional.pad( + full_text_mask, (0, 0, 0, padded_seq_len - full_text_mask.shape[2]), "constant", 0 + ) full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) - if mode == "decode": - full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() + tt_full_text_mask_expand_1NSH = ttnn.from_torch( + full_text_mask_expand_1NSH, + device=self.mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) + h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) + tt_h = self.configuration.prepare_inputs_ttnn_prefill( + h, + ) + rot_mats = get_prefill_rot_mat( + self.configuration.head_dim, self.configuration.max_seq_len, self.mesh_device, seq_len=S + ) + transformation_mat_torch = get_rot_transformation_mat(self.configuration.head_dim) + transformation_mats = ttnn.as_tensor( + transformation_mat_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) + tt_full_text_mask_expand_11SD = ttnn.from_torch( + full_text_mask_expand_11SD, + device=self.mesh_device, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + return ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) + + def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): + B = tokens.shape[0] + assert ( + B == self.configuration.max_batch_size + ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" + S = 1 + position_ids = torch.tensor([position_id], dtype=torch.long) + h = self.prepare_inputs_common(position_ids, tokens) + + tt_position_id = ttnn.from_torch( + position_ids, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + xattn_mask = cross_attention_masks[:, :, position_ids] + xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1) + xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() + + tt_xattn_mask = ttnn.from_torch( + xattn_mask_expand, + device=self.mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + xattn_mask.shape[-1], + ], + [S, B, 32, xattn_mask.shape[-1]], + ), + ) + + full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] + full_text_mask_expand_1NSH = full_text_mask.expand( + -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim + ) + full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=self.mesh_device, @@ -354,81 +434,97 @@ def forward( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) + tt_full_text_mask_expand_1NSH = ttnn.reshape( + tt_full_text_mask_expand_1NSH, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + self.configuration.head_dim, + ], + [ + S, + B, + 32, + self.configuration.head_dim, + ], + ), + ) - # Check mask shapes, pad if in prefill? - if mode == "prefill": - h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) - tt_h = self.configuration.prepare_inputs_ttnn_prefill( - h, - ) - rot_mats = get_prefill_rot_mat( - self.configuration.head_dim, self.configuration.max_seq_len, self.mesh_device, seq_len=seq_len - ) - transformation_mat_torch = get_rot_transformation_mat(self.configuration.head_dim) - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) + tt_h = self.configuration.prepare_inputs_ttnn_decode( + h, + ttnn.DRAM_MEMORY_CONFIG, + ) + rot_mats, _ = get_single_rot_mat( + self.configuration.head_dim, + self.mesh_device, + self.configuration.num_devices, + start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 + # TODO: B must match max_batch_size, be careful + batch=B, + ) + transformation_mats = None + tt_full_text_mask_expand_11SD = None - full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) - tt_full_text_mask_expand_11SD = ttnn.from_torch( - full_text_mask_expand_11SD, - device=self.mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_h = self.configuration.prepare_inputs_ttnn_decode( - h, - ttnn.DRAM_MEMORY_CONFIG, - ) - rot_mats, _ = get_single_rot_mat( - self.configuration.head_dim, - self.mesh_device, - self.configuration.num_devices, - start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 - # TODO: B must match max_batch_size, be careful - batch=B, - ) - transformation_mats = None - - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [ - seq_len, - batch, - self.configuration.n_heads // self.configuration.num_devices, - xattn_mask.shape[-1], - ], - [seq_len, batch, 32, xattn_mask.shape[-1]], - ), - ) - tt_full_text_mask_expand_1NSH = ttnn.reshape( - tt_full_text_mask_expand_1NSH, - shape=ttnn.Shape( - [ - seq_len, - batch, - self.configuration.n_heads // self.configuration.num_devices, - self.configuration.head_dim, - ], - [ - seq_len, - batch, - 32, - self.configuration.head_dim, - ], - ), - ) + return ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) + + def process_output_prefill(self, logits, B, S): + padded_seq_len = _get_padded_prefill_seqlen(S) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() + tt_out = tt_out[0].reshape(B, padded_seq_len, -1)[:, :S, :] + return tt_out + + def process_output_decode(self, logits, B, S): + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() + tt_out = tt_out[:, :, :B, :].reshape(B, S, -1) + return tt_out - tt_full_text_mask_expand_11SD = None + def forward( + self, + position_ids: torch.Tensor, + tokens: torch.Tensor, + cross_attention_masks: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + xattn_caches, # list of ttnn tensors + text_only_inference: bool = False, + user_id=0, + ) -> torch.Tensor: + """ + This method takes torch tensors in, returns torch tensors. + It also determines whether or not to run prefill or decode. + """ + B = tokens.shape[0] + S = position_ids.shape[0] # TODO: Get B, S from tokens when we don't pass full tokens around + mode = "decode" if S == 1 else "prefill" + + # pos_arg is used in preparation in different ways based on mode + pos_arg = S if mode == "prefill" else position_ids.item() + prepare_fn = self.prepare_inputs_decode if mode == "decode" else self.prepare_inputs_prefill + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) = prepare_fn( + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + pos_arg, + ) logits = self.text_model.forward( tt_h, @@ -444,14 +540,61 @@ def forward( text_only_inference=text_only_inference, ) - tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) - tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() - if mode == "prefill": - tt_out = tt_out[0].reshape(batch, padded_seq_len, -1)[:, :seq_len, :] # DEBUG: undo padding - else: - tt_out = tt_out[0, :, :batch, :].reshape(batch, seq_len, -1) + output_fn = self.process_output_decode if mode == "decode" else self.process_output_prefill + return output_fn(logits, B, S) - return tt_out + def ttnn_prefill_forward( + self, + h, + xattn_mask, + full_text_mas_expand_1NSH, + full_text_mask_expand_11SD, + xattn_caches, + position_id, + rot_mats, + transformation_mats, + user_id, + ): + """ + This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. + """ + return self.text_model.forward( + h, + xattn_mask=xattn_mask, + full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, + full_text_row_masked_out_mask_11SD=full_text_mask_expand_11SD, + xattn_caches=xattn_caches, + current_pos=position_id, + rot_mat=rot_mats, + transformation_mats=transformation_mats, + user_id=user_id, + mode="prefill", + ) + + def ttnn_decode_forward( + self, + h, + xattn_mask, + full_text_mas_expand_1NSH, + xattn_caches, + position_id, + rot_mats, + transformation_mats, + ): + """ + This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors. + """ + return self.text_model.forward( + h, + xattn_mask=xattn_mask, + full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, + full_text_row_masked_out_mask_11SD=None, + xattn_caches=xattn_caches, + current_pos=position_id, + rot_mat=rot_mats, + transformation_mats=transformation_mats, + mode="decode", + ) def _stack_images( From ff2f7baeb1a2d70dbd53ef08d1d29e3f1c8b4e9d Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 7 Nov 2024 06:19:13 -0800 Subject: [PATCH 09/19] #14519: Don't pass full token tensor into decode and prefill --- .../demos/llama3/demo/simple_vision_demo.py | 53 ++++++++++++------- models/demos/llama3/tests/conftest.py | 47 ++++++++++++++++ .../tt/multimodal/llama_vision_model.py | 5 +- 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 7dad35d2098..c8718256e62 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -42,17 +42,7 @@ def __init__(self, model, model_args, mesh_device, vllm=False): self.mesh_device = mesh_device self.vllm = vllm - def get_prefill_inputs(self, model_input): - """ - Responsible for taking model_input: ModelInput and returning vision_images, vision_mask, tokens - """ - images = model_input.vision.images - mask = model_input.vision.mask - tokens = model_input.tokens - - return images, mask, tokens - - def forward_prefill( + def prefill_forward_single_user( self, vision_images, vision_mask, @@ -66,6 +56,7 @@ def forward_prefill( Performs vision encode step then text prefill. Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) """ + B = tokens.shape[0] xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( batch_images=[vision_images], batch_masks=[vision_mask], @@ -98,12 +89,11 @@ def forward_prefill( user_id, ) - B = 1 logits = self.model.process_output_prefill(tt_logits, B, prefill_len) return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits - def forward_decode( + def decode_forward( self, position_id, tokens, @@ -118,6 +108,9 @@ def forward_decode( # forward_decode should be traced callable # decorator does compilation, capture, execute + # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. + # S = 1 + B, S = tokens.shape ( tt_h, @@ -141,11 +134,22 @@ def forward_decode( transformation_mats, ) - B = tokens.shape[0] - S = 1 logits = self.model.process_output_decode(tt_logits, B, S) return logits + def capture_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Captures a trace for the decode_forward method. + """ + pass + def get_sampler(temperature, top_p, tokenizer): def sample(logits): @@ -262,7 +266,9 @@ def test_llama_multimodal_demo_text( model_input = formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) # Do initial prefill - vision_images, vision_mask, prompt_tokens = model.get_prefill_inputs(model_input) + vision_images = model_input.vision.images + vision_mask = model_input.vision.mask + prompt_tokens = model_input.tokens prefill_len = len(prompt_tokens) total_len = prefill_len + max_gen_len # Prepares mask for full length of output # Create tokens tensor @@ -271,10 +277,16 @@ def test_llama_multimodal_demo_text( tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long) prefill_start = time.perf_counter() - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits = model.forward_prefill( + prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S + ( + xattn_caches, + cross_attention_masks, + full_text_row_masked_out_mask, + logits, + ) = model.prefill_forward_single_user( vision_images, vision_mask, - tokens, + prompt_tokens_tensor, xattn_caches, user_id=0, total_len=total_len, @@ -291,9 +303,10 @@ def test_llama_multimodal_demo_text( for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() position_id = prefill_len + gen_idx - logits = model.forward_decode( + next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S + logits = model.decode_forward( position_id, - tokens, + next_token_tensor, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, diff --git a/models/demos/llama3/tests/conftest.py b/models/demos/llama3/tests/conftest.py index d15dee818f1..822c9621072 100644 --- a/models/demos/llama3/tests/conftest.py +++ b/models/demos/llama3/tests/conftest.py @@ -8,3 +8,50 @@ @pytest.fixture(autouse=True) def ensure_gc(): gc.collect() + + +def traced(callable): + """ + Test it locally, get it into ttnn mainline + """ + # TODO: release trace on delete or ??? + trace_id = None + args_device = None + kwargs_device = None + outputs = None + + def create_device_inputs(*args, **kwargs): + # allocate device tensors for each arg which is on host + # don't copy + nonlocal args_device + nonlocal kwargs_device + + def copy_inputs_to_device(*args, **kwargs): + # copy any host tensors to device + + # Check that kwargs keys matches kwargs_device keys + # check that args len matches args_device len + nonlocal args_device + nonlocal kwargs_device + pass + + def wrapper(self, *args, **kwargs): + nonlocal trace_id + nonlocal outputs + if not trace_id: + create_device_inputs(args, kwargs) + copy_inputs_to_device(args, kwargs) + ret = callable(self, *args, **kwargs) + outputs = ret + trace_id = ttnn.capture_trace(...) + callable(self, *args, **kwargs) + ttnn.end_trace(...) + return ret + # check that inputs, outputs are host tensors + # or if an input is on device, do nothing + # copy new inputs to inputs, return outputs + copy_inputs_to_device(args, kwargs) + ttnn.execute_trace(trace_id) + return outputs + + return wrapper diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 274381d21a3..818c7f62d7b 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -282,14 +282,13 @@ def validate_inputs(self, tokens, position_ids): def prepare_inputs_common(self, position_ids, tokens): self.validate_inputs(tokens, position_ids) - h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) + h = self.text_model.get_partially_trainable_embedding(tokens) return h def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len): B = tokens.shape[0] assert B == 1, f"Only batch 1 is supported, got {B}" - # S = tokens.shape[1] # TODO: Get B, S from tokens when we don't pass full tokens around - S = prefill_len + S = tokens.shape[1] position_ids = torch.arange(S, dtype=torch.long) h = self.prepare_inputs_common(position_ids, tokens) padded_seq_len = _get_padded_prefill_seqlen(S) From 51e2c89ab45829b2f0637a881a3cb741422a8a29 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 7 Nov 2024 10:03:36 -0800 Subject: [PATCH 10/19] #14519: Fix rebase issues --- .../multimodal/test_llama_cross_attention.py | 20 +++++++++++++------ ..._llama_cross_attention_transformer_text.py | 12 ++++++++--- .../multimodal/test_llama_cross_block.py | 20 ++++++++++++------- models/demos/llama3/tt/model_config.py | 13 ++++++++++-- .../tt/multimodal/llama_cross_attention.py | 4 ++-- .../tt/multimodal/llama_vision_model.py | 2 +- 6 files changed, 50 insertions(+), 21 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 73fc48f44b1..e4830421311 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -32,7 +32,13 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) +@pytest.mark.parametrize( + "batch", + (1,), + ids=[ + "batch_1", + ], +) def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 pcc_required = 0.99 @@ -103,7 +109,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset for b in range(batch): tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( tt_xattn_tokens[b : b + 1], - force_replicate=True, + force_replicated=True, ) tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) tt_xattn_cache_torch = [ @@ -202,8 +208,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset user_id=b, ) - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) + tt_output_torch = tt_output_torch[..., :seq_len, :].view(1, seq_len, dim) outputs.append(tt_output_torch) tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) @@ -213,6 +219,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset ttnn.DRAM_MEMORY_CONFIG, force_replicated=True, ) + tt_x = ttnn.interleaved_to_sharded(tt_x, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"]) + xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, @@ -255,8 +263,8 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset mode=mode, ) - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim) passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index e67a36b19d5..84a04bfd372 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -39,7 +39,13 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) +@pytest.mark.parametrize( + "batch", + (1,), + ids=[ + "batch_1", + ], +) @torch.no_grad() def test_llama_cross_attention_transformer_text_inference( text_seq_len, @@ -112,7 +118,7 @@ def test_llama_cross_attention_transformer_text_inference( for b in range(batch): tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill( tt_vision_tokens[b : b + 1], - force_replicate=True, + force_replicated=True, ) tt_xattn_cache = [ @@ -233,7 +239,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), ) rot_mats = get_prefill_rot_mat( diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index e95bc5dc649..04112167154 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -28,7 +28,13 @@ ], indirect=True, ) -@pytest.mark.parametrize("batch", (1, 2), ids=["batch_1", "batch_2"]) +@pytest.mark.parametrize( + "batch", + (1,), + ids=[ + "batch_1", + ], +) def test_llama_cross_attention_transformer_block_inference( text_seq_len, batch, mesh_device, use_program_cache, reset_seeds, ensure_gc ): @@ -97,7 +103,7 @@ def test_llama_cross_attention_transformer_block_inference( for b in range(batch): tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( tt_xattn_tokens[b : b + 1], - force_replicate=True, + force_replicated=True, ) tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) tt_xattn_cache_torch = [ @@ -195,7 +201,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), ) tt_out = tt_model( tt_tensor_x, @@ -207,8 +213,8 @@ def test_llama_cross_attention_transformer_block_inference( user_id=b, ) - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, ..., :seq_len, :].view(1, seq_len, dim) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) + tt_output_torch = tt_output_torch[..., :seq_len, :].view(1, seq_len, dim) outputs.append(tt_output_torch) tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) @@ -263,8 +269,8 @@ def test_llama_cross_attention_transformer_block_inference( if mode == "decode": tt_full_text_mask_expand_11SD = None - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) - tt_output_torch = tt_output_torch[0, :, :batch, :].reshape(batch, seq_len, dim) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) + tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim) passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) logger.info(comp_allclose(pt_out, tt_output_torch)) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 21355b422a0..4c431643825 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -566,10 +566,19 @@ def find_largest_divisor(n, max_divisor=8): fuse_batch=seq_len <= max_seq, ) + xattn_cache_y_cores = ( + 16 // self.num_devices + ) # Based on seqlen, this formula gives us a valid number of y cores + xattn_cache_x_cores = 8 self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config( # using n_heads since xattn repeats KV to match Q - (((self.n_heads // self.num_devices) * seq_len // 64), self.head_dim), - ttnn.CoreGrid(y=8, x=8), + ( + nearest_32( + (self.n_heads // self.num_devices) * seq_len // (xattn_cache_y_cores * xattn_cache_x_cores) + ), + self.head_dim, + ), + ttnn.CoreGrid(y=xattn_cache_y_cores, x=xattn_cache_x_cores), ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, use_height_and_width_as_shard_shape=True, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 787f29248ae..87730dae903 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -133,7 +133,8 @@ def __init__( def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): # Always runs with batch=1 - B, seqlen_y = 1, xattn_tokens.shape[2] + B, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2] + assert B == 1, "Batch size must be 1" MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ if seqlen_y > MAX_MM_SEQ_LEN: xattn_tokens = ttnn.reshape(xattn_tokens, [1, B * seqlen_y // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) @@ -146,7 +147,6 @@ def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): compute_kernel_config=self.compute_kernel_config_hifi4, program_config=self.model_config["VISION_XATTN_KV_PROGCFG"](seqlen_y, MAX_MM_SEQ_LEN), ) - xv = ttnn.linear( xattn_tokens, self.wv, diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 818c7f62d7b..94f65ea13af 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -362,7 +362,7 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), ) return ( From 0b7807f130d8037bf8dd087342a21641334b2ec5 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 7 Nov 2024 13:04:23 -0800 Subject: [PATCH 11/19] #14519: Refactored decode input preparation to separate host tensor creation and device tensor transformations. Enabled tracing in simple_vision_demo with an easy trace function --- .../demos/llama3/demo/simple_vision_demo.py | 205 +++++++++++++++++- models/demos/llama3/tt/model_config.py | 6 +- .../tt/multimodal/llama_vision_model.py | 165 +++++++++----- 3 files changed, 318 insertions(+), 58 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index c8718256e62..a5bf099d027 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -119,7 +119,7 @@ def decode_forward( _, tt_position_id, rot_mats, - transformation_mats, + _, ) = self.model.prepare_inputs_decode( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id ) @@ -131,7 +131,6 @@ def decode_forward( xattn_caches, tt_position_id, rot_mats, - transformation_mats, ) logits = self.model.process_output_decode(tt_logits, B, S) @@ -148,7 +147,171 @@ def capture_trace( """ Captures a trace for the decode_forward method. """ - pass + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + # Compile run + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + # Get inputs ready for trace run + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.model.copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) + ) + + trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + B = tokens.shape[0] + # Do on-device transformations of inputs before forward + tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=B, + ) + + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + xattn_caches, + tt_position_id, + rot_mats, + ) + + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + + return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats + + def decode_forward_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, # TODO: unused since captured in trace? + trace_id, + trace_logits_rm, + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + self.model.copy_host_to_device( + host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), + device_tensors=( + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ), + ) + + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + B, S = tokens.shape + logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) + + return logits + + def easy_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Tracing is easy! Just call this method and you'll run traced + """ + if not hasattr(self, "trace_id"): + ( + trace_id, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.capture_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + self.trace_id = trace_id + self.trace_inputs = { + "tt_h": tt_h, + "tt_xattn_mask": tt_xattn_mask, + "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, + "tt_position_id": tt_position_id, + "rot_mats": rot_mats, + } + self.trace_outputs = { + "tt_logits_rm": tt_logits_rm, + } + + return self.decode_forward_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + self.trace_id, + self.trace_outputs["tt_logits_rm"], + self.trace_inputs["tt_h"], + self.trace_inputs["tt_xattn_mask"], + self.trace_inputs["tt_full_text_mask_expand_1NSH"], + self.trace_inputs["tt_position_id"], + self.trace_inputs["rot_mats"], + ) def get_sampler(temperature, top_p, tokenizer): @@ -207,6 +370,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn "normal", ], ) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( mesh_device, warmup_iters, @@ -300,17 +464,48 @@ def test_llama_multimodal_demo_text( decode_times = [] + # Capture trace + # next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S + # trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats = model.capture_trace( + # prefill_len, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # ) + for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() position_id = prefill_len + gen_idx next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - logits = model.decode_forward( + # logits = model.decode_forward( + # position_id, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # ) + logits = model.easy_trace( position_id, next_token_tensor, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, ) + # logits = model.decode_forward_trace( + # position_id, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # trace_id, + # tt_logits_rm, + # tt_h, + # tt_xattn_mask, + # tt_full_text_mask_expand_1NSH, + # tt_position_id, + # rot_mats + # ) next_token, text = sampler(logits) # Update next token tokens[0, position_id + 1] = next_token @@ -334,3 +529,5 @@ def test_llama_multimodal_demo_text( logger.info(f"Prefill time: {prefill_time_ms:.2f} ms") decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 logger.info(f"Decode time: {decode_time_ms:.2f} ms") + + # ttnn.release_trace(model.mesh_device, trace_id) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 4c431643825..d8e8bf7c4fe 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -625,7 +625,7 @@ def ccl_topology(self): return ttnn.Topology.Linear return None - def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): + def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False): """ Prepare inputs for decode mode. x: (batch, seq, dim) @@ -665,11 +665,11 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): if torch.is_tensor(x): x = ttnn.from_torch( x, - device=self.mesh_device, + device=self.mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper, - memory_config=input_mem_cfg, + memory_config=input_mem_cfg if not on_host else None, ) else: # Convert the row major layout from embedding back to tile layout x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 94f65ea13af..80a27df0679 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -376,20 +376,58 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma ) def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + _transformation_mats, + ) = self.prepare_decode_inputs_host(tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)) + + tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=tokens.shape[0], + ) + + return ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + _transformation_mats, + ) + + def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): B = tokens.shape[0] assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" - S = 1 position_ids = torch.tensor([position_id], dtype=torch.long) h = self.prepare_inputs_common(position_ids, tokens) + tt_h = self.configuration.prepare_inputs_ttnn_decode( + h, + ttnn.DRAM_MEMORY_CONFIG, + on_host=True, + ) tt_position_id = ttnn.from_torch( position_ids, - device=self.mesh_device, + device=None, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) @@ -399,25 +437,11 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, - device=self.mesh_device, + device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [ - S, - B, - self.configuration.n_heads // self.configuration.num_devices, - xattn_mask.shape[-1], - ], - [S, B, 32, xattn_mask.shape[-1]], - ), - ) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] full_text_mask_expand_1NSH = full_text_mask.expand( @@ -426,35 +450,12 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, - device=self.mesh_device, + device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) - tt_full_text_mask_expand_1NSH = ttnn.reshape( - tt_full_text_mask_expand_1NSH, - shape=ttnn.Shape( - [ - S, - B, - self.configuration.n_heads // self.configuration.num_devices, - self.configuration.head_dim, - ], - [ - S, - B, - 32, - self.configuration.head_dim, - ], - ), - ) - tt_h = self.configuration.prepare_inputs_ttnn_decode( - h, - ttnn.DRAM_MEMORY_CONFIG, - ) rot_mats, _ = get_single_rot_mat( self.configuration.head_dim, self.mesh_device, @@ -462,7 +463,9 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 # TODO: B must match max_batch_size, be careful batch=B, + on_host=True, ) + transformation_mats = None tt_full_text_mask_expand_11SD = None @@ -476,15 +479,72 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas transformation_mats, ) - def process_output_prefill(self, logits, B, S): + def copy_host_to_device(self, host_tensors, device_tensors=None): + """ + Helper function which copies host tensors to device tensors + """ + if device_tensors is None: + ret = [] + for i in range(len(host_tensors)): + on_device = ttnn.to_device(host_tensors[i], device=self.mesh_device) + ret.append(on_device) + return ret + else: + for i in range(len(host_tensors)): + ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i]) + return device_tensors + + def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): + """ + Does any transformations on device tensors which are necessary before ttnn_decode_forward + """ + print("transforming xattn mask") + assert ( + B == self.configuration.max_batch_size + ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" + S = 1 + + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + tt_xattn_mask.shape[-1], + ], + [S, B, 32, tt_xattn_mask.shape[-1]], + ), + ) + tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) + tt_full_text_mask_expand_1NSH = ttnn.reshape( + tt_full_text_mask_expand_1NSH, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + self.configuration.head_dim, + ], + [ + S, + B, + 32, + self.configuration.head_dim, + ], + ), + ) + + return (tt_xattn_mask, tt_full_text_mask_expand_1NSH) + + def process_output_prefill(self, tt_out, B, S): padded_seq_len = _get_padded_prefill_seqlen(S) - tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() tt_out = tt_out[0].reshape(B, padded_seq_len, -1)[:, :S, :] return tt_out - def process_output_decode(self, logits, B, S): - tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + def process_output_decode(self, tt_out, B, S): tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() tt_out = tt_out[:, :, :B, :].reshape(B, S, -1) return tt_out @@ -538,9 +598,10 @@ def forward( mode=mode, text_only_inference=text_only_inference, ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) output_fn = self.process_output_decode if mode == "decode" else self.process_output_prefill - return output_fn(logits, B, S) + return output_fn(tt_out, B, S) def ttnn_prefill_forward( self, @@ -557,7 +618,7 @@ def ttnn_prefill_forward( """ This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. """ - return self.text_model.forward( + logits = self.text_model.forward( h, xattn_mask=xattn_mask, full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, @@ -569,6 +630,8 @@ def ttnn_prefill_forward( user_id=user_id, mode="prefill", ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + return tt_out def ttnn_decode_forward( self, @@ -578,12 +641,11 @@ def ttnn_decode_forward( xattn_caches, position_id, rot_mats, - transformation_mats, ): """ This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors. """ - return self.text_model.forward( + logits = self.text_model.forward( h, xattn_mask=xattn_mask, full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, @@ -591,9 +653,10 @@ def ttnn_decode_forward( xattn_caches=xattn_caches, current_pos=position_id, rot_mat=rot_mats, - transformation_mats=transformation_mats, mode="decode", ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + return tt_out def _stack_images( From 71f0727dc04bab237162377b7853ddbb3ba22823 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 8 Nov 2024 05:52:22 -0800 Subject: [PATCH 12/19] #14519: Implement LlamaVision generation class which plugs into existing generation pipelines. --- .../demos/llama3/demo/multimodal_demo_chat.py | 82 ++-- .../demos/llama3/demo/multimodal_demo_text.py | 64 +-- .../demos/llama3/demo/simple_vision_demo.py | 364 ++------------- .../tt/multimodal/llama_vision_model.py | 1 - .../llama3/tt/multimodal/vision_generator.py | 429 ++++++++++++++++++ 5 files changed, 522 insertions(+), 418 deletions(-) create mode 100644 models/demos/llama3/tt/multimodal/vision_generator.py diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index 7b39fb3db61..ca3d5b498e3 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -8,19 +8,21 @@ from PIL import Image as PIL_Image from termcolor import cprint -from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model -import llama_models.llama3.reference_impl.generation as llama_reference_generation +import pytest +import os +import ttnn +import llama_models.llama3.reference_impl.generation as llama_reference_generation +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ImageMedia, UserMessage from pkg_resources import resource_filename IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) -import torch -import pytest -import os -import ttnn +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision +from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model @pytest.mark.parametrize( @@ -36,39 +38,36 @@ "target", ("tt", "cpu"), ) -@pytest.mark.parametrize( - "warmup_iters", - (0, 1), -) def test_llama_multimodal_demo_chat( mesh_device, target, - warmup_iters, temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, - max_batch_size: int = 4, + max_batch_size: int = 1, max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): - mesh_device.enable_program_cache() - mesh_device.enable_async(True) ckpt_dir = os.environ["LLAMA_DIR"] tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") - generator = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - - if target == "tt": + if target == "cpu": + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + else: logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") - model = create_multimodal_model(generator.args, mesh_device) - generator.model = model + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) # image understanding dialogs = [] @@ -85,26 +84,21 @@ def test_llama_multimodal_demo_chat( ) ], ] - # text only - dialogs += [ - [UserMessage(content="what is the recipe of mayonnaise in two sentences?")], - ] print(f"Running text completion on {target}") - for _ in range(warmup_iters + 1): - for dialog in dialogs: - result = generator.chat_completion( - dialog, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) + for dialog in dialogs: + result = generator.chat_completion( + dialog, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) - for msg in dialog: - print(f"{msg.role.capitalize()}: {msg.content}\n") + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") - out_message = result.generation - print(f"> {out_message.role.capitalize()}: {out_message.content}") - for t in out_message.tool_calls: - print(f" Tool call: {t.tool_name} ({t.arguments})") - print("\n==================================\n") + out_message = result.generation + print(f"> {out_message.role.capitalize()}: {out_message.content}") + for t in out_message.tool_calls: + print(f" Tool call: {t.tool_name} ({t.arguments})") + print("\n==================================\n") diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index 102b03975e4..2029c43458b 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -8,36 +8,22 @@ from PIL import Image as PIL_Image from termcolor import cprint -import llama_models.llama3.reference_impl.generation as llama_reference_generation +import pytest +import os +import ttnn +import llama_models.llama3.reference_impl.generation as llama_reference_generation from llama_models.llama3.api.datatypes import ImageMedia +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat + from pkg_resources import resource_filename IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) -import torch -import pytest -import os -import ttnn - - -def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): - from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer - from models.demos.llama3.tt.model_config import TtModelArgs - - tt_model_args = TtModelArgs(mesh_device) - checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) - model = CrossAttentionTransformer( - model_args, - mesh_device, - checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - ) - model.setup_cache(model_args.max_batch_size, torch.float32) - return model +from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision @pytest.mark.parametrize( @@ -64,28 +50,30 @@ def test_llama_multimodal_demo_text( temperature: float = 0.5, top_p: float = 0.9, max_seq_len: int = 512, - max_batch_size: int = 4, + max_batch_size: int = 1, max_gen_len: Optional[int] = 200, model_parallel_size: Optional[int] = None, ): - mesh_device.enable_program_cache() - mesh_device.enable_async(True) ckpt_dir = os.environ["LLAMA_DIR"] tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") - generator = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - - if target == "tt": + if target == "cpu": + generator = llama_reference_generation.Llama.build( + ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + else: logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") - model = create_multimodal_model(generator.args, mesh_device) - generator.model = model + mesh_device.enable_program_cache() + mesh_device.enable_async(True) + model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") @@ -100,8 +88,6 @@ def test_llama_multimodal_demo_text( clutter = PIL_Image.open(f).convert("RGB") interleaved_contents = [ - # text only - "The color of the sky is blue but sometimes it can also be", # image understanding [ImageMedia(image=img), "If I had to write a haiku for this one"], [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index a5bf099d027..964554280ee 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -10,8 +10,7 @@ import llama_models.llama3.reference_impl.generation as llama_reference_generation from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.llama3.api.chat_format import ChatFormat, ModelInput - +from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ImageMedia, UserMessage from pkg_resources import resource_filename @@ -24,294 +23,7 @@ import ttnn import time - -class LlamaVision: - def __init__(self, model, model_args, mesh_device, vllm=False): - """ - Creating a LlamaVision wrapper requires only a mesh_device and model_args. - With model_args you have the checkpoint location, can specify max batch size - and max seqlen, and other model specific parameters. - - LlamaVision is general to text and chat. - - For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. - - """ - self.model = model - self.model_args = model_args - self.mesh_device = mesh_device - self.vllm = vllm - - def prefill_forward_single_user( - self, - vision_images, - vision_mask, - tokens, - xattn_caches, - user_id, - total_len, - prefill_len, - ): - """ - Performs vision encode step then text prefill. - Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) - """ - B = tokens.shape[0] - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( - batch_images=[vision_images], - batch_masks=[vision_mask], - total_len=total_len, - xattn_caches=xattn_caches, - user_id=user_id, - ) - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - tt_position_id, - rot_mats, - transformation_mats, - ) = self.model.prepare_inputs_prefill( - tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len - ) - - tt_logits = self.model.ttnn_prefill_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, - xattn_caches, - tt_position_id, - rot_mats, - transformation_mats, - user_id, - ) - - logits = self.model.process_output_prefill(tt_logits, B, prefill_len) - - return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits - - def decode_forward( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Performs text decode step. - Returns logits - """ - - # forward_decode should be traced callable - # decorator does compilation, capture, execute - # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. - # S = 1 - B, S = tokens.shape - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - tt_logits = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - xattn_caches, - tt_position_id, - rot_mats, - ) - - logits = self.model.process_output_decode(tt_logits, B, S) - return logits - - def capture_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Captures a trace for the decode_forward method. - """ - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - # Compile run - tt_logits_rm = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - xattn_caches, - tt_position_id, - rot_mats, - ) - - # Get inputs ready for trace run - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id - ) - - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_position_id, - rot_mats, - ) = self.model.copy_host_to_device( - (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) - ) - - trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) - B = tokens.shape[0] - # Do on-device transformations of inputs before forward - tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - B=B, - ) - - tt_logits_rm = self.model.ttnn_decode_forward( - tt_h, - tt_xattn_mask_transform, - tt_full_text_mask_expand_1NSH_transform, - xattn_caches, - tt_position_id, - rot_mats, - ) - - ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) - - return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats - - def decode_forward_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, # TODO: unused since captured in trace? - trace_id, - trace_logits_rm, - trace_h, - trace_xattn_mask, - trace_full_text_mask_expand_1NSH, - trace_position_id, - trace_rot_mats, - ): - ( - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - _, - tt_position_id, - rot_mats, - _, - ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id - ) - - self.model.copy_host_to_device( - host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), - device_tensors=( - trace_h, - trace_xattn_mask, - trace_full_text_mask_expand_1NSH, - trace_position_id, - trace_rot_mats, - ), - ) - - ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) - - B, S = tokens.shape - logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) - - return logits - - def easy_trace( - self, - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ): - """ - Tracing is easy! Just call this method and you'll run traced - """ - if not hasattr(self, "trace_id"): - ( - trace_id, - tt_logits_rm, - tt_h, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_position_id, - rot_mats, - ) = self.capture_trace( - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - self.trace_id = trace_id - self.trace_inputs = { - "tt_h": tt_h, - "tt_xattn_mask": tt_xattn_mask, - "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, - "tt_position_id": tt_position_id, - "rot_mats": rot_mats, - } - self.trace_outputs = { - "tt_logits_rm": tt_logits_rm, - } - - return self.decode_forward_trace( - position_id, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - self.trace_id, - self.trace_outputs["tt_logits_rm"], - self.trace_inputs["tt_h"], - self.trace_inputs["tt_xattn_mask"], - self.trace_inputs["tt_full_text_mask_expand_1NSH"], - self.trace_inputs["tt_position_id"], - self.trace_inputs["rot_mats"], - ) +from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision def get_sampler(temperature, top_p, tokenizer): @@ -370,11 +82,17 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn "normal", ], ) +@pytest.mark.parametrize( + "enable_trace", + (False, True), + ids=["no_trace", "trace"], +) @pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( mesh_device, warmup_iters, test_case, + enable_trace, temperature: float = 0, top_p: float = 0.9, max_seq_len: int = 512, @@ -391,11 +109,11 @@ def test_llama_multimodal_demo_text( mesh_device.enable_program_cache() mesh_device.enable_async(True) model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) - model = LlamaVision(model, model_args, mesh_device) + generator = LlamaVision(model, model_args, mesh_device) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) - xattn_caches = model.model.setup_cache(model_args.max_batch_size) + xattn_caches = generator.model.setup_cache(model_args.max_batch_size) with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") @@ -447,7 +165,7 @@ def test_llama_multimodal_demo_text( cross_attention_masks, full_text_row_masked_out_mask, logits, - ) = model.prefill_forward_single_user( + ) = generator.prefill_forward_single_user( vision_images, vision_mask, prompt_tokens_tensor, @@ -459,57 +177,35 @@ def test_llama_multimodal_demo_text( prefill_end = time.perf_counter() next_token, text = sampler(logits) - # logger.info(f"Prefill output: {next_token}:{text}") tokens[0, prefill_len] = next_token decode_times = [] - # Capture trace - # next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - # trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats = model.capture_trace( - # prefill_len, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # ) - for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() position_id = prefill_len + gen_idx next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - # logits = model.decode_forward( - # position_id, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # ) - logits = model.easy_trace( - position_id, - next_token_tensor, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - # logits = model.decode_forward_trace( - # position_id, - # next_token_tensor, - # cross_attention_masks, - # full_text_row_masked_out_mask, - # xattn_caches, - # trace_id, - # tt_logits_rm, - # tt_h, - # tt_xattn_mask, - # tt_full_text_mask_expand_1NSH, - # tt_position_id, - # rot_mats - # ) + + if enable_trace: + logits = generator.easy_trace( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + else: + logits = generator.decode_forward( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + next_token, text = sampler(logits) # Update next token tokens[0, position_id + 1] = next_token - # logger.info(f"Decode output {position_id}: {next_token}:{text}") decode_end = time.perf_counter() decode_times.append(decode_end - decode_start) @@ -530,4 +226,4 @@ def test_llama_multimodal_demo_text( decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 logger.info(f"Decode time: {decode_time_ms:.2f} ms") - # ttnn.release_trace(model.mesh_device, trace_id) + # ttnn.release_trace(generator.mesh_device, trace_id) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 80a27df0679..15ec522058a 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -498,7 +498,6 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand """ Does any transformations on device tensors which are necessary before ttnn_decode_forward """ - print("transforming xattn mask") assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py new file mode 100644 index 00000000000..06f32bc160d --- /dev/null +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -0,0 +1,429 @@ +import ttnn +import torch + +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + StopReason, +) + +from llama_models.llama3.reference_impl.generation import ( + ChatPrediction, + CompletionPrediction, + TokenResult, + sample_top_p, +) + + +class LlamaVision: + def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, formatter=None): + """ + Creating a LlamaVision wrapper requires only a mesh_device and model_args. + With model_args you have the checkpoint location, can specify max batch size + and max seqlen, and other model specific parameters. + + LlamaVision is general to text and chat. + + For bringup, make this class general to any backend implementation, as long as it takes torch tensors and returns torch tensors. + + """ + self.model = model + self.model_args = model_args + self.mesh_device = mesh_device + self.vllm = vllm + self.tokenizer = tokenizer + self.formatter = formatter + + def prefill_forward_single_user( + self, + vision_images, + vision_mask, + tokens, + xattn_caches, + user_id, + total_len, + prefill_len, + ): + """ + Performs vision encode step then text prefill. + Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) + """ + B = tokens.shape[0] + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=[vision_images], + batch_masks=[vision_mask], + total_len=total_len, + xattn_caches=xattn_caches, + user_id=user_id, + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + transformation_mats, + ) = self.model.prepare_inputs_prefill( + tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len + ) + + tt_logits = self.model.ttnn_prefill_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_full_text_mask_expand_11SD, + xattn_caches, + tt_position_id, + rot_mats, + transformation_mats, + user_id, + ) + + logits = self.model.process_output_prefill(tt_logits, B, prefill_len) + + return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits + + def decode_forward( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Performs text decode step. + Returns logits + """ + + # forward_decode should be traced callable + # decorator does compilation, capture, execute + # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. + B, S = tokens.shape + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + tt_logits = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + logits = self.model.process_output_decode(tt_logits, B, S) + return logits + + def capture_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Captures a trace for the decode_forward method. + """ + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + # Compile run + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + # Get inputs ready for trace run + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.model.copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) + ) + + trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + B = tokens.shape[0] + # Do on-device transformations of inputs before forward + tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=B, + ) + + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + xattn_caches, + tt_position_id, + rot_mats, + ) + + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + + return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats + + def decode_forward_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, # TODO: unused since captured in trace? + trace_id, + trace_logits_rm, + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + self.model.copy_host_to_device( + host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), + device_tensors=( + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ), + ) + + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + B, S = tokens.shape + logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) + + return logits + + def easy_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_id"): + ( + trace_id, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.capture_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + self.trace_id = trace_id + self.trace_inputs = { + "tt_h": tt_h, + "tt_xattn_mask": tt_xattn_mask, + "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, + "tt_position_id": tt_position_id, + "rot_mats": rot_mats, + } + self.trace_outputs = { + "tt_logits_rm": tt_logits_rm, + } + + return self.decode_forward_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + self.trace_id, + self.trace_outputs["tt_logits_rm"], + self.trace_inputs["tt_h"], + self.trace_inputs["tt_xattn_mask"], + self.trace_inputs["tt_full_text_mask_expand_1NSH"], + self.trace_inputs["tt_position_id"], + self.trace_inputs["rot_mats"], + ) + + def generate( + self, + model_input, + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + ): + # Do initial prefill + vision_images = model_input.vision.images + vision_mask = model_input.vision.mask + prompt_tokens = model_input.tokens + prefill_len = len(prompt_tokens) + total_len = prefill_len + max_gen_len # Prepares mask for full length of output + + prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S + # Suboptimal to allocate caches every time + xattn_caches = self.model.setup_cache(self.model_args.max_batch_size) + ( + xattn_caches, + cross_attention_masks, + full_text_row_masked_out_mask, + logits, + ) = self.prefill_forward_single_user( + vision_images, + vision_mask, + prompt_tokens_tensor, + xattn_caches, + user_id=0, + total_len=total_len, + prefill_len=prefill_len, + ) + + def sample(logits): + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + next_token = next_token.reshape(-1) + return next_token, self.tokenizer.decode(next_token.tolist()) + + next_token, text = sample(logits) + + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + for gen_idx in range(max_gen_len - 1): + position_id = prefill_len + gen_idx + next_token_tensor = next_token.reshape(1, 1) # B, S + + logits = self.decode_forward( + position_id, + next_token_tensor, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + + next_token, text = sample(logits) + yield TokenResult( + token=next_token[0].item(), + text=text, + ) + + def chat_completion( + self, + messages, + temperature=0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: + max_gen_len = self.model.configuration.max_seq_len - 1 + + tokens = [] + + stop_reason = None + breakpoint() + for result in self.generate( + model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format=False), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + if result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.formatter.decode_assistant_message(tokens, stop_reason) + + return ChatPrediction(generation=message) + + def text_completion( + self, + content: InterleavedTextMedia, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len=None, + ): + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: + max_gen_len = self.model.configuration.max_seq_len - 1 + + model_input = self.formatter.encode_content(content) + + tokens = [] + + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ): + tokens.append(result.token) + + generation = self.tokenizer.decode(tokens) + + return CompletionPrediction(generation=generation) From 81cc9b159f2211b78e7e24df5f60e93bbb802acf Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 8 Nov 2024 07:48:07 -0800 Subject: [PATCH 13/19] #14519: Fix test script now that pytest params changed --- models/demos/llama3/demo/simple_vision_demo.py | 2 +- tests/scripts/t3000/run_t3000_demo_tests.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 964554280ee..673c3bc5a73 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -85,7 +85,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn @pytest.mark.parametrize( "enable_trace", (False, True), - ids=["no_trace", "trace"], + ids=["no_trace", "yes_trace"], ) @pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index 81fe4693094..27d089acd33 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -90,7 +90,7 @@ run_t3000_llama3_vision_tests() { pip install -r models/demos/llama3/requirements.txt for fake_device in "$n300" "$t3k"; do - FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt and 1" --timeout 600; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt" --timeout 600; fail+=$? echo "LOG_METAL: Llama3 vision tests for $fake_device completed" done From 1149598648ac8c8e2c3aeba441cca039af22f240 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 8 Nov 2024 10:56:44 -0800 Subject: [PATCH 14/19] #14519: Remove breakpoint --- models/demos/llama3/tt/multimodal/vision_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py index 06f32bc160d..57f5cc9ef6a 100644 --- a/models/demos/llama3/tt/multimodal/vision_generator.py +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -382,7 +382,6 @@ def chat_completion( tokens = [] stop_reason = None - breakpoint() for result in self.generate( model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format=False), max_gen_len=max_gen_len, From 9e5d0b95fe31ae45b30864a5e4dc287eec2dc2e2 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Fri, 8 Nov 2024 12:53:19 -0800 Subject: [PATCH 15/19] #14519: license --- models/demos/llama3/tt/multimodal/vision_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py index 57f5cc9ef6a..d0073a8b911 100644 --- a/models/demos/llama3/tt/multimodal/vision_generator.py +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 import ttnn import torch From ac7ffcb7fc0651711667c57246734c57fc31f2f7 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 13 Nov 2024 07:38:17 -0800 Subject: [PATCH 16/19] #14519: Remove trace decorator --- models/demos/llama3/tests/conftest.py | 47 --------------------------- 1 file changed, 47 deletions(-) diff --git a/models/demos/llama3/tests/conftest.py b/models/demos/llama3/tests/conftest.py index 822c9621072..d15dee818f1 100644 --- a/models/demos/llama3/tests/conftest.py +++ b/models/demos/llama3/tests/conftest.py @@ -8,50 +8,3 @@ @pytest.fixture(autouse=True) def ensure_gc(): gc.collect() - - -def traced(callable): - """ - Test it locally, get it into ttnn mainline - """ - # TODO: release trace on delete or ??? - trace_id = None - args_device = None - kwargs_device = None - outputs = None - - def create_device_inputs(*args, **kwargs): - # allocate device tensors for each arg which is on host - # don't copy - nonlocal args_device - nonlocal kwargs_device - - def copy_inputs_to_device(*args, **kwargs): - # copy any host tensors to device - - # Check that kwargs keys matches kwargs_device keys - # check that args len matches args_device len - nonlocal args_device - nonlocal kwargs_device - pass - - def wrapper(self, *args, **kwargs): - nonlocal trace_id - nonlocal outputs - if not trace_id: - create_device_inputs(args, kwargs) - copy_inputs_to_device(args, kwargs) - ret = callable(self, *args, **kwargs) - outputs = ret - trace_id = ttnn.capture_trace(...) - callable(self, *args, **kwargs) - ttnn.end_trace(...) - return ret - # check that inputs, outputs are host tensors - # or if an input is on device, do nothing - # copy new inputs to inputs, return outputs - copy_inputs_to_device(args, kwargs) - ttnn.execute_trace(trace_id) - return outputs - - return wrapper From 751e4b108ca1982a54595df9e1082adb1b27a090 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 13 Nov 2024 08:10:21 -0800 Subject: [PATCH 17/19] #14519: remove batch option from rot mat --- .../test_llama_cross_attention_transformer_text.py | 5 ++++- models/demos/llama3/tt/llama_common.py | 12 +++++++++--- .../demos/llama3/tt/multimodal/llama_vision_model.py | 1 - 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 84a04bfd372..211c990dc3a 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -292,7 +292,10 @@ def test_llama_cross_attention_transformer_text_inference( ) rot_mats, _ = get_single_rot_mat( - model_args.head_dim, mesh_device, model_args.num_devices, start_pos=cur_pos - 1, batch=batch + model_args.head_dim, + mesh_device, + model_args.num_devices, + start_pos=cur_pos - 1, ) transformation_mats = None diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index e8e88222e6e..6368443df4f 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -185,7 +185,13 @@ def get_rot_transformation_mat(dhead): def get_single_rot_mat( - dhead, mesh_device, num_devices, start_pos=0, theta: float = 500000.0, use_scaled=True, on_host=False, batch=1 + dhead, + mesh_device, + num_devices, + start_pos=0, + theta: float = 500000.0, + use_scaled=True, + on_host=False, ): freqs_unscaled = 1.0 / (theta ** (torch.arange(0, dhead, 2)[: (dhead // 2)].float() / dhead)) if use_scaled: @@ -210,13 +216,13 @@ def get_single_rot_mat( current_rot_mat[torch.arange(1, dhead, 2), torch.arange(0, dhead, 2)] = sin_freqs.clone() return ttnn.from_torch( - current_rot_mat.T.unsqueeze(0).unsqueeze(0).expand(-1, batch, -1, -1), # 1,batch,head_dim,head_dim + current_rot_mat.T.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( - rot_matrix.unsqueeze(0).unsqueeze(0).expand(-1, batch, -1, -1), # 1,batch,head_dim,head_dim + rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 15ec522058a..f40bd7f593e 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -462,7 +462,6 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro self.configuration.num_devices, start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 # TODO: B must match max_batch_size, be careful - batch=B, on_host=True, ) From 659c111206d52ecb06099e9f2a4a096e3f490777 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 13 Nov 2024 08:10:32 -0800 Subject: [PATCH 18/19] #14519: Add traced demo to CI --- tests/scripts/t3000/run_t3000_demo_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index 27d089acd33..2abda2c7598 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -90,7 +90,7 @@ run_t3000_llama3_vision_tests() { pip install -r models/demos/llama3/requirements.txt for fake_device in "$n300" "$t3k"; do - FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/multimodal_demo_chat.py -k "tt" --timeout 600; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/simple_vision_demo.py -k "cold and yes_trace" --timeout 600; fail+=$? echo "LOG_METAL: Llama3 vision tests for $fake_device completed" done From 31944dba02caaf1d647dd8598f90d44e4bbcb0a0 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 13 Nov 2024 08:41:10 -0800 Subject: [PATCH 19/19] #14519: Fix merge bug in xblock test --- .../multimodal/test_llama_cross_block.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 04112167154..d977d73e922 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -257,17 +257,14 @@ def test_llama_cross_attention_transformer_block_inference( ), ) - full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, dim) - tt_full_text_mask_expand_11SD = ttnn.from_torch( - full_text_mask_expand_11SD, - device=mesh_device, - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), - ) - if mode == "decode": - tt_full_text_mask_expand_11SD = None + tt_out = tt_model( + tt_x, + xattn_mask=tt_xattn_mask, + full_text_row_masked_out_mask_1NSH=tt_full_text_mask_expand_1NSH, + full_text_row_masked_out_mask_11SD=None, + xattn_cache=tt_xattn_cache, + mode=mode, + ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim)