Skip to content

Commit

Permalink
#3:add generic script for concat sharding(rm,fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Feb 28, 2025
1 parent 0eed687 commit 86f2997
Showing 1 changed file with 35 additions and 139 deletions.
174 changes: 35 additions & 139 deletions models/experimental/functional_yolov11/tt/ttnn_yolov11.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ def Yolov11_shard_upsample(device, x):
return x


def sharded_concat(input_tensors, num_cores=64, dim=3): # expected input tensors to be in fp16, RM, same (h*w)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
in_shard_width = input_tensors[0].shape[-1]
shard_height = (input_tensors[0].shape[2] + num_cores - 1) // num_cores
input_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, in_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
out_shard_width = 0
for i in range(len(input_tensors)):
out_shard_width += input_tensors[i].shape[-1]
input_tensors[i] = ttnn.to_memory_config(input_tensors[i], input_sharded_memory_config)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, out_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
output = ttnn.concat(input_tensors, dim, memory_config=output_sharded_memory_config)
output = ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)

return output


class Conv:
def __init__(self, device, parameter, conv_pt, enable_act=True, is_detect=False):
self.enable_act = enable_act
Expand Down Expand Up @@ -232,52 +258,15 @@ def __call__(self, device, x):
)
use_sharded_concat = True
if use_sharded_concat:
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
if m1.shape[2] == 49: # 224
shard_height = 1 # (n*h*w + num_cores - 1)//num_cores
elif m1.shape[2] == 400: # 640
shard_height = 7
else:
print("invalid shard spec")
in_shard_width = x1.shape[-1]
out_shard_width = x1.shape[-1] + m1.shape[-1] + m2.shape[-1] + m3.shape[-1]
input_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, in_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, out_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
print("inpit is", input_sharded_memory_config)
print("outpt is", output_sharded_memory_config)
x1 = ttnn.to_memory_config(x1, memory_config=input_sharded_memory_config)
m1 = ttnn.to_memory_config(m1, memory_config=input_sharded_memory_config)
m2 = ttnn.to_memory_config(m2, memory_config=input_sharded_memory_config)
m3 = ttnn.to_memory_config(m3, memory_config=input_sharded_memory_config)
memory_config_used = output_sharded_memory_config
y = sharded_concat([x1, m1, m2, m3])
else:
if m2.is_sharded():
m2 = ttnn.sharded_to_interleaved(m2, ttnn.L1_MEMORY_CONFIG)
if m3.is_sharded():
m3 = ttnn.sharded_to_interleaved(m3, ttnn.L1_MEMORY_CONFIG)
if m1.is_sharded():
m1 = ttnn.sharded_to_interleaved(m1, ttnn.L1_MEMORY_CONFIG)
memory_config_used = ttnn.L1_MEMORY_CONFIG
y = ttnn.concat([x1, m1, m2, m3], dim=-1, memory_config=memory_config_used)
if y.is_sharded():
y = ttnn.sharded_to_interleaved(y, ttnn.L1_MEMORY_CONFIG)
y = ttnn.concat([x1, m1, m2, m3], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG)
x = self.cv2(device, y)
ttnn.deallocate(x1)
ttnn.deallocate(m1)
ttnn.deallocate(m2)
ttnn.deallocate(m3)
x = ttnn.reallocate(x)
print("output shape is ", x.shape)
# x = ttnn.reallocate(x)
return x


Expand All @@ -295,62 +284,20 @@ def __call__(self, device, x):

k1 = self.k1(device, x1)
k2 = self.k2(device, k1)
print(
"input to concat",
x2.shape,
x2.layout,
x2.dtype,
x2.memory_config(),
k2.shape,
k2.layout,
k2.dtype,
k2.memory_config(),
)
use_shard_concat = False # fps drop due to layout conversion
if use_shard_concat:
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
if x2.shape[2] == 49: # 224
shard_height = 1 # (n*h*w + num_cores - 1)//num_cores
elif x2.shape[2] == 196: # 224
shard_height = 4
elif x2.shape[2] == 400: # 640
shard_height = 7
elif x2.shape[2] == 1600: # 640
shard_height = 13
else:
print("invalid shard spec")
in_shard_width = x2.shape[-1]
out_shard_width = x2.shape[-1] + k2.shape[-1]
input_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, in_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
x2 = ttnn.to_layout(x2, ttnn.ROW_MAJOR_LAYOUT)
x2 = ttnn.to_dtype(x2, ttnn.bfloat16)
x2 = ttnn.to_memory_config(x2, memory_config=input_sharded_memory_config)
k2 = ttnn.to_layout(k2, ttnn.ROW_MAJOR_LAYOUT)
k2 = ttnn.to_dtype(k2, ttnn.bfloat16)
k2 = ttnn.to_memory_config(k2, memory_config=input_sharded_memory_config)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, out_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
memory_config_used = output_sharded_memory_config
x = sharded_concat([k2, x2])
else:
memory_config_used = ttnn.L1_MEMORY_CONFIG
x = ttnn.concat((k2, x2), 3, memory_config=memory_config_used)
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.concat((k2, x2), 3, memory_config=ttnn.L1_MEMORY_CONFIG)
x = self.cv3(device, x)
ttnn.deallocate(x1)
ttnn.deallocate(x2)
ttnn.deallocate(k1)
ttnn.deallocate(k2)
x = ttnn.reallocate(x)
return x


Expand All @@ -369,14 +316,10 @@ def __init__(self, device, parameter, conv_pt, is_bk_enabled=False):
self.c3k = C3K(device, parameter[0], conv_pt.m[0])

def __call__(self, device, x):
# if self.is_bk_enabled:
x = self.cv1(device, x)
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
print("x.sha[e is ]", x.shape)
y1 = x[:, :, :, : x.shape[-1] // 2]
y2 = x[:, :, :, x.shape[-1] // 2 : x.shape[-1]]
# x = ttnn.reallocate(x)
# y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)
if self.is_bk_enabled:
y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)
y3 = self.k(device, y2)
Expand All @@ -387,63 +330,16 @@ def __call__(self, device, x):
y2 = ttnn.to_layout(y2, ttnn.ROW_MAJOR_LAYOUT)
if y3.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
y3 = ttnn.to_layout(y3, ttnn.ROW_MAJOR_LAYOUT)

print("y1,y2,y3", y1.shape, y2.shape, y3.shape, y1.dtype, y2.dtype, y3.dtype)
use_shard_concat = True
if use_shard_concat:
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
if y1.shape[2] == 25600: # 640
shard_height = 400 # (n*h*w + num_cores - 1)//num_cores
elif y1.shape[2] == 6400: # 640
shard_height = 102
elif y1.shape[2] == 1600: # 640
shard_height = 25
elif y1.shape[2] == 400: # 640
shard_height = 7
elif y1.shape[2] == 3136: # 224
shard_height = 49
elif y1.shape[2] == 784: # 224
shard_height = 13
elif y1.shape[2] == 196: # 224
shard_height = 4
elif y1.shape[2] == 49: # 224
shard_height = 1
else:
print("invalid shard spec")
in_shard_width = y1.shape[-1]
out_shard_width = y1.shape[-1] + y2.shape[-1] + y3.shape[-1]
input_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, in_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, out_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
y1 = ttnn.to_memory_config(y1, memory_config=input_sharded_memory_config)
y2 = ttnn.to_memory_config(y2, memory_config=input_sharded_memory_config)
y3 = ttnn.to_memory_config(y3, memory_config=input_sharded_memory_config)
memory_config_used = output_sharded_memory_config
x = sharded_concat([y1, y2, y3])
else:
memory_config_used = ttnn.L1_MEMORY_CONFIG
x = ttnn.concat((y1, y2, y3), 3, memory_config=memory_config_used)

if use_shard_concat:
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)

# if x.get_layout() == ttnn.ROW_MAJOR_LAYOUT:
# x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)

x = ttnn.concat((y1, y2, y3), 3, memory_config=ttnn.L1_MEMORY_CONFIG)
x = self.cv2(device, x)

ttnn.deallocate(y1)
ttnn.deallocate(y2)
ttnn.deallocate(y3)
x = ttnn.reallocate(x)
return x


Expand Down Expand Up @@ -503,7 +399,7 @@ def __call__(self, device, x, batch_size=1):
ttnn.deallocate(k)
ttnn.deallocate(v)
ttnn.deallocate(x2)
x = ttnn.reallocate(x)
# x = ttnn.reallocate(x)
return x


Expand Down Expand Up @@ -573,7 +469,7 @@ def __call__(self, device, x):
x = self.cv2(device, x)
ttnn.deallocate(a)
ttnn.deallocate(b)
x = ttnn.reallocate(x)
# x = ttnn.reallocate(x)
return x


Expand Down

0 comments on commit 86f2997

Please sign in to comment.