Skip to content

Commit

Permalink
#10899: fix BERT memory overflow and layernorm shape mismatch issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Jul 30, 2024
1 parent 29b6071 commit 80330e1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 9 additions & 4 deletions models/demos/metal_BERT_large_11/tt/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ def __call__(
embeddings_type=ttnn.EmbeddingsType.BINARY,
memory_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"],
)
token_type_embeddings = ttnn.reshape(
token_type_embeddings,
[token_type_embeddings.shape[0], 1, token_type_embeddings.shape[1], token_type_embeddings.shape[2]],
)
token_type_ids.deallocate()

if self.position_embedding_type == "absolute":
Expand All @@ -184,6 +180,15 @@ def __call__(
position_embeddings_tt_tensor.shape[2],
],
)
inputs_plus_token_type_embeddings_tt_tensor = ttnn.reshape(
inputs_plus_token_type_embeddings_tt_tensor,
[
inputs_plus_token_type_embeddings_tt_tensor.shape[0],
1,
inputs_plus_token_type_embeddings_tt_tensor.shape[1],
inputs_plus_token_type_embeddings_tt_tensor.shape[2],
],
)
# Deallocate inputs_embeds and token_type_embeddings here to avoid having to move final output
if self.model_config["DEALLOC_INPUT_EMBEDS_AFTER_POSITION_EMBEDS"]:
inputs_embeds.deallocate()
Expand Down
2 changes: 2 additions & 0 deletions models/demos/metal_BERT_large_11/tt/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def op1_qkv_fused(activation, qkv_weight, qkv_bias):
return qkv

grid_size = model_config.get("GRID_SIZE", device.compute_with_storage_grid_size())
if type(grid_size) == list:
grid_size = tt_lib.tensor.CoreCoord(tuple(grid_size))

def op2_create_qkv_heads(qkv):
(
Expand Down

0 comments on commit 80330e1

Please sign in to comment.