From 49384c156d6692510e944a12edfd75408a50a8de Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Fri, 27 Sep 2024 13:36:15 +0000 Subject: [PATCH] #13180: change attention to use default sdpa decode to simplify args (removed cur_pos_attn) --- models/demos/wormhole/llama31_8b/demo/demo.py | 5 +---- .../llama31_8b/demo/demo_continuous_batching.py | 3 +-- .../demos/wormhole/llama31_8b/demo/demo_trace.py | 12 ++---------- .../wormhole/llama31_8b/demo/demo_with_prefill.py | 7 +------ .../llama31_8b/tests/test_llama_attention.py | 5 +---- .../tests/test_llama_attention_prefill.py | 2 +- .../llama31_8b/tests/test_llama_decoder.py | 5 +---- .../tests/test_llama_decoder_prefill.py | 2 +- .../wormhole/llama31_8b/tests/test_llama_model.py | 5 +---- .../llama31_8b/tests/test_llama_model_prefill.py | 2 +- .../wormhole/llama31_8b/tests/test_llama_perf.py | 5 +---- .../wormhole/llama31_8b/tt/llama_attention.py | 15 +++++---------- .../demos/wormhole/llama31_8b/tt/llama_decoder.py | 2 -- .../demos/wormhole/llama31_8b/tt/llama_model.py | 3 +-- 14 files changed, 18 insertions(+), 55 deletions(-) diff --git a/models/demos/wormhole/llama31_8b/demo/demo.py b/models/demos/wormhole/llama31_8b/demo/demo.py index daf0f5f91848..b7377b9417de 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo.py +++ b/models/demos/wormhole/llama31_8b/demo/demo.py @@ -161,12 +161,9 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env): ) current_pos_tensor = ttnn.from_torch(torch.tensor([curr_pos] * batch_size), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([curr_pos] * batch_size * 8), device=device, dtype=ttnn.int32 - ) # Run ttnn llama model - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) # Get model output tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) # Row-major layout diff --git a/models/demos/wormhole/llama31_8b/demo/demo_continuous_batching.py b/models/demos/wormhole/llama31_8b/demo/demo_continuous_batching.py index 39dbc1446467..fec61f0ff393 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_continuous_batching.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_continuous_batching.py @@ -322,7 +322,6 @@ def run_decode( tt_out = model( prefill_input, 0, # Current position - None, rot_mats_prefill, transformation_mats, user_id=batch_idx, @@ -351,7 +350,7 @@ def run_decode( model_args, tt_args.device, tokens_tensor, tt_embed, host_embed, indices_tensor ) logger.info(f"Decoding batch with indices {batch_token_indices}") - tt_out = model(decode_input, current_pos_tensor, current_pos_tensor, rot_mat=current_rot_mat) + tt_out = model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out) ttnn.deallocate(current_rot_mat) diff --git a/models/demos/wormhole/llama31_8b/demo/demo_trace.py b/models/demos/wormhole/llama31_8b/demo/demo_trace.py index cc42ebab512e..1fbc2e0fd10e 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_trace.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_trace.py @@ -245,7 +245,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num tt_out = tt_model( prefill_input, None, # Current position - None, # Current position for attention rot_mats_prefill, transformation_mats, user_id=batch_id, @@ -287,13 +286,10 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num write_event = ttnn.create_event(device) current_pos = ttnn.from_torch(torch.tensor(decoding_pos, dtype=torch.int32), device=device, dtype=ttnn.int32) - current_pos_attn = ttnn.from_torch( - torch.tensor(decoding_pos * 8, dtype=torch.int32), device=device, dtype=ttnn.int32 - ) # Compile the trace (dry run of the model) decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out) tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) @@ -301,13 +297,12 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) ttnn.plus_one(current_pos) - ttnn.plus_one(current_pos_attn) # Capture Trace trace_id = ttnn.begin_trace_capture(device, cq_id=0) decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) ttnn.deallocate(tt_out) tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) @@ -315,13 +310,11 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) ttnn.plus_one(current_pos) - ttnn.plus_one(current_pos_attn) ttnn.end_trace_capture(device, trace_id, cq_id=0) # Reset the decoding position for the proper run of the model current_pos_reset = ttnn.from_torch(torch.tensor(decoding_pos, dtype=torch.int32), dtype=ttnn.int32) - current_pos_attn_reset = ttnn.from_torch(torch.tensor(decoding_pos * 8, dtype=torch.int32), dtype=ttnn.int32) tt_out_tok_reset = ttnn.from_torch( torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), dtype=ttnn.uint32, @@ -329,7 +322,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num # Update the resetted tensors on device ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos) - ttnn.copy_host_to_device_tensor(current_pos_attn_reset, current_pos_attn) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) # Start decoding diff --git a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py index 5b09d07b0af9..4103d8b097f5 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py @@ -273,7 +273,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num tt_out = tt_model( prefill_input, 0, # Current position - None, rot_mats_prefill, transformation_mats, user_id=batch_id, @@ -293,7 +292,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num tt_out = tt_model( prefill_input, 0, # Current position - 0, rot_mats_prefill, transformation_mats, user_id=batch_id, @@ -344,15 +342,12 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num tt_model.device, ) current_pos_tensor = ttnn.from_torch(torch.tensor([curr_pos] * batch_size), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([curr_pos] * batch_size * 8), device=device, dtype=ttnn.int32 - ) profiler.end(f"prepare_input_decode", iteration=batch_idx) profiler.start(f"decode_and_argmax", iteration=batch_idx) # Run ttnn llama3.1 model - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) # Get model output tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_attention.py b/models/demos/wormhole/llama31_8b/tests/test_llama_attention.py index 0ebc7bbc08be..3dabb7329d71 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_attention.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_attention.py @@ -62,9 +62,6 @@ def test_llama_attention_inference(device, use_program_cache, reset_seeds): tt_attention_input = pt_attention_input.clone() current_pos = generation_start_pos + i current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32 - ) attention_input = prepare_inputs_ttnn( tt_attention_input, @@ -72,7 +69,7 @@ def test_llama_attention_inference(device, use_program_cache, reset_seeds): device, ) - tt_out = tt_model([attention_input], current_pos_tensor, current_pos_attn_tensor, rot_mats=current_rot_mat) + tt_out = tt_model([attention_input], current_pos_tensor, rot_mats=current_rot_mat) # multi-device attention module returns replicated output assert isinstance(tt_out, list) tt_out = tt_out[0] diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_attention_prefill.py b/models/demos/wormhole/llama31_8b/tests/test_llama_attention_prefill.py index b265777aba3d..2ae7ac1c6f29 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_attention_prefill.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_attention_prefill.py @@ -70,7 +70,7 @@ def test_llama_attention_inference(seq_len, device, use_program_cache, reset_see device, ) - tt_out = tt_model([attention_input], 0, None, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model([attention_input], 0, rot_mats, transformation_mats, user_id=0, mode="prefill") # multi-device attention module returns replicated output assert isinstance(tt_out, list) tt_out = tt_out[0] diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_decoder.py b/models/demos/wormhole/llama31_8b/tests/test_llama_decoder.py index 30638c951253..b3c21e156f5b 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_decoder.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_decoder.py @@ -68,9 +68,6 @@ def test_llama_decoder_inference(device, use_program_cache, reset_seeds): tt_decode_input = pt_decode_input.clone() current_pos = generation_start_pos + i current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32 - ) decode_input = prepare_inputs_ttnn( tt_decode_input, @@ -79,7 +76,7 @@ def test_llama_decoder_inference(device, use_program_cache, reset_seeds): ) # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) tt_output_torch = ( ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] ) # [seq, batch, hidden_dim] diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_decoder_prefill.py b/models/demos/wormhole/llama31_8b/tests/test_llama_decoder_prefill.py index 80b39dd6c6e3..5a7d3aca2a8c 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_decoder_prefill.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_decoder_prefill.py @@ -86,7 +86,7 @@ def test_llama_decoder_inference(device, seq_len, use_program_cache, reset_seeds attn_mask_torch = torch.triu(attn_mask, diagonal=1) ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch) # Run TT model - tt_out = tt_model(decode_input, None, None, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=0, mode="prefill") tt_output_torch = ttnn.to_torch(tt_out).view(batch, seq_len, -1) # [seq, batch, hidden_dim] passing, pcc_message = comp_pcc(ref_output, tt_output_torch) diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_model.py b/models/demos/wormhole/llama31_8b/tests/test_llama_model.py index 9cf4da67e2ae..16caa4c4145f 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_model.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_model.py @@ -124,12 +124,9 @@ def test_llama_model_inference(device, weights, layers, use_program_cache, reset tt_model.device, ) current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32 - ) # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) # Convert ttnn tensor to torch tensor tt_output_torch = ( ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :] diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_model_prefill.py b/models/demos/wormhole/llama31_8b/tests/test_llama_model_prefill.py index 086c10c15396..4628cae5be80 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_model_prefill.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_model_prefill.py @@ -123,7 +123,7 @@ def test_llama_model_inference(device, seq_len, use_program_cache, reset_seeds): for i in range(1): start_pos = 0 # Run TT model - tt_out = tt_model(decode_input, None, None, rot_mats, transformation_mats, user_id=i, mode="prefill") + tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=i, mode="prefill") # Convert ttnn tensor to torch tensor tt_output_torch = ttnn.to_torch(tt_out).view(batch, seq_len, -1) # [seq, batch, hidden_dim] diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py b/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py index dce87212d110..81f1a03d11f6 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py @@ -148,13 +148,10 @@ def run_inference(device, tt_model, tt_embd, embd, encoded_prompts, generation_s ) current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32 - ) # Run TT model profiler.start(f"model_run_for_inference_{i}") - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) + tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) # Convert ttnn tensor to torch tensor profiler.start(f"result_wait_for_inference_{i}") diff --git a/models/demos/wormhole/llama31_8b/tt/llama_attention.py b/models/demos/wormhole/llama31_8b/tt/llama_attention.py index 134730d58e27..0d4d1c788e9b 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_attention.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_attention.py @@ -219,13 +219,11 @@ def forward_decode( self, xs: List[ttnn.Tensor], current_pos, - current_pos_attn, rot_mat=None, ) -> ttnn.Tensor: """ x: (seq_len, 1, batch, hidden_dim) current_pos: (batch_size), current token position in the sequence for each user - current_pos_attn: (batch_size * kv_heads[8]), current token position in the sequence for each KV_head (Required for SDPA_decode) """ dense_outputs = [] for i in range(self.num_devices): @@ -279,7 +277,7 @@ def forward_decode( q_heads_pre_rot_1BQD, rot_mat, program_config=self.model_config["ROT_MAT_BMM_PROGCFG"], - memory_config=ttnn.DRAM_MEMORY_CONFIG, + memory_config=q_heads_pre_rot_1BQD.memory_config(), compute_kernel_config=self.compute_kernel_config_hifi2, dtype=ttnn.bfloat16, ) @@ -312,12 +310,11 @@ def forward_decode( ttnn.deallocate(k_heads_1BKD) ttnn.deallocate(v_heads_1BKD) - attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode_gqa( + attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( q_heads_1BQD, keys, values, - cur_pos_tensor=current_pos_attn, - transpose_q=False, + cur_pos_tensor=current_pos, scale=self.scale, program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"], @@ -482,10 +479,8 @@ def forward_prefill(self, xs_11SH, rot_mats, transformation_mats, user_id: int = attn_output_11SH.deallocate(True) return [output_11SH] - def forward( - self, xs, current_pos, current_pos_attn=None, rot_mats=None, transformation_mats=None, user_id=0, mode="decode" - ): + def forward(self, xs, current_pos, rot_mats=None, transformation_mats=None, user_id=0, mode="decode"): if mode == "prefill": return self.forward_prefill(xs[0], rot_mats, transformation_mats, user_id) else: - return self.forward_decode(xs, current_pos, current_pos_attn, rot_mats) + return self.forward_decode(xs, current_pos, rot_mats) diff --git a/models/demos/wormhole/llama31_8b/tt/llama_decoder.py b/models/demos/wormhole/llama31_8b/tt/llama_decoder.py index a40a9e9696c4..71e0333b06f4 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_decoder.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_decoder.py @@ -73,7 +73,6 @@ def forward( self, x: ttnn.Tensor, current_pos, - current_pos_attn, rot_mat=None, transformation_mats=None, user_id=0, @@ -88,7 +87,6 @@ def forward( r = self.attention.forward( [attn_norm], current_pos, - current_pos_attn, rot_mat, transformation_mats, user_id, diff --git a/models/demos/wormhole/llama31_8b/tt/llama_model.py b/models/demos/wormhole/llama31_8b/tt/llama_model.py index 5e2582e41105..21a9a792cf1b 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_model.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_model.py @@ -68,7 +68,6 @@ def forward( self, x: ttnn.Tensor, current_pos, - current_pos_attn, rot_mat=None, transformation_mats=None, user_id=0, @@ -76,7 +75,7 @@ def forward( get_last_token=-1, ): for layer in self.layers: - x = layer(x, current_pos, current_pos_attn, rot_mat, transformation_mats, user_id, mode) + x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode) if mode == "prefill" and get_last_token == -1: return x