Skip to content

Commit

Permalink
#5: optimisation in yolo main module(ops removal, sharding)
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Feb 28, 2025
1 parent 883e415 commit bf59ace
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def attempt_load(weights, map_location=None):

@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_yolov11(device, use_program_cache, reset_seeds):
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=640, input_width=640)
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=224, input_width=224)

torch_model = attempt_load("yolov11n.pt", map_location="cpu")
state_dict = torch_model.state_dict()
Expand All @@ -79,11 +79,7 @@ def test_yolov11(device, use_program_cache, reset_seeds):
parameters = create_yolov11_model_parameters(torch_model, torch_input, device=device)
ttnn_model = ttnn_yolov11.YoloV11(device, parameters)
ttnn_output = ttnn_model(ttnn_input)
# l1 = torch.load("/home/ubuntu/venkatesh_yolov11/tt-metal/models/experimental/functional_yolov11/dumps/torch_out.pth")
# l1 = torch.load("/home/ubuntu/venkatesh_yolov11/tt-metal/models/experimental/functional_yolov11/dumps/tt_out.pth")
# assert_with_pcc(l1, l2, 0.99)
ttnn_output = ttnn.to_torch(ttnn_output)
# ttnn_output = ttnn_output.permute(0, 2, 1)
print(ttnn_output.shape, torch_output.shape)

ttnn_output = ttnn_output.reshape(torch_output.shape)
assert_with_pcc(torch_output, ttnn_output, 0.99999)
184 changes: 80 additions & 104 deletions models/experimental/functional_yolov11/tt/ttnn_yolov11.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def __call__(self, device, x):
ttnn.deallocate(m1)
ttnn.deallocate(m2)
ttnn.deallocate(m3)
# x = ttnn.reallocate(x)
return x


Expand Down Expand Up @@ -354,35 +353,30 @@ def __init__(self, device, parameter, conv_pt):
self.scale = self.key_dim**-0.5

def __call__(self, device, x, batch_size=1):
qkv = self.qkv(device, x) # [1, 1, 49[64], 256]
qkv = self.qkv(device, x)
qkv = ttnn.sharded_to_interleaved(qkv, memory_config=ttnn.L1_MEMORY_CONFIG)
qkv = ttnn.permute(qkv, (0, 3, 1, 2)) # [1,256,1,49]
qkv = ttnn.permute(qkv, (0, 3, 1, 2))
qkv = ttnn.to_layout(qkv, layout=ttnn.ROW_MAJOR_LAYOUT)
qkv = ttnn.to_dtype(qkv, ttnn.bfloat16)
qkv = ttnn.to_layout(qkv, layout=ttnn.TILE_LAYOUT)
qkv = ttnn.reshape(
qkv, (batch_size, self.num_heads, self.key_dim * 2 + self.head_dim, qkv.shape[-1])
) # [1,2,128,49]
qkv = ttnn.reshape(qkv, (batch_size, self.num_heads, self.key_dim * 2 + self.head_dim, qkv.shape[-1]))
q, k, v = (
qkv[:, :, : self.key_dim, :],
qkv[:, :, self.key_dim : self.head_dim, :],
qkv[:, :, self.head_dim :, :],
) # ttnn.Shape([1, 2, 32, 49[64]]) ttnn.Shape([1, 2, 32, 49[64]]) ttnn.Shape([1, 2, 64, 49[64]])
)

q_permuted = ttnn.permute(q, (0, 1, 3, 2)) # ttnn.Shape([1, 2, 49[64]],32)
q_permuted = ttnn.permute(q, (0, 1, 3, 2))
attn = ttnn.matmul(q_permuted, k, memory_config=ttnn.L1_MEMORY_CONFIG)
attn = ttnn.multiply(attn, self.scale) # ([1, 2, 49, 49])
attn = ttnn.multiply(attn, self.scale)
attn = ttnn.softmax(attn, dim=-1)
attn = ttnn.permute(attn, (0, 1, 3, 2))
x1 = ttnn.matmul(v, attn, memory_config=ttnn.L1_MEMORY_CONFIG) # [1, 2, 64, 49[64]]
x1 = ttnn.matmul(v, attn, memory_config=ttnn.L1_MEMORY_CONFIG)
x1 = ttnn.reshape(x1, (1, 1, (x1.shape[0] * x1.shape[1] * x1.shape[2]), x1.shape[3]))
x1 = ttnn.permute(x1, (0, 1, 3, 2))
v = ttnn.reshape(v, (1, 1, (v.shape[0] * v.shape[1] * v.shape[2]), v.shape[3])) # [1,1,128, 49[64]]
v = ttnn.reshape(v, (1, 1, (v.shape[0] * v.shape[1] * v.shape[2]), v.shape[3]))
v = ttnn.permute(v, (0, 1, 3, 2))
x2 = self.pe(device=device, x=v)
# x2 = ttnn.sharded_to_interleaved(x2, memory_config=ttnn.L1_MEMORY_CONFIG)

# x = x1 + x2
x = ttnn.add(x1, x2, memory_config=x2.memory_config())
x = self.proj(device=device, x=x)
ttnn.deallocate(x1)
Expand All @@ -393,7 +387,6 @@ def __call__(self, device, x, batch_size=1):
ttnn.deallocate(k)
ttnn.deallocate(v)
ttnn.deallocate(x2)
# x = ttnn.reallocate(x)
return x


Expand Down Expand Up @@ -440,7 +433,6 @@ def __init__(self, device, parameter, conv_pt):
def __call__(self, device, x):
x1 = x
x = self.attn(device, x)
# x = x1 + x
x = ttnn.add(x1, x, memory_config=x.memory_config())
x1 = x
x = self.ffn_conv1(device, x)
Expand All @@ -457,15 +449,14 @@ def __init__(self, device, parameter, conv_pt):
self.psablock = PSABlock(device, parameter.m[0], conv_pt.m[0])

def __call__(self, device, x):
x = self.cv1(device, x) # (1,1,49,256)
x = self.cv1(device, x)
a, b = x[:, :, :, : int(self.out_channel_0 / 2)], x[:, :, :, int(self.out_channel_0 / 2) :]
x = self.psablock(device, b)
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.concat((a, x), dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG)
x = self.cv2(device, x)
ttnn.deallocate(a)
ttnn.deallocate(b)
# x = ttnn.reallocate(x)
return x


Expand Down Expand Up @@ -505,35 +496,35 @@ def __init__(self, device, parameter, conv_pt):
self.anchors = conv_pt.anchors
self.strides = conv_pt.strides

def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
def __call__(self, device, y1, y2, y3):
x1 = self.cv2_0_0(device, y1)
x1 = self.cv2_0_1(device, x1)
x1 = self.cv2_0_2(x1) # 0.98
x1 = self.cv2_0_2(x1)
x2 = self.cv2_1_0(device, y2)
x2 = self.cv2_1_1(device, x2)
x2 = self.cv2_1_2(x2) # 0.993
x2 = self.cv2_1_2(x2)

x3 = self.cv2_2_0(device, y3)
x3 = self.cv2_2_1(device, x3)
x3 = self.cv2_2_2(x3) # 0.998
x3 = self.cv2_2_2(x3)

x4 = self.cv3_0_0_0(device, y1)
x4 = self.cv3_0_0_1(device, x4)
x4 = self.cv3_0_1_0(device, x4)
x4 = self.cv3_0_1_1(device, x4)
x4 = self.cv3_0_2_0(x4) # 0.986
x4 = self.cv3_0_2_0(x4)

x5 = self.cv3_1_0_0(device, y2)
x5 = self.cv3_1_0_1(device, x5)
x5 = self.cv3_1_1_0(device, x5)
x5 = self.cv3_1_1_1(device, x5)
x5 = self.cv3_1_2_0(x5) # 0.961
x5 = self.cv3_1_2_0(x5)

x6 = self.cv3_2_0_0(device, y3)
x6 = self.cv3_2_0_1(device, x6)
x6 = self.cv3_2_1_0(device, x6)
x6 = self.cv3_2_1_1(device, x6)
x6 = self.cv3_2_2_0(x6) # 0.986
x6 = self.cv3_2_2_0(x6)

x1 = ttnn.sharded_to_interleaved(x1, memory_config=ttnn.L1_MEMORY_CONFIG)
x2 = ttnn.sharded_to_interleaved(x2, memory_config=ttnn.L1_MEMORY_CONFIG)
Expand All @@ -546,14 +537,12 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
y2 = ttnn.concat((x2, x5), -1, memory_config=ttnn.L1_MEMORY_CONFIG)
y3 = ttnn.concat((x3, x6), -1, memory_config=ttnn.L1_MEMORY_CONFIG)

y1_reshaped = ttnn.reshape(y1, (y1.shape[0], y1.shape[2], y1.shape[-1])) # 0.9992
y2_reshaped = ttnn.reshape(y2, (y2.shape[0], y2.shape[2], y2.shape[-1])) # 0.99908
y3_reshaped = ttnn.reshape(y3, (y3.shape[0], y3.shape[2], y3.shape[-1])) # 0.993

# y_all = [y1_reshaped, y2_reshaped, y3_reshaped]
y = ttnn.concat((y1, y2, y3), dim=2, memory_config=ttnn.L1_MEMORY_CONFIG) # 0.998
y1_reshaped = ttnn.reshape(y1, (y1.shape[0], y1.shape[2], y1.shape[-1]))
y2_reshaped = ttnn.reshape(y2, (y2.shape[0], y2.shape[2], y2.shape[-1]))
y3_reshaped = ttnn.reshape(y3, (y3.shape[0], y3.shape[2], y3.shape[-1]))
y = ttnn.concat((y1, y2, y3), dim=2, memory_config=ttnn.L1_MEMORY_CONFIG)
y = ttnn.squeeze(y, dim=0)
ya, yb = y[:, :, :64], y[:, :, 64:144] # 0.991, 0.97
ya, yb = y[:, :, :64], y[:, :, 64:144]
ttnn.deallocate(y1)
ttnn.deallocate(y2)
ttnn.deallocate(y3)
Expand All @@ -563,49 +552,35 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
ttnn.deallocate(x4)
ttnn.deallocate(x5)
ttnn.deallocate(x6)
# ttnn.deallocate(y1_reshaped)
# ttnn.deallocate(y2_reshaped)
# ttnn.deallocate(y3_reshaped)
ttnn.deallocate(y)
ya = ttnn.reallocate(ya)
yb = ttnn.reallocate(yb)
# ya = ttnn.permute(ya, (0, 2, 1))
# print("before reshape", ya.shape, ya.layout, ya.memory_config(), ya.dtype)
# ya = ttnn.reshape(ya, (ya.shape[0], 4, 16, ya.shape[2]))
# ya = ttnn.permute(ya, (0, 2, 1, 3)) # 0.991
# ya = ttnn.to_layout(ya, ttnn.TILE_LAYOUT)
# ya = ttnn.softmax(ya, dim=1)
# ya = ttnn.permute(ya, (0, 2, 3, 1))
ya = ttnn.reshape(ya, (ya.shape[0], y.shape[1], 4, 16))
ya = ttnn.softmax(ya, dim=-1)
ya = ttnn.permute(ya, (0, 2, 1, 3))
c = self.dfl(ya) # 0.968
c = self.dfl(ya)
ttnn.deallocate(ya)
c = ttnn.sharded_to_interleaved(c, memory_config=ttnn.L1_MEMORY_CONFIG)

c = ttnn.to_layout(c, layout=ttnn.ROW_MAJOR_LAYOUT)
c = ttnn.permute(c, (0, 3, 1, 2))
c = ttnn.reshape(c, (c.shape[0], 1, 4, int(c.shape[3] / 4)))
c = ttnn.reshape(c, (c.shape[0], c.shape[1] * c.shape[2], c.shape[3]))
# c = ttnn.reshape(c,(c.shape[0],c.shape[1],4,int(c.shape[2] / 4)))
# c = ttnn.squeeze(c,dim=0)
c1, c2 = c[:, :2, :], c[:, 2:4, :]

anchor, strides = self.anchors, self.strides
# anchor = ttnn.to_memory_config(anchor, memory_config=ttnn.L1_MEMORY_CONFIG)
# strides = ttnn.to_memory_config(strides, memory_config=ttnn.L1_MEMORY_CONFIG)
c1 = ttnn.to_layout(c1, layout=ttnn.TILE_LAYOUT)
c2 = ttnn.to_layout(c2, layout=ttnn.TILE_LAYOUT)

c1 = anchor - c1 # 0.998
c2 = anchor + c2 # 0.997
c1 = anchor - c1
c2 = anchor + c2

z1 = c2 - c1
z2 = c1 + c2
z2 = ttnn.div(z2, 2) # 0.9995
z2 = ttnn.div(z2, 2)

z = ttnn.concat((z2, z1), dim=1, memory_config=ttnn.L1_MEMORY_CONFIG)
z = ttnn.multiply(z, strides) # 0.9998
z = ttnn.multiply(z, strides)
yb = ttnn.permute(yb, (0, 2, 1))
yb = ttnn.sigmoid(yb)
ttnn.deallocate(c)
Expand All @@ -620,12 +595,8 @@ def __call__(self, device, y1, y2, y3): # 0.9934, #0.987, #0.9803
z = ttnn.to_layout(z, layout=ttnn.ROW_MAJOR_LAYOUT)
yb = ttnn.to_layout(yb, layout=ttnn.ROW_MAJOR_LAYOUT)
out = ttnn.concat((z, yb), dim=1, memory_config=ttnn.L1_MEMORY_CONFIG)

ttnn.deallocate(yb)

ttnn.deallocate(z)

# out = ttnn.reallocate(out)
return out


Expand Down Expand Up @@ -673,41 +644,23 @@ def __init__(self, device, parameters):
self.detect = Detect(device, parameters.model_args.model[23], parameters.model[23])

def __call__(self, x):
x = self.conv1(self.device, x) # 0.9997768588395117
x = self.conv2(self.device, x) # 0.9995776757085278
x = self.c3k2_1(self.device, x) # 0.9981333502789589
x = self.conv3(self.device, x) # 0.9969896145281048
x = self.c3k2_2(self.device, x) # -.0.07, WITH FROM AND TO - 0.9947373339421013
# torch.save(ttnn.to_torch(x).reshape(x.shape[0],int(torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32))),int(torch.sqrt(torch.tensor(x.shape[2], dtype=torch.float32))),x.shape[-1]).permute(0,3,1,2),"/home/ubuntu/tt-metal/models/experimental/functional_yolov11/dumps/ttnn_out.pth")
# return x
x = self.conv1(self.device, x)
x = self.conv2(self.device, x)
x = self.c3k2_1(self.device, x)
x = self.conv3(self.device, x)
x = self.c3k2_2(self.device, x)
x4 = x
x = self.conv5(self.device, x) # 0.9958605201742154
x = self.c3k2_3(self.device, x) # 6

x = self.conv5(self.device, x)
x = self.c3k2_3(self.device, x)
x6 = x
x = self.conv6(self.device, x) # 7
x = self.c3k2_4(self.device, x) # 8 #0.25

x = self.sppf(self.device, x) # 9 #0.986
x = self.c2psa(self.device, x) # 10
print("before x details", x.shape, x.layout, x.dtype)
# x = ttnn.to_torch(x)
# x = ttnn.from_torch(
# x,
# dtype=ttnn.bfloat16,
# device=self.device,
# layout=ttnn.ROW_MAJOR_LAYOUT,
# memory_config=ttnn.L1_MEMORY_CONFIG,
# )
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.to_dtype(x, ttnn.bfloat16)
# torch.save(ttnn.to_torch(x).reshape(1,7,7,256).permute(0,3,1,2),"/home/ubuntu/venkatesh_yolov11/tt-metal/models/experimental/functional_yolov11/dumps/tt_out.pth")
print("after x details", x.shape, x.layout, x.dtype)
x = self.conv6(self.device, x)
x = self.c3k2_4(self.device, x)
x = self.sppf(self.device, x)
x = self.c2psa(self.device, x)
x10 = x
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
# x = ttnn.to_dtype(x, ttnn.bfloat16)
x = ttnn.reshape(x, (x.shape[0], int(math.sqrt(x.shape[2])), int(math.sqrt(x.shape[2])), x.shape[3]))
print("ttnn input to upsample1 is ", x.shape, x.layout, x.dtype)
# x = Yolov11_shard_upsample(self.device, x)
# x = ttnn.upsample(x, scale_factor=2)
nhw = x.shape[0] * x.shape[1] * x.shape[2]
num_cores = determine_num_cores_for_upsample(nhw, x.shape[2])
core_grid = get_core_grid_from_num_cores(num_cores)
Expand All @@ -722,63 +675,86 @@ def __call__(self, x):
x = ttnn.upsample(x, scale_factor=2, memory_config=x.memory_config()) # 11
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
print("output of 1st upsample", x.shape)
x = ttnn.reshape(x, (1, 1, x.shape[0] * x.shape[1] * x.shape[2], x.shape[3]))
x6 = ttnn.to_layout(x6, layout=ttnn.ROW_MAJOR_LAYOUT)
print("x and x6 and x4 oconfig is ", x.memory_config(), x6.memory_config(), x4.memory_config())
x = ttnn.concat((x, x6), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 12

# x = sharded_concat([x,x6]) # unequal channels( sharded_concat is not applicable)
shard_height = (x[0].shape[2] + 64 - 1) // 64
print("shard height is ", shard_height)
print("x and x6 sahpes are", x.shape, x6.shape)
input_sharded_memory_config_1 = ttnn.create_sharded_memory_config(
(shard_height, x.shape[-1]),
core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}),
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
input_sharded_memory_config_2 = ttnn.create_sharded_memory_config(
(shard_height, x6.shape[-1]),
core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}),
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
# x = ttnn.to_memory_config(x,input_sharded_memory_config_1)
# x6 = ttnn.to_memory_config(x6,input_sharded_memory_config_2)
out_sharded_memory_config_ = ttnn.create_sharded_memory_config(
(shard_height, x.shape[-1] + x6.shape[-1]),
core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}),
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
x = ttnn.concat((x, x6), -1, memory_config=ttnn.L1_MEMORY_CONFIG)

ttnn.deallocate(x6)
# x = ttnn.reallocate(x)
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
# if x.shape[2]==196:
# x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
# x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
print(" after x and x6 concat", x.shape)
# return x
x = self.c3k2_5(self.device, x) # 13
x13 = x
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.reshape(x, (x.shape[0], int(math.sqrt(x.shape[2])), int(math.sqrt(x.shape[2])), x.shape[3]))
print("ttnn input to upsample2 is ", x.shape, x.layout, x.dtype)
nhw = x.shape[0] * x.shape[1] * x.shape[2]
num_cores = determine_num_cores_for_upsample(nhw, x.shape[2])
core_grid = get_core_grid_from_num_cores(num_cores)
shardspec = ttnn.create_sharded_memory_config_(
x.shape, core_grid, ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR
)

if x.is_sharded():
x = ttnn.reshard(x, shardspec)
else:
x = ttnn.interleaved_to_sharded(x, shardspec)
x = ttnn.upsample(x, scale_factor=2, memory_config=x.memory_config())
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
# x = ttnn.upsample(x, scale_factor=2) # 14
# x = Yolov11_shard_upsample(self.device, x)
print("output of 2nd upsample", x.shape)
x = ttnn.reshape(x, (1, 1, x.shape[0] * x.shape[1] * x.shape[2], x.shape[3]))
x4 = ttnn.to_layout(x4, layout=ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.concat((x, x4), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 15
x = sharded_concat([x, x4])
# x = ttnn.concat((x, x4), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 15
ttnn.deallocate(x4)
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
# x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
x = self.c3k2_6(self.device, x) # 16
x16 = x
x = self.conv7(self.device, x) # 17
# x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
# x = ttnn.to_dtype(x, ttnn.bfloat16)
# x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
print("x and x13 shapes are", x.shape, x13.shape, x.dtype, x13.dtype, x.layout, x13.layout)
# print("x and x13 shapes are", x.shape, x13.shape, x.dtype, x13.dtype, x.layout, x13.layout)
x = ttnn.concat((x, x13), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 18
ttnn.deallocate(x13)
x = self.c3k2_7(self.device, x) # 19
x19 = x
x = self.conv8(self.device, x) # 20 #16
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.to_dtype(x, ttnn.bfloat16)
# x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
# x = ttnn.to_dtype(x, ttnn.bfloat16)
print("x and x10 shapes are", x.shape, x10.shape, x.dtype, x10.dtype, x.layout, x10.layout)
x = ttnn.concat((x, x10), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 21
print("output cncat shape is", x.shape)
ttnn.deallocate(x10)
x = self.c3k2_8(self.device, x) # 22
x22 = x
x = self.detect(self.device, x16, x19, x22)
# ttnn.deallocate(x16)
# ttnn.deallocate(x19)
# ttnn.deallocate(x22)
ttnn.deallocate(x16)
ttnn.deallocate(x19)
ttnn.deallocate(x22)
return x
Loading

0 comments on commit bf59ace

Please sign in to comment.