diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index cda3c2ed957..1d5654efba2 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -189,22 +189,14 @@ def test_llama_multimodal_demo_text( position_id = prefill_lens + gen_idx next_token_tensor = next_tokens.reshape(max_batch_size, 1) - if enable_trace: - logits = generator.easy_trace( - position_id, - next_token_tensor, - batch_xattn_masks, - batch_text_masks, - xattn_caches, - ) - else: - logits = generator.decode_forward( - position_id, - next_token_tensor, - batch_xattn_masks, - batch_text_masks, - xattn_caches, - ) + logits = generator.decode_forward( + position_id, + next_token_tensor, + batch_xattn_masks, + batch_text_masks, + xattn_caches, + enable_trace=enable_trace, + ) next_tokens, next_texts = sampler(logits) # Update next token diff --git a/models/demos/llama3/tt/generator.py b/models/demos/llama3/tt/generator.py index c42450e48d3..89fde136d6d 100644 --- a/models/demos/llama3/tt/generator.py +++ b/models/demos/llama3/tt/generator.py @@ -167,7 +167,7 @@ def decode_forward_trace_text( return logits - def prefill_forward_single_user( + def _prefill_forward_single_user( self, vision_images, vision_mask, @@ -178,6 +178,7 @@ def prefill_forward_single_user( prefill_len, page_table=None, kv_cache=None, + cross_page_table=None, ): """ Performs vision encode step then text prefill. @@ -194,6 +195,10 @@ def prefill_forward_single_user( if page_table is not None: page_table = self._get_prefill_user_page_table(page_table, kv_cache, prefill_len) + if cross_page_table is not None: + num_vision_tokens = vision_tokens.shape[2] + cross_page_table = self._get_prefill_user_page_table(cross_page_table, kv_cache, num_vision_tokens) + ( tt_h, tt_xattn_mask, @@ -201,12 +206,14 @@ def prefill_forward_single_user( tt_full_text_mask_expand_11SD, rot_mats, tt_page_table, + tt_cross_page_table, ) = self.model.prepare_inputs_prefill( tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len, page_table=page_table, + cross_page_table=cross_page_table, ) tt_logits = self.model.ttnn_prefill_forward( @@ -221,9 +228,11 @@ def prefill_forward_single_user( page_table=tt_page_table, kv_cache=kv_cache, get_last_token=(last_token_idx // 32) * 32, + cross_page_table=tt_cross_page_table, ) del tt_page_table + del tt_cross_page_table logits = self.model.process_output_prefill(tt_logits, B, last_token_idx=(last_token_idx % 32)) @@ -239,9 +248,10 @@ def prefill_forward( prompt_lens, page_table=None, kv_cache=None, + cross_page_table=None, ): """ - Batched version of prefill_forward_single_user for vision model. + Batched version of _prefill_forward_single_user for vision model. """ batch, batch_seq_len = tokens.shape output_logits = torch.zeros(batch, 1, self.model_args.vocab_size) @@ -256,7 +266,7 @@ def prefill_forward( cross_attention_masks, full_text_row_masked_out_mask, logits, - ) = self.prefill_forward_single_user( + ) = self._prefill_forward_single_user( vision_images=vision_images[user_id], vision_mask=vision_masks[user_id], tokens=tokens[user_id : user_id + 1, :seq_len], # Keep batch dimension @@ -266,6 +276,7 @@ def prefill_forward( prefill_len=seq_len, page_table=page_table, kv_cache=kv_cache, + cross_page_table=cross_page_table, ) output_logits[user_id] = logits output_xattn_masks.append(cross_attention_masks) @@ -281,14 +292,51 @@ def decode_forward( tokens, cross_attention_masks, full_text_row_masked_out_mask, - xattn_caches, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + decode_kwargs = { + "position_id": start_pos, + "tokens": tokens, + "cross_attention_masks": cross_attention_masks, + "full_text_row_masked_out_mask": full_text_row_masked_out_mask, + "xattn_caches": xattn_caches, + "page_table": page_table, + "kv_cache": kv_cache, + "cross_page_table": cross_page_table, + } + if enable_trace: + tt_logits = self._easy_trace(**decode_kwargs) + else: + tt_logits = self._decode_forward_no_trace(**decode_kwargs) + + if read_from_device: + return self.read_decode_output(tt_logits, tokens.shape[0]) + else: + return tt_logits + + def read_decode_output(self, tt_logits, unpadded_batch): + logits = self.model.process_output_decode(tt_logits, B=unpadded_batch, S=1) + return logits + + def _decode_forward_no_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches=None, page_table=None, kv_cache=None, - prompt_lens=None, + cross_page_table=None, ): """ Performs text decode step. - Returns logits + Returns tt_logits on device """ # forward_decode should be traced callable @@ -303,8 +351,14 @@ def decode_forward( tt_position_id, tt_rot_mats, tt_page_table, + tt_cross_page_table, ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=start_pos, page_table=page_table + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id=position_id, + page_table=page_table, + cross_page_table=cross_page_table, ) tt_logits = self.model.ttnn_decode_forward( @@ -316,18 +370,21 @@ def decode_forward( tt_rot_mats, page_table=tt_page_table, kv_cache=kv_cache, + cross_page_table=tt_cross_page_table, ) - logits = self.model.process_output_decode(tt_logits, B, S) - return logits + return tt_logits - def capture_trace( + def _capture_trace( self, position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, + page_table=None, + kv_cache=None, + cross_page_table=None, ): """ Captures a trace for the decode_forward method. @@ -339,8 +396,14 @@ def capture_trace( tt_position_id, tt_rot_mats, tt_page_table, + tt_cross_page_table, ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id=position_id, + page_table=page_table, + cross_page_table=cross_page_table, ) # Compile run @@ -351,7 +414,11 @@ def capture_trace( xattn_caches, tt_position_id, tt_rot_mats, + page_table=tt_page_table, + kv_cache=kv_cache, + cross_page_table=tt_cross_page_table, ) + logger.info("Done Compiling Model") # Get inputs ready for trace run ( @@ -360,9 +427,15 @@ def capture_trace( tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, - _, + tt_page_table, + tt_cross_page_table, ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id, + page_table=page_table, + cross_page_table=cross_page_table, ) ( @@ -371,8 +444,18 @@ def capture_trace( tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, + tt_page_table, + tt_cross_page_table, ) = copy_host_to_device( - (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id), + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ), mesh_device=self.mesh_device, ) @@ -400,19 +483,34 @@ def capture_trace( xattn_caches, tt_position_id, tt_rot_mats, + page_table=tt_page_table, + kv_cache=kv_cache, + cross_page_table=tt_cross_page_table, ) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + logger.info("Done Capturing Decode Trace") - return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id + return ( + trace_id, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ) - def decode_forward_trace( + def _decode_forward_trace( self, position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, - xattn_caches, # TODO: unused since captured in trace? + page_table, + cross_page_table, trace_id, trace_logits_rm, trace_h, @@ -420,43 +518,64 @@ def decode_forward_trace( trace_full_text_mask_expand_1NSH, trace_position_id, trace_rope_id, + trace_page_table, + trace_cross_page_table, ): + """ + Executes the trace for the decode_forward method but does not read back outputs. + """ ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, - _, + tt_page_table, + tt_cross_page_table, ) = self.model.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id=position_id, + page_table=page_table, + cross_page_table=cross_page_table, ) copy_host_to_device( - host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id), + host_tensors=( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ), device_tensors=( trace_h, trace_xattn_mask, trace_full_text_mask_expand_1NSH, trace_position_id, trace_rope_id, + trace_page_table, + trace_cross_page_table, ), ) ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) - B, S = tokens.shape - logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) - - return logits + return trace_logits_rm - def easy_trace( + def _easy_trace( self, position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, - xattn_caches, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, ): """ Tracing is easy! Just call this method and we'll handle tracing for you. @@ -470,12 +589,17 @@ def easy_trace( tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, - ) = self.capture_trace( + tt_page_table, + tt_cross_page_table, + ) = self._capture_trace( position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, ) self.trace_id = trace_id self.trace_inputs = { @@ -484,17 +608,20 @@ def easy_trace( "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, "tt_position_id": tt_position_id, "tt_rope_id": tt_rope_id, + "tt_page_table": tt_page_table, + "tt_cross_page_table": tt_cross_page_table, } self.trace_outputs = { "tt_logits_rm": tt_logits_rm, } - return self.decode_forward_trace( + trace_logits_rm = self._decode_forward_trace( position_id, tokens, cross_attention_masks, full_text_row_masked_out_mask, - xattn_caches, + page_table, + cross_page_table, self.trace_id, self.trace_outputs["tt_logits_rm"], self.trace_inputs["tt_h"], @@ -502,8 +629,12 @@ def easy_trace( self.trace_inputs["tt_full_text_mask_expand_1NSH"], self.trace_inputs["tt_position_id"], self.trace_inputs["tt_rope_id"], + self.trace_inputs["tt_page_table"], + self.trace_inputs["tt_cross_page_table"], ) + return trace_logits_rm + def generate( self, model_input, @@ -526,7 +657,7 @@ def generate( cross_attention_masks, full_text_row_masked_out_mask, logits, - ) = self.prefill_forward_single_user( + ) = self._prefill_forward_single_user( vision_images, vision_mask, prompt_tokens_tensor, @@ -536,7 +667,7 @@ def generate( prefill_len=prefill_len, ) - logits = logits.view(1, 1, self.model_args.max_vocab_size) + logits = logits.view(1, 1, self.model_args.vocab_size) def sample(logits): if temperature > 0: @@ -564,6 +695,7 @@ def sample(logits): [cross_attention_masks], [full_text_row_masked_out_mask], xattn_caches, + enable_trace=False, ) next_token, text = sample(logits) diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py index f962b2801b1..7989aba9547 100644 --- a/models/demos/llama3/tt/generator_vllm.py +++ b/models/demos/llama3/tt/generator_vllm.py @@ -9,14 +9,17 @@ from models.demos.llama3.tt.generator import LlamaGenerator from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model +from models.utility_functions import nearest_32 from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): """ Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama(). - Note that vLLM's input_processor_for_mllama performs additional processing to handle chunking which we do not yet support. + Note that vLLM's input_processor_for_mllama performs additional processing to compute num_tiles while here it is fixed. """ # Move encoder_prompt to prompt. If the user does not explicitly provide separate @@ -27,11 +30,32 @@ def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInput inputs["prompt"] = inputs["encoder_prompt"] inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] + multi_modal_data = inputs.get("encoder_multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data or multi_modal_data["image"] is None: + # text-only + inputs["encoder_prompt"] = "" + inputs["encoder_prompt_token_ids"] = [] + inputs["encoder_multi_modal_data"] = {} + return inputs + + # Set encoder prompt length based on the number of vision tokens so block manager allocates enough blocks (cross block tables). + hf_config = ctx.model_config.hf_config + assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = nearest_32( + (hf_config.vision_config.image_size // 14) ** 2 + 1 + ) # Note: we use nearest 32 while vLLM does not by default + num_vision_tokens = ( + hf_config.vision_config.max_num_tiles * token_per_chunk + ) # Note: we use max_num_tiles while vLLM uses num_tiles by default + inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_vision_tokens + inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_vision_tokens + return inputs +# @MULTIMODAL_REGISTRY.register_image_input_mapper() # TODO: Add once model can accept inputs from multi_modal_input_mapper (raw pixel values) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) -class TtMllamaForConditionalGeneration(LlamaGenerator): +class TtMllamaForConditionalGeneration(LlamaGenerator, SupportsMultiModal): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -52,11 +76,10 @@ def prefill_forward( self, tokens: torch.Tensor, images: List[PIL.Image.Image], - xattn_caches, - start_pos, - page_table: torch.Tensor = None, - kv_cache=None, - prompt_lens=None, + page_table: torch.Tensor, + kv_cache, + prompt_lens, + cross_page_table: torch.Tensor, ): """ Replaces prefill_forward from LlamaGenerator with a version that supports mask creation. @@ -73,5 +96,13 @@ def prefill_forward( total_lens.append(prompt_lens[user_id] + self.max_gen_len) return super().prefill_forward( - vision_images, vision_masks, tokens, xattn_caches, total_lens, prompt_lens, page_table, kv_cache + vision_images, + vision_masks, + tokens, + None, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, ) diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index cba5b063237..a9405a43739 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -92,7 +92,6 @@ def __init__( self.ccl_topology = configuration.ccl_topology() self.is_multichip = configuration.is_multichip - self.layer_num = layer_num layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) if configuration.dummy_weights or (weight_cache_path is None): cache_name = lambda _: None @@ -317,8 +316,8 @@ def forward_decode( # KV update ### if kv_cache: - keys = kv_cache[self.layer_num][0] - values = kv_cache[self.layer_num][1] + keys = kv_cache[0] + values = kv_cache[1] else: keys = self.layer_past[0] values = self.layer_past[1] @@ -536,7 +535,7 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None, k # Fill KV-Cache if kv_cache: - keys_BKSD, values_BKSD = kv_cache[self.layer_num][0], kv_cache[self.layer_num][1] + keys_BKSD, values_BKSD = kv_cache[0], kv_cache[1] else: keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index bc619737976..f86a0058e6e 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -172,7 +172,10 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): mesh_mapper=mesh_mapper, ) - rope_idxs = self.rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_current_pos = torch.maximum( + current_pos, torch.tensor(0, dtype=torch.int64) + ) # Ensure position indices are non-negative + rope_idxs = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True) current_pos_tt = ttnn.from_torch( current_pos, device=None, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 5aa338a012b..57bfedecffa 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -131,7 +131,7 @@ def __init__( eps=self.norm_eps, ) - def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache): + def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache, cross_page_table=None): """ Uses xattn_tokens to compute K, V. Should be run inside of forward_prefill. Updates xattn_cache with K, V (TODO: support page table for KV cache) @@ -188,17 +188,28 @@ def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache): xk = self.k_norm(xk, mode="decode") k_cache, v_cache = xattn_cache - # Work around fill_cache memory constraint by making these sharded - k_fill = ttnn.interleaved_to_sharded(xk, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) - v_fill = ttnn.interleaved_to_sharded(xv, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) - ttnn.fill_cache(k_cache, k_fill, user_id) - ttnn.fill_cache(v_cache, v_fill, user_id) + if cross_page_table: + ttnn.experimental.paged_fill_cache( + k_cache, ttnn.experimental.typecast(xk, k_cache.dtype), cross_page_table, batch_idx=user_id + ) + ttnn.experimental.paged_fill_cache( + v_cache, ttnn.experimental.typecast(xv, v_cache.dtype), cross_page_table, batch_idx=user_id + ) + else: + # Work around fill_cache memory constraint by making these sharded + k_fill = ttnn.interleaved_to_sharded(xk, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) + v_fill = ttnn.interleaved_to_sharded(xv, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) + + ttnn.fill_cache(k_cache, k_fill, user_id) + ttnn.fill_cache(v_cache, v_fill, user_id) return xk, xv - def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): - batch = xattn_cache[0].shape[0] + def forward_decode( + self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, cross_page_table=None + ): + batch = xattn_mask.shape[1] x_11SH = ttnn.sharded_to_interleaved(x_11SH, ttnn.L1_MEMORY_CONFIG) # TODO support sharded input @@ -232,17 +243,29 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, ) # TODO: Can I get rid of the KV repeat_interleave? - - output = ttnn.transformer.scaled_dot_product_attention_decode( - xq, - xk, - xv, - is_causal=False, - attn_mask=xattn_mask, - scale=self.scale, - program_config=program_config, - compute_kernel_config=self.compute_kernel_config_sdpa, - ) + if cross_page_table: + output = ttnn.transformer.paged_scaled_dot_product_attention_decode( + xq, + xk, + xv, + is_causal=False, + attn_mask=xattn_mask, + page_table_tensor=cross_page_table, + scale=self.scale, + program_config=program_config, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + else: + output = ttnn.transformer.scaled_dot_product_attention_decode( + xq, + xk, + xv, + is_causal=False, + attn_mask=xattn_mask, + scale=self.scale, + program_config=program_config, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) # WARNING: this broadcast is also broken, must broadcast on host output = ttnn.mul(output, full_text_row_masked_out_mask_1NSH) @@ -275,14 +298,23 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, return ttnn.to_memory_config(output, self.model_config["DECODE_RESIDUAL_MEMCFG"]) def forward_prefill( - self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id, vision_tokens + self, + x_11SH, + xattn_mask, + full_text_row_masked_out_mask_1NSH, + xattn_cache, + user_id, + vision_tokens, + cross_page_table=None, ): seq_len = x_11SH.shape[-2] # B, S, D assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" # Compute cross attention cache. Return contiguous caches - k_cache_user, v_cache_user = self.compute_xattn_kv_cache(vision_tokens, user_id, xattn_cache) + k_cache_user, v_cache_user = self.compute_xattn_kv_cache( + vision_tokens, user_id, xattn_cache, cross_page_table=cross_page_table + ) cache_seq_len = k_cache_user.shape[-2] if seq_len > 1024: @@ -357,7 +389,15 @@ def forward_prefill( return output def forward( - self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, mode, user_id=0, vision_tokens=None + self, + x_11SH, + xattn_mask, + full_text_row_masked_out_mask_1NSH, + xattn_cache, + mode, + user_id=0, + vision_tokens=None, + cross_page_table=None, ): if mode == "prefill": return self.forward_prefill( @@ -367,6 +407,9 @@ def forward( xattn_cache, user_id=user_id, vision_tokens=vision_tokens, + cross_page_table=cross_page_table, ) else: - return self.forward_decode(x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache) + return self.forward_decode( + x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, cross_page_table=cross_page_table + ) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 1c05ec06e1c..162f6dc6da7 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -276,10 +276,12 @@ def forward( mode="decode", page_table=None, kv_cache=None, + cross_page_table=None, text_only_inference=False, vision_tokens=None, get_last_token=-1, ): + total_layer_idx = 0 # Used to track the total layer index for accessing the paged kv cache for idx, ( layer, xattn_layer, @@ -289,13 +291,18 @@ def forward( h = xattn_layer( h, xattn_mask=xattn_mask, - xattn_cache=xattn_caches[xattn_layer_idx], + xattn_cache=xattn_caches[xattn_layer_idx] + if cross_page_table is None + else kv_cache[total_layer_idx], full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, full_text_row_masked_out_mask_11SD=full_text_row_masked_out_mask_11SD, mode=mode, user_id=user_id, vision_tokens=vision_tokens, + cross_page_table=cross_page_table, ) + if idx in self.fusion_schedule: + total_layer_idx += 1 h = layer( h, current_pos, @@ -303,8 +310,9 @@ def forward( user_id=user_id, mode=mode, page_table=page_table, - kv_cache=kv_cache, + kv_cache=kv_cache[total_layer_idx] if kv_cache is not None else None, ) + total_layer_idx += 1 if get_last_token != -1: h = ttnn.slice(h, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, h.shape[-1])) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 9d8c3760af0..90c693b9567 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -125,6 +125,7 @@ def forward( mode, user_id=0, vision_tokens=None, + cross_page_table=None, ): skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( @@ -139,6 +140,7 @@ def forward( mode=mode, user_id=user_id, vision_tokens=vision_tokens, + cross_page_table=cross_page_table, ) # FIXME: DRAM workaround for No circular buffer with id error attn_out = ttnn.to_memory_config(attn_out, memory_config=ttnn.DRAM_MEMORY_CONFIG) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index ef400f99275..e879f2d2342 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -282,7 +282,13 @@ def prepare_inputs_common(self, position_ids, tokens): return h def prepare_inputs_prefill( - self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len, page_table=None + self, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + prefill_len, + page_table=None, + cross_page_table=None, ): B = tokens.shape[0] assert B == 1, f"Only batch 1 is supported, got {B}" @@ -361,6 +367,17 @@ def prepare_inputs_prefill( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + if isinstance(cross_page_table, torch.Tensor): + # Support vLLM tensor cross_page_table input + cross_page_table = ttnn.as_tensor( + cross_page_table, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + return ( tt_h, tt_xattn_mask, @@ -368,10 +385,17 @@ def prepare_inputs_prefill( tt_full_text_mask_expand_11SD, rot_mats, page_table, + cross_page_table, ) def prepare_inputs_decode( - self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None + self, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id, + page_table=None, + cross_page_table=None, ): ( tt_h, @@ -380,8 +404,14 @@ def prepare_inputs_decode( tt_position_id, tt_rope_id, tt_page_table, + tt_cross_page_table, ) = self.prepare_decode_inputs_host( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=page_table + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id, + page_table=page_table, + cross_page_table=cross_page_table, ) ( @@ -391,8 +421,17 @@ def prepare_inputs_decode( tt_position_id, tt_rope_id, tt_page_table, + tt_cross_page_table, ) = copy_host_to_device( - (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, tt_page_table), + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + tt_rope_id, + tt_page_table, + tt_cross_page_table, + ), mesh_device=self.mesh_device, ) @@ -411,15 +450,26 @@ def prepare_inputs_decode( tt_position_id, tt_rot_mats, tt_page_table, + tt_cross_page_table, ) def prepare_decode_inputs_host( - self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None + self, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + position_id, + page_table=None, + cross_page_table=None, ): B = tokens.shape[0] assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" + unpadded_batch_size = len(cross_attention_masks) + assert unpadded_batch_size == len( + full_text_row_masked_out_mask + ), f"cross_attention_masks batch dim ({unpadded_batch_size}) does not match full_text_row_masked_out_mask batch dim ({len(full_text_row_masked_out_mask)})" h = self.prepare_inputs_common(position_id, tokens) tt_h = self.configuration.prepare_residual_tensor_decode( h, @@ -435,8 +485,20 @@ def prepare_decode_inputs_host( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tt_rope_id = self.text_model.rope_setup.get_rot_idxs(position_id, on_host=True) - xattn_mask = torch.cat([cross_attention_masks[i][:, :, position_id[i]] for i in range(B)], dim=1).unsqueeze(0) + rot_position_id = torch.maximum( + position_id, torch.tensor(0, dtype=torch.int64) + ) # Ensure position indices are non-negative + tt_rope_id = self.text_model.rope_setup.get_rot_idxs(rot_position_id, on_host=True) + + xattn_mask = torch.cat( + [cross_attention_masks[i][:, :, position_id[i]] for i in range(unpadded_batch_size)], dim=1 + ).unsqueeze(0) + # Pad xattn_mask along batch if tokens have been padded + if B > unpadded_batch_size: + xattn_mask = torch.cat( + [xattn_mask, torch.zeros(1, 1, B - unpadded_batch_size, xattn_mask.shape[-1])], dim=2 + ) + xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1) xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() @@ -449,8 +511,13 @@ def prepare_decode_inputs_host( ) full_text_mask = torch.cat( - [full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(B)], dim=1 + [full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(unpadded_batch_size)], dim=1 ).unsqueeze(0) + # Pad full_text_mask along batch if tokens have been padded + if B > unpadded_batch_size: + full_text_mask = torch.cat( + [full_text_mask, torch.zeros(1, 1, B - unpadded_batch_size, full_text_mask.shape[-1])], dim=2 + ) full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) @@ -472,6 +539,15 @@ def prepare_decode_inputs_host( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) + if isinstance(cross_page_table, torch.Tensor): + # Support vLLM tensor cross_page_table input + cross_page_table = ttnn.as_tensor( + cross_page_table, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + return ( tt_h, tt_xattn_mask, @@ -479,6 +555,7 @@ def prepare_decode_inputs_host( tt_position_id, tt_rope_id, page_table, + cross_page_table, ) def transform_decode_inputs_device(self, tt_h, tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): @@ -550,6 +627,7 @@ def forward( vision_tokens=None, page_table=None, kv_cache=None, + cross_page_table=None, ) -> torch.Tensor: """ This method takes torch tensors in, returns torch tensors. @@ -570,12 +648,14 @@ def forward( tt_position_id, rot_mats, tt_page_table, + tt_cross_page_table, ) = prepare_fn( tokens, cross_attention_masks, full_text_row_masked_out_mask, pos_arg, page_table=page_table, + cross_page_table=cross_page_table, ) logits = self.text_model.forward( @@ -592,6 +672,7 @@ def forward( kv_cache=kv_cache, text_only_inference=text_only_inference, vision_tokens=vision_tokens, + cross_page_table=tt_cross_page_table, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) @@ -611,10 +692,17 @@ def ttnn_prefill_forward( page_table=None, kv_cache=None, get_last_token=-1, + cross_page_table=None, ): """ This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. """ + + if cross_page_table is not None: + assert ( + xattn_caches is None and kv_cache is not None + ), "no separate xattn_caches should be allocated when using cross_page_table with paged kv cache" + logits = self.text_model.forward( h, xattn_mask=xattn_mask, @@ -627,6 +715,7 @@ def ttnn_prefill_forward( mode="prefill", page_table=page_table, kv_cache=kv_cache, + cross_page_table=cross_page_table, vision_tokens=vision_tokens, get_last_token=get_last_token, ) @@ -643,10 +732,17 @@ def ttnn_decode_forward( rot_mats, page_table=None, kv_cache=None, + cross_page_table=None, ): """ This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors. """ + + if cross_page_table is not None: + assert ( + xattn_caches is None and kv_cache is not None + ), "no separate xattn_caches should be allocated when using cross_page_table with paged kv cache" + logits = self.text_model.forward( h, xattn_mask=xattn_mask, @@ -658,6 +754,7 @@ def ttnn_decode_forward( mode="decode", page_table=page_table, kv_cache=kv_cache, + cross_page_table=cross_page_table, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) return tt_out