Skip to content

Commit

Permalink
#14519: Refactored decode input preparation to separate host tensor c…
Browse files Browse the repository at this point in the history
…reation and device tensor transformations. Enabled tracing in simple_vision_demo with an easy trace function
  • Loading branch information
cglagovichTT committed Nov 7, 2024
1 parent a413e64 commit eca2988
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 58 deletions.
205 changes: 201 additions & 4 deletions models/demos/llama3/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def decode_forward(
_,
tt_position_id,
rot_mats,
transformation_mats,
_,
) = self.model.prepare_inputs_decode(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)
Expand All @@ -131,7 +131,6 @@ def decode_forward(
xattn_caches,
tt_position_id,
rot_mats,
transformation_mats,
)

logits = self.model.process_output_decode(tt_logits, B, S)
Expand All @@ -148,7 +147,171 @@ def capture_trace(
"""
Captures a trace for the decode_forward method.
"""
pass
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
_,
tt_position_id,
rot_mats,
_,
) = self.model.prepare_inputs_decode(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)

# Compile run
tt_logits_rm = self.model.ttnn_decode_forward(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
xattn_caches,
tt_position_id,
rot_mats,
)

# Get inputs ready for trace run
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
_,
tt_position_id,
rot_mats,
_,
) = self.model.prepare_decode_inputs_host(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id
)

(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
rot_mats,
) = self.model.copy_host_to_device(
(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)
)

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
B = tokens.shape[0]
# Do on-device transformations of inputs before forward
tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device(
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
B=B,
)

tt_logits_rm = self.model.ttnn_decode_forward(
tt_h,
tt_xattn_mask_transform,
tt_full_text_mask_expand_1NSH_transform,
xattn_caches,
tt_position_id,
rot_mats,
)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)

return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats

def decode_forward_trace(
self,
position_id,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches, # TODO: unused since captured in trace?
trace_id,
trace_logits_rm,
trace_h,
trace_xattn_mask,
trace_full_text_mask_expand_1NSH,
trace_position_id,
trace_rot_mats,
):
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
_,
tt_position_id,
rot_mats,
_,
) = self.model.prepare_decode_inputs_host(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)

self.model.copy_host_to_device(
host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats),
device_tensors=(
trace_h,
trace_xattn_mask,
trace_full_text_mask_expand_1NSH,
trace_position_id,
trace_rot_mats,
),
)

ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)

B, S = tokens.shape
logits = self.model.process_output_decode(trace_logits_rm, B=B, S=S)

return logits

def easy_trace(
self,
position_id,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
):
"""
Tracing is easy! Just call this method and you'll run traced
"""
if not hasattr(self, "trace_id"):
(
trace_id,
tt_logits_rm,
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
rot_mats,
) = self.capture_trace(
position_id,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
self.trace_id = trace_id
self.trace_inputs = {
"tt_h": tt_h,
"tt_xattn_mask": tt_xattn_mask,
"tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH,
"tt_position_id": tt_position_id,
"rot_mats": rot_mats,
}
self.trace_outputs = {
"tt_logits_rm": tt_logits_rm,
}

return self.decode_forward_trace(
position_id,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
self.trace_id,
self.trace_outputs["tt_logits_rm"],
self.trace_inputs["tt_h"],
self.trace_inputs["tt_xattn_mask"],
self.trace_inputs["tt_full_text_mask_expand_1NSH"],
self.trace_inputs["tt_position_id"],
self.trace_inputs["rot_mats"],
)


def get_sampler(temperature, top_p, tokenizer):
Expand Down Expand Up @@ -207,6 +370,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn
"normal",
],
)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True)
def test_llama_multimodal_demo_text(
mesh_device,
warmup_iters,
Expand Down Expand Up @@ -300,17 +464,48 @@ def test_llama_multimodal_demo_text(

decode_times = []

# Capture trace
# next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S
# trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats = model.capture_trace(
# prefill_len,
# next_token_tensor,
# cross_attention_masks,
# full_text_row_masked_out_mask,
# xattn_caches,
# )

for gen_idx in range(max_gen_len - 1):
decode_start = time.perf_counter()
position_id = prefill_len + gen_idx
next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S
logits = model.decode_forward(
# logits = model.decode_forward(
# position_id,
# next_token_tensor,
# cross_attention_masks,
# full_text_row_masked_out_mask,
# xattn_caches,
# )
logits = model.easy_trace(
position_id,
next_token_tensor,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
# logits = model.decode_forward_trace(
# position_id,
# next_token_tensor,
# cross_attention_masks,
# full_text_row_masked_out_mask,
# xattn_caches,
# trace_id,
# tt_logits_rm,
# tt_h,
# tt_xattn_mask,
# tt_full_text_mask_expand_1NSH,
# tt_position_id,
# rot_mats
# )
next_token, text = sampler(logits)
# Update next token
tokens[0, position_id + 1] = next_token
Expand All @@ -334,3 +529,5 @@ def test_llama_multimodal_demo_text(
logger.info(f"Prefill time: {prefill_time_ms:.2f} ms")
decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000
logger.info(f"Decode time: {decode_time_ms:.2f} ms")

# ttnn.release_trace(model.mesh_device, trace_id)
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def ccl_topology(self):
return ttnn.Topology.Linear
return None

def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False):
def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
"""
Prepare inputs for decode mode.
x: (batch, seq, dim)
Expand Down Expand Up @@ -665,11 +665,11 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False):
if torch.is_tensor(x):
x = ttnn.from_torch(
x,
device=self.mesh_device,
device=self.mesh_device if not on_host else None,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=mesh_mapper,
memory_config=input_mem_cfg,
memory_config=input_mem_cfg if not on_host else None,
)
else: # Convert the row major layout from embedding back to tile layout
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
Expand Down
Loading

0 comments on commit eca2988

Please sign in to comment.