diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index c8718256e62..a5bf099d027 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -119,7 +119,7 @@ def decode_forward( _, tt_position_id, rot_mats, - transformation_mats, + _, ) = self.model.prepare_inputs_decode( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id ) @@ -131,7 +131,6 @@ def decode_forward( xattn_caches, tt_position_id, rot_mats, - transformation_mats, ) logits = self.model.process_output_decode(tt_logits, B, S) @@ -148,7 +147,171 @@ def capture_trace( """ Captures a trace for the decode_forward method. """ - pass + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_inputs_decode( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + # Compile run + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + xattn_caches, + tt_position_id, + rot_mats, + ) + + # Get inputs ready for trace run + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id + ) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.model.copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) + ) + + trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + B = tokens.shape[0] + # Do on-device transformations of inputs before forward + tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=B, + ) + + tt_logits_rm = self.model.ttnn_decode_forward( + tt_h, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + xattn_caches, + tt_position_id, + rot_mats, + ) + + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + + return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats + + def decode_forward_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, # TODO: unused since captured in trace? + trace_id, + trace_logits_rm, + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _, + tt_position_id, + rot_mats, + _, + ) = self.model.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + ) + + self.model.copy_host_to_device( + host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), + device_tensors=( + trace_h, + trace_xattn_mask, + trace_full_text_mask_expand_1NSH, + trace_position_id, + trace_rot_mats, + ), + ) + + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + B, S = tokens.shape + logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S) + + return logits + + def easy_trace( + self, + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ): + """ + Tracing is easy! Just call this method and you'll run traced + """ + if not hasattr(self, "trace_id"): + ( + trace_id, + tt_logits_rm, + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.capture_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + ) + self.trace_id = trace_id + self.trace_inputs = { + "tt_h": tt_h, + "tt_xattn_mask": tt_xattn_mask, + "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, + "tt_position_id": tt_position_id, + "rot_mats": rot_mats, + } + self.trace_outputs = { + "tt_logits_rm": tt_logits_rm, + } + + return self.decode_forward_trace( + position_id, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + self.trace_id, + self.trace_outputs["tt_logits_rm"], + self.trace_inputs["tt_h"], + self.trace_inputs["tt_xattn_mask"], + self.trace_inputs["tt_full_text_mask_expand_1NSH"], + self.trace_inputs["tt_position_id"], + self.trace_inputs["rot_mats"], + ) def get_sampler(temperature, top_p, tokenizer): @@ -207,6 +370,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn "normal", ], ) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( mesh_device, warmup_iters, @@ -300,17 +464,48 @@ def test_llama_multimodal_demo_text( decode_times = [] + # Capture trace + # next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S + # trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats = model.capture_trace( + # prefill_len, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # ) + for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() position_id = prefill_len + gen_idx next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S - logits = model.decode_forward( + # logits = model.decode_forward( + # position_id, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # ) + logits = model.easy_trace( position_id, next_token_tensor, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, ) + # logits = model.decode_forward_trace( + # position_id, + # next_token_tensor, + # cross_attention_masks, + # full_text_row_masked_out_mask, + # xattn_caches, + # trace_id, + # tt_logits_rm, + # tt_h, + # tt_xattn_mask, + # tt_full_text_mask_expand_1NSH, + # tt_position_id, + # rot_mats + # ) next_token, text = sampler(logits) # Update next token tokens[0, position_id + 1] = next_token @@ -334,3 +529,5 @@ def test_llama_multimodal_demo_text( logger.info(f"Prefill time: {prefill_time_ms:.2f} ms") decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 logger.info(f"Decode time: {decode_time_ms:.2f} ms") + + # ttnn.release_trace(model.mesh_device, trace_id) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 4c431643825..d8e8bf7c4fe 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -625,7 +625,7 @@ def ccl_topology(self): return ttnn.Topology.Linear return None - def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): + def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False): """ Prepare inputs for decode mode. x: (batch, seq, dim) @@ -665,11 +665,11 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): if torch.is_tensor(x): x = ttnn.from_torch( x, - device=self.mesh_device, + device=self.mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper, - memory_config=input_mem_cfg, + memory_config=input_mem_cfg if not on_host else None, ) else: # Convert the row major layout from embedding back to tile layout x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 94f65ea13af..80a27df0679 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -376,20 +376,58 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma ) def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + _transformation_mats, + ) = self.prepare_decode_inputs_host(tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id) + + ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)) + + tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + B=tokens.shape[0], + ) + + return ( + tt_h, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + _tt_full_text_mask_expand_11SD, + tt_position_id, + rot_mats, + _transformation_mats, + ) + + def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): B = tokens.shape[0] assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" - S = 1 position_ids = torch.tensor([position_id], dtype=torch.long) h = self.prepare_inputs_common(position_ids, tokens) + tt_h = self.configuration.prepare_inputs_ttnn_decode( + h, + ttnn.DRAM_MEMORY_CONFIG, + on_host=True, + ) tt_position_id = ttnn.from_torch( position_ids, - device=self.mesh_device, + device=None, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) @@ -399,25 +437,11 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, - device=self.mesh_device, + device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) - tt_xattn_mask = ttnn.reshape( - tt_xattn_mask, - shape=ttnn.Shape( - [ - S, - B, - self.configuration.n_heads // self.configuration.num_devices, - xattn_mask.shape[-1], - ], - [S, B, 32, xattn_mask.shape[-1]], - ), - ) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] full_text_mask_expand_1NSH = full_text_mask.expand( @@ -426,35 +450,12 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas full_text_mask_expand_1NSH = full_text_mask_expand_1NSH.transpose(1, 2).contiguous() tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, - device=self.mesh_device, + device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) - tt_full_text_mask_expand_1NSH = ttnn.reshape( - tt_full_text_mask_expand_1NSH, - shape=ttnn.Shape( - [ - S, - B, - self.configuration.n_heads // self.configuration.num_devices, - self.configuration.head_dim, - ], - [ - S, - B, - 32, - self.configuration.head_dim, - ], - ), - ) - tt_h = self.configuration.prepare_inputs_ttnn_decode( - h, - ttnn.DRAM_MEMORY_CONFIG, - ) rot_mats, _ = get_single_rot_mat( self.configuration.head_dim, self.mesh_device, @@ -462,7 +463,9 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 # TODO: B must match max_batch_size, be careful batch=B, + on_host=True, ) + transformation_mats = None tt_full_text_mask_expand_11SD = None @@ -476,15 +479,72 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas transformation_mats, ) - def process_output_prefill(self, logits, B, S): + def copy_host_to_device(self, host_tensors, device_tensors=None): + """ + Helper function which copies host tensors to device tensors + """ + if device_tensors is None: + ret = [] + for i in range(len(host_tensors)): + on_device = ttnn.to_device(host_tensors[i], device=self.mesh_device) + ret.append(on_device) + return ret + else: + for i in range(len(host_tensors)): + ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i]) + return device_tensors + + def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): + """ + Does any transformations on device tensors which are necessary before ttnn_decode_forward + """ + print("transforming xattn mask") + assert ( + B == self.configuration.max_batch_size + ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" + S = 1 + + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) + tt_xattn_mask = ttnn.reshape( + tt_xattn_mask, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + tt_xattn_mask.shape[-1], + ], + [S, B, 32, tt_xattn_mask.shape[-1]], + ), + ) + tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) + tt_full_text_mask_expand_1NSH = ttnn.reshape( + tt_full_text_mask_expand_1NSH, + shape=ttnn.Shape( + [ + S, + B, + self.configuration.n_heads // self.configuration.num_devices, + self.configuration.head_dim, + ], + [ + S, + B, + 32, + self.configuration.head_dim, + ], + ), + ) + + return (tt_xattn_mask, tt_full_text_mask_expand_1NSH) + + def process_output_prefill(self, tt_out, B, S): padded_seq_len = _get_padded_prefill_seqlen(S) - tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() tt_out = tt_out[0].reshape(B, padded_seq_len, -1)[:, :S, :] return tt_out - def process_output_decode(self, logits, B, S): - tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + def process_output_decode(self, tt_out, B, S): tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() tt_out = tt_out[:, :, :B, :].reshape(B, S, -1) return tt_out @@ -538,9 +598,10 @@ def forward( mode=mode, text_only_inference=text_only_inference, ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) output_fn = self.process_output_decode if mode == "decode" else self.process_output_prefill - return output_fn(logits, B, S) + return output_fn(tt_out, B, S) def ttnn_prefill_forward( self, @@ -557,7 +618,7 @@ def ttnn_prefill_forward( """ This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. """ - return self.text_model.forward( + logits = self.text_model.forward( h, xattn_mask=xattn_mask, full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, @@ -569,6 +630,8 @@ def ttnn_prefill_forward( user_id=user_id, mode="prefill", ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + return tt_out def ttnn_decode_forward( self, @@ -578,12 +641,11 @@ def ttnn_decode_forward( xattn_caches, position_id, rot_mats, - transformation_mats, ): """ This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors. """ - return self.text_model.forward( + logits = self.text_model.forward( h, xattn_mask=xattn_mask, full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, @@ -591,9 +653,10 @@ def ttnn_decode_forward( xattn_caches=xattn_caches, current_pos=position_id, rot_mat=rot_mats, - transformation_mats=transformation_mats, mode="decode", ) + tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) + return tt_out def _stack_images(