Skip to content

Commit

Permalink
#13180: change attention to use default sdpa decode to simplify args …
Browse files Browse the repository at this point in the history
…(removed cur_pos_attn)
  • Loading branch information
caixunshiren committed Sep 27, 2024
1 parent af54a28 commit 49384c1
Show file tree
Hide file tree
Showing 14 changed files with 18 additions and 55 deletions.
5 changes: 1 addition & 4 deletions models/demos/wormhole/llama31_8b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,9 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env):
)

current_pos_tensor = ttnn.from_torch(torch.tensor([curr_pos] * batch_size), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([curr_pos] * batch_size * 8), device=device, dtype=ttnn.int32
)

# Run ttnn llama model
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)

# Get model output
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) # Row-major layout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def run_decode(
tt_out = model(
prefill_input,
0, # Current position
None,
rot_mats_prefill,
transformation_mats,
user_id=batch_idx,
Expand Down Expand Up @@ -351,7 +350,7 @@ def run_decode(
model_args, tt_args.device, tokens_tensor, tt_embed, host_embed, indices_tensor
)
logger.info(f"Decoding batch with indices {batch_token_indices}")
tt_out = model(decode_input, current_pos_tensor, current_pos_tensor, rot_mat=current_rot_mat)
tt_out = model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
ttnn.deallocate(current_rot_mat)
Expand Down
12 changes: 2 additions & 10 deletions models/demos/wormhole/llama31_8b/demo/demo_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
tt_out = tt_model(
prefill_input,
None, # Current position
None, # Current position for attention
rot_mats_prefill,
transformation_mats,
user_id=batch_id,
Expand Down Expand Up @@ -287,49 +286,42 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
write_event = ttnn.create_event(device)

current_pos = ttnn.from_torch(torch.tensor(decoding_pos, dtype=torch.int32), device=device, dtype=ttnn.int32)
current_pos_attn = ttnn.from_torch(
torch.tensor(decoding_pos * 8, dtype=torch.int32), device=device, dtype=ttnn.int32
)

# Compile the trace (dry run of the model)
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok)
ttnn.deallocate(tt_out_rm)
new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat)
current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat)
ttnn.plus_one(current_pos)
ttnn.plus_one(current_pos_attn)

# Capture Trace
trace_id = ttnn.begin_trace_capture(device, cq_id=0)

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok)
ttnn.deallocate(tt_out_rm)
new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat)
current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat)
ttnn.plus_one(current_pos)
ttnn.plus_one(current_pos_attn)

ttnn.end_trace_capture(device, trace_id, cq_id=0)

# Reset the decoding position for the proper run of the model
current_pos_reset = ttnn.from_torch(torch.tensor(decoding_pos, dtype=torch.int32), dtype=ttnn.int32)
current_pos_attn_reset = ttnn.from_torch(torch.tensor(decoding_pos * 8, dtype=torch.int32), dtype=ttnn.int32)
tt_out_tok_reset = ttnn.from_torch(
torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0),
dtype=ttnn.uint32,
)

# Update the resetted tensors on device
ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos)
ttnn.copy_host_to_device_tensor(current_pos_attn_reset, current_pos_attn)
ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok)

# Start decoding
Expand Down
7 changes: 1 addition & 6 deletions models/demos/wormhole/llama31_8b/demo/demo_with_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
tt_out = tt_model(
prefill_input,
0, # Current position
None,
rot_mats_prefill,
transformation_mats,
user_id=batch_id,
Expand All @@ -293,7 +292,6 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
tt_out = tt_model(
prefill_input,
0, # Current position
0,
rot_mats_prefill,
transformation_mats,
user_id=batch_id,
Expand Down Expand Up @@ -344,15 +342,12 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
tt_model.device,
)
current_pos_tensor = ttnn.from_torch(torch.tensor([curr_pos] * batch_size), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([curr_pos] * batch_size * 8), device=device, dtype=ttnn.int32
)

profiler.end(f"prepare_input_decode", iteration=batch_idx)

profiler.start(f"decode_and_argmax", iteration=batch_idx)
# Run ttnn llama3.1 model
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)

# Get model output
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,14 @@ def test_llama_attention_inference(device, use_program_cache, reset_seeds):
tt_attention_input = pt_attention_input.clone()
current_pos = generation_start_pos + i
current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32
)

attention_input = prepare_inputs_ttnn(
tt_attention_input,
model_args.dim,
device,
)

tt_out = tt_model([attention_input], current_pos_tensor, current_pos_attn_tensor, rot_mats=current_rot_mat)
tt_out = tt_model([attention_input], current_pos_tensor, rot_mats=current_rot_mat)
# multi-device attention module returns replicated output
assert isinstance(tt_out, list)
tt_out = tt_out[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_llama_attention_inference(seq_len, device, use_program_cache, reset_see
device,
)

tt_out = tt_model([attention_input], 0, None, rot_mats, transformation_mats, user_id=0, mode="prefill")
tt_out = tt_model([attention_input], 0, rot_mats, transformation_mats, user_id=0, mode="prefill")
# multi-device attention module returns replicated output
assert isinstance(tt_out, list)
tt_out = tt_out[0]
Expand Down
5 changes: 1 addition & 4 deletions models/demos/wormhole/llama31_8b/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def test_llama_decoder_inference(device, use_program_cache, reset_seeds):
tt_decode_input = pt_decode_input.clone()
current_pos = generation_start_pos + i
current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32
)

decode_input = prepare_inputs_ttnn(
tt_decode_input,
Expand All @@ -79,7 +76,7 @@ def test_llama_decoder_inference(device, use_program_cache, reset_seeds):
)

# Run TT model
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)
tt_output_torch = (
ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :]
) # [seq, batch, hidden_dim]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_llama_decoder_inference(device, seq_len, use_program_cache, reset_seeds
attn_mask_torch = torch.triu(attn_mask, diagonal=1)
ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch)
# Run TT model
tt_out = tt_model(decode_input, None, None, rot_mats, transformation_mats, user_id=0, mode="prefill")
tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=0, mode="prefill")
tt_output_torch = ttnn.to_torch(tt_out).view(batch, seq_len, -1) # [seq, batch, hidden_dim]
passing, pcc_message = comp_pcc(ref_output, tt_output_torch)

Expand Down
5 changes: 1 addition & 4 deletions models/demos/wormhole/llama31_8b/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,9 @@ def test_llama_model_inference(device, weights, layers, use_program_cache, reset
tt_model.device,
)
current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32
)

# Run TT model
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)
# Convert ttnn tensor to torch tensor
tt_output_torch = (
ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1)[: model_args.max_batch_size, :, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_llama_model_inference(device, seq_len, use_program_cache, reset_seeds):
for i in range(1):
start_pos = 0
# Run TT model
tt_out = tt_model(decode_input, None, None, rot_mats, transformation_mats, user_id=i, mode="prefill")
tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=i, mode="prefill")
# Convert ttnn tensor to torch tensor
tt_output_torch = ttnn.to_torch(tt_out).view(batch, seq_len, -1) # [seq, batch, hidden_dim]

Expand Down
5 changes: 1 addition & 4 deletions models/demos/wormhole/llama31_8b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,10 @@ def run_inference(device, tt_model, tt_embd, embd, encoded_prompts, generation_s
)

current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32
)

# Run TT model
profiler.start(f"model_run_for_inference_{i}")
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)
tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat)

# Convert ttnn tensor to torch tensor
profiler.start(f"result_wait_for_inference_{i}")
Expand Down
15 changes: 5 additions & 10 deletions models/demos/wormhole/llama31_8b/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,11 @@ def forward_decode(
self,
xs: List[ttnn.Tensor],
current_pos,
current_pos_attn,
rot_mat=None,
) -> ttnn.Tensor:
"""
x: (seq_len, 1, batch, hidden_dim)
current_pos: (batch_size), current token position in the sequence for each user
current_pos_attn: (batch_size * kv_heads[8]), current token position in the sequence for each KV_head (Required for SDPA_decode)
"""
dense_outputs = []
for i in range(self.num_devices):
Expand Down Expand Up @@ -279,7 +277,7 @@ def forward_decode(
q_heads_pre_rot_1BQD,
rot_mat,
program_config=self.model_config["ROT_MAT_BMM_PROGCFG"],
memory_config=ttnn.DRAM_MEMORY_CONFIG,
memory_config=q_heads_pre_rot_1BQD.memory_config(),
compute_kernel_config=self.compute_kernel_config_hifi2,
dtype=ttnn.bfloat16,
)
Expand Down Expand Up @@ -312,12 +310,11 @@ def forward_decode(
ttnn.deallocate(k_heads_1BKD)
ttnn.deallocate(v_heads_1BKD)

attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode_gqa(
attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode(
q_heads_1BQD,
keys,
values,
cur_pos_tensor=current_pos_attn,
transpose_q=False,
cur_pos_tensor=current_pos,
scale=self.scale,
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"],
Expand Down Expand Up @@ -482,10 +479,8 @@ def forward_prefill(self, xs_11SH, rot_mats, transformation_mats, user_id: int =
attn_output_11SH.deallocate(True)
return [output_11SH]

def forward(
self, xs, current_pos, current_pos_attn=None, rot_mats=None, transformation_mats=None, user_id=0, mode="decode"
):
def forward(self, xs, current_pos, rot_mats=None, transformation_mats=None, user_id=0, mode="decode"):
if mode == "prefill":
return self.forward_prefill(xs[0], rot_mats, transformation_mats, user_id)
else:
return self.forward_decode(xs, current_pos, current_pos_attn, rot_mats)
return self.forward_decode(xs, current_pos, rot_mats)
2 changes: 0 additions & 2 deletions models/demos/wormhole/llama31_8b/tt/llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def forward(
self,
x: ttnn.Tensor,
current_pos,
current_pos_attn,
rot_mat=None,
transformation_mats=None,
user_id=0,
Expand All @@ -88,7 +87,6 @@ def forward(
r = self.attention.forward(
[attn_norm],
current_pos,
current_pos_attn,
rot_mat,
transformation_mats,
user_id,
Expand Down
3 changes: 1 addition & 2 deletions models/demos/wormhole/llama31_8b/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ def forward(
self,
x: ttnn.Tensor,
current_pos,
current_pos_attn,
rot_mat=None,
transformation_mats=None,
user_id=0,
mode="decode",
get_last_token=-1,
):
for layer in self.layers:
x = layer(x, current_pos, current_pos_attn, rot_mat, transformation_mats, user_id, mode)
x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode)
if mode == "prefill" and get_last_token == -1:
return x

Expand Down

0 comments on commit 49384c1

Please sign in to comment.