Skip to content

Commit

Permalink
#4: remove unneccesary ops from c2psa,detect sub-modules
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Feb 28, 2025
1 parent 86f2997 commit 883e415
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def attempt_load(weights, map_location=None):

@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_yolov11(device, use_program_cache, reset_seeds):
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=224, input_width=224)
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=640, input_width=640)

torch_model = attempt_load("yolov11n.pt", map_location="cpu")
state_dict = torch_model.state_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def make_anchors(device, feats, strides, grid_cell_offset=0.5):
b = torch.cat(stride_tensor).transpose(0, 1)

return (
ttnn.from_torch(a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device),
ttnn.from_torch(b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device),
ttnn.from_torch(
a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
),
ttnn.from_torch(
b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
),
)


Expand Down
50 changes: 26 additions & 24 deletions models/experimental/functional_yolov11/tt/ttnn_yolov11.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,9 @@ def __call__(self, device, x, batch_size=1):
qkv = self.qkv(device, x) # [1, 1, 49[64], 256]
qkv = ttnn.sharded_to_interleaved(qkv, memory_config=ttnn.L1_MEMORY_CONFIG)
qkv = ttnn.permute(qkv, (0, 3, 1, 2)) # [1,256,1,49]
print("before qkv details", qkv.shape, qkv.layout, qkv.dtype)
# qkv = ttnn.to_torch(qkv)
# qkv = ttnn.from_torch(
# qkv, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG
# )
qkv = ttnn.to_layout(qkv, layout=ttnn.ROW_MAJOR_LAYOUT)
qkv = ttnn.to_dtype(qkv, ttnn.bfloat16)
qkv = ttnn.to_layout(qkv, layout=ttnn.TILE_LAYOUT)
print("after qkv details", qkv.shape, qkv.layout, qkv.dtype)
qkv = ttnn.reshape(
qkv, (batch_size, self.num_heads, self.key_dim * 2 + self.head_dim, qkv.shape[-1])
) # [1,2,128,49]
Expand All @@ -386,10 +380,10 @@ def __call__(self, device, x, batch_size=1):
v = ttnn.reshape(v, (1, 1, (v.shape[0] * v.shape[1] * v.shape[2]), v.shape[3])) # [1,1,128, 49[64]]
v = ttnn.permute(v, (0, 1, 3, 2))
x2 = self.pe(device=device, x=v)
x2 = ttnn.sharded_to_interleaved(x2, memory_config=ttnn.L1_MEMORY_CONFIG)

x = x1 + x2
# x2 = ttnn.sharded_to_interleaved(x2, memory_config=ttnn.L1_MEMORY_CONFIG)

# x = x1 + x2
x = ttnn.add(x1, x2, memory_config=x2.memory_config())
x = self.proj(device=device, x=x)
ttnn.deallocate(x1)
ttnn.deallocate(qkv)
Expand Down Expand Up @@ -446,11 +440,13 @@ def __init__(self, device, parameter, conv_pt):
def __call__(self, device, x):
x1 = x
x = self.attn(device, x)
x = x1 + x
# x = x1 + x
x = ttnn.add(x1, x, memory_config=x.memory_config())
x1 = x
x = self.ffn_conv1(device, x)
x = self.ffn_conv2(device, x)
return x + x1
x = ttnn.add(x, x1, memory_config=x1.memory_config())
return x


class C2PSA:
Expand Down Expand Up @@ -555,7 +551,8 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
y3_reshaped = ttnn.reshape(y3, (y3.shape[0], y3.shape[2], y3.shape[-1])) # 0.993

# y_all = [y1_reshaped, y2_reshaped, y3_reshaped]
y = ttnn.concat((y1_reshaped, y2_reshaped, y3_reshaped), dim=1, memory_config=ttnn.L1_MEMORY_CONFIG) # 0.998
y = ttnn.concat((y1, y2, y3), dim=2, memory_config=ttnn.L1_MEMORY_CONFIG) # 0.998
y = ttnn.squeeze(y, dim=0)
ya, yb = y[:, :, :64], y[:, :, 64:144] # 0.991, 0.97
ttnn.deallocate(y1)
ttnn.deallocate(y2)
Expand All @@ -566,19 +563,22 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
ttnn.deallocate(x4)
ttnn.deallocate(x5)
ttnn.deallocate(x6)
ttnn.deallocate(y1_reshaped)
ttnn.deallocate(y2_reshaped)
ttnn.deallocate(y3_reshaped)
# ttnn.deallocate(y1_reshaped)
# ttnn.deallocate(y2_reshaped)
# ttnn.deallocate(y3_reshaped)
ttnn.deallocate(y)
ya = ttnn.reallocate(ya)
yb = ttnn.reallocate(yb)
ya = ttnn.permute(ya, (0, 2, 1))
print("before reshape", ya.shape, ya.layout, ya.memory_config(), ya.dtype)
ya = ttnn.reshape(ya, (ya.shape[0], 4, 16, ya.shape[2]))
ya = ttnn.permute(ya, (0, 2, 1, 3)) # 0.991
ya = ttnn.to_layout(ya, ttnn.TILE_LAYOUT)
ya = ttnn.softmax(ya, dim=1)
ya = ttnn.permute(ya, (0, 2, 3, 1))
# ya = ttnn.permute(ya, (0, 2, 1))
# print("before reshape", ya.shape, ya.layout, ya.memory_config(), ya.dtype)
# ya = ttnn.reshape(ya, (ya.shape[0], 4, 16, ya.shape[2]))
# ya = ttnn.permute(ya, (0, 2, 1, 3)) # 0.991
# ya = ttnn.to_layout(ya, ttnn.TILE_LAYOUT)
# ya = ttnn.softmax(ya, dim=1)
# ya = ttnn.permute(ya, (0, 2, 3, 1))
ya = ttnn.reshape(ya, (ya.shape[0], y.shape[1], 4, 16))
ya = ttnn.softmax(ya, dim=-1)
ya = ttnn.permute(ya, (0, 2, 1, 3))
c = self.dfl(ya) # 0.968
ttnn.deallocate(ya)
c = ttnn.sharded_to_interleaved(c, memory_config=ttnn.L1_MEMORY_CONFIG)
Expand All @@ -587,11 +587,13 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
c = ttnn.permute(c, (0, 3, 1, 2))
c = ttnn.reshape(c, (c.shape[0], 1, 4, int(c.shape[3] / 4)))
c = ttnn.reshape(c, (c.shape[0], c.shape[1] * c.shape[2], c.shape[3]))
# c = ttnn.reshape(c,(c.shape[0],c.shape[1],4,int(c.shape[2] / 4)))
# c = ttnn.squeeze(c,dim=0)
c1, c2 = c[:, :2, :], c[:, 2:4, :]

anchor, strides = self.anchors, self.strides
anchor = ttnn.to_memory_config(anchor, memory_config=ttnn.L1_MEMORY_CONFIG)
strides = ttnn.to_memory_config(strides, memory_config=ttnn.L1_MEMORY_CONFIG)
# anchor = ttnn.to_memory_config(anchor, memory_config=ttnn.L1_MEMORY_CONFIG)
# strides = ttnn.to_memory_config(strides, memory_config=ttnn.L1_MEMORY_CONFIG)
c1 = ttnn.to_layout(c1, layout=ttnn.TILE_LAYOUT)
c2 = ttnn.to_layout(c2, layout=ttnn.TILE_LAYOUT)

Expand Down

0 comments on commit 883e415

Please sign in to comment.