Skip to content

Commit

Permalink
#2: sharded c3k2 concat
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Feb 28, 2025
1 parent 61088da commit 0eed687
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 59 deletions.
106 changes: 98 additions & 8 deletions models/experimental/functional_yolov11/test/test_ttnn_c3k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,96 @@
@pytest.mark.parametrize(
"in_channel, out_channel, kernel, stride, padding, dilation, groups,is_bk_enabled,fwd_input_shape",
[
# 224
# (
# [32, 48, 16, 8],
# [32, 64, 8, 16],
# [1, 1, 3, 3],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# True,
# [1, 32, 56, 56],
# ),
# (
# [64, 96, 32, 16],
# [64, 128, 16, 32],
# [1, 1, 3, 3],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# True,
# [1, 64, 28, 28],
# ),
# (
# [128, 192, 64, 64, 64, 32, 32, 32, 32],
# [128, 128, 32, 32, 64, 32, 32, 32, 32],
# [1, 1, 1, 1, 1, 3, 3, 3, 3],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# False,
# [1, 128, 14, 14],
# ),
# (
# [256, 384, 128, 128, 128, 64, 64, 64, 64],
# [256, 256, 64, 64, 128, 64, 64, 64, 64],
# [1, 1, 1, 1, 1, 3, 3, 3, 3],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# False,
# [1, 256, 7, 7],
# ),
# (
# [384, 192, 64, 32],
# [128, 128, 32, 64],
# [1, 1, 3, 3],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# True,
# [1, 384, 14, 14],
# ),
# (
# [256, 96, 32, 16],
# [64, 64, 16, 32],
# [1, 1, 3, 3],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# True,
# [1, 256, 28, 28],
# ),
# (
# [192, 192, 64, 32],
# [128, 128, 32, 64],
# [1, 1, 3, 3],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# True,
# [1, 192, 14, 14],
# ),
# (
# [384, 384, 128, 128, 128, 64, 64, 64, 64],
# [256, 256, 64, 64, 128, 64, 64, 64, 64],
# [1, 1, 1, 1, 1, 3, 3, 3, 3],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [0, 0, 0, 0, 0, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1],
# False,
# [1, 384, 7, 7],
# ),
# #640
(
[32, 48, 16, 8],
[32, 64, 8, 16],
Expand All @@ -25,7 +115,7 @@
[1, 1, 1, 1],
[1, 1, 1, 1],
True,
[1, 32, 56, 56],
[1, 32, 160, 160],
),
(
[64, 96, 32, 16],
Expand All @@ -36,7 +126,7 @@
[1, 1, 1, 1],
[1, 1, 1, 1],
True,
[1, 64, 28, 28],
[1, 64, 80, 80],
),
(
[128, 192, 64, 64, 64, 32, 32, 32, 32],
Expand All @@ -47,7 +137,7 @@
[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
False,
[1, 128, 14, 14],
[1, 128, 40, 40],
),
(
[256, 384, 128, 128, 128, 64, 64, 64, 64],
Expand All @@ -58,7 +148,7 @@
[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
False,
[1, 256, 7, 7],
[1, 256, 20, 20],
),
(
[384, 192, 64, 32],
Expand All @@ -69,7 +159,7 @@
[1, 1, 1, 1],
[1, 1, 1, 1],
True,
[1, 384, 14, 14],
[1, 384, 40, 40],
),
(
[256, 96, 32, 16],
Expand All @@ -80,7 +170,7 @@
[1, 1, 1, 1],
[1, 1, 1, 1],
True,
[1, 256, 28, 28],
[1, 256, 80, 80],
),
(
[192, 192, 64, 32],
Expand All @@ -91,7 +181,7 @@
[1, 1, 1, 1],
[1, 1, 1, 1],
True,
[1, 192, 14, 14],
[1, 192, 40, 40],
),
(
[384, 384, 128, 128, 128, 64, 64, 64, 64],
Expand All @@ -102,7 +192,7 @@
[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
False,
[1, 384, 7, 7],
[1, 384, 20, 20],
),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def attempt_load(weights, map_location=None):

@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_yolov11(device, use_program_cache, reset_seeds):
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=640, input_width=640)
torch_input, ttnn_input = create_yolov11_input_tensors(device, input_channels=3, input_height=224, input_width=224)

torch_model = attempt_load("yolov11n.pt", map_location="cpu")
state_dict = torch_model.state_dict()
Expand Down
99 changes: 59 additions & 40 deletions models/experimental/functional_yolov11/tt/ttnn_yolov11.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __call__(self, device, x):
ttnn.deallocate(m2)
ttnn.deallocate(m3)
x = ttnn.reallocate(x)
print("output shape is ", x.shape)
return x


Expand All @@ -294,12 +295,6 @@ def __call__(self, device, x):

k1 = self.k1(device, x1)
k2 = self.k2(device, k1)
# input to concat Shape([1, 1, 49, 64]) Layout.TILE DataType.BFLOAT8_B MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::L1,shard_spec=std::nullopt) Shape([1, 1, 49, 64]) Layout.TILE DataType.BFLOAT8_B MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::L1,shard_spec=std::nullopt)
# input to concat Shape([1, 1, 196, 32]) Layout.TILE DataType.BFLOAT8_B MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::L1,shard_spec=std::nullopt) Shape([1, 1, 196, 32]) Layout.TILE DataType.BFLOAT8_B MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::L1,shard_spec=std::nullopt)
# if x2.is_sharded():
# x2 = ttnn.sharded_to_interleaved(x2, ttnn.L1_MEMORY_CONFIG)
# if k2.is_sharded():
# k2 = ttnn.sharded_to_interleaved(k2, ttnn.L1_MEMORY_CONFIG)
print(
"input to concat",
x2.shape,
Expand All @@ -311,7 +306,7 @@ def __call__(self, device, x):
k2.dtype,
k2.memory_config(),
)
use_shard_concat = True
use_shard_concat = False # fps drop due to layout conversion
if use_shard_concat:
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
if x2.shape[2] == 49: # 224
Expand Down Expand Up @@ -377,49 +372,73 @@ def __call__(self, device, x):
# if self.is_bk_enabled:
x = self.cv1(device, x)
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
print("x.sha[e is ]", x.shape)
y1 = x[:, :, :, : x.shape[-1] // 2]
y2 = x[:, :, :, x.shape[-1] // 2 : x.shape[-1]]
x = ttnn.reallocate(x)
y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)
# x = ttnn.reallocate(x)
# y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)
if self.is_bk_enabled:
y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)
y3 = self.k(device, y2)
else:
y3 = self.c3k(device, y2)

if y2.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
y2 = ttnn.to_layout(y2, ttnn.ROW_MAJOR_LAYOUT)
if y3.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
y3 = ttnn.to_layout(y3, ttnn.ROW_MAJOR_LAYOUT)

x = ttnn.concat((y1, y2, y3), 3, memory_config=ttnn.L1_MEMORY_CONFIG)
print("y1,y2,y3", y1.shape, y2.shape, y3.shape, y1.dtype, y2.dtype, y3.dtype)
use_shard_concat = True
if use_shard_concat:
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))})
if y1.shape[2] == 25600: # 640
shard_height = 400 # (n*h*w + num_cores - 1)//num_cores
elif y1.shape[2] == 6400: # 640
shard_height = 102
elif y1.shape[2] == 1600: # 640
shard_height = 25
elif y1.shape[2] == 400: # 640
shard_height = 7
elif y1.shape[2] == 3136: # 224
shard_height = 49
elif y1.shape[2] == 784: # 224
shard_height = 13
elif y1.shape[2] == 196: # 224
shard_height = 4
elif y1.shape[2] == 49: # 224
shard_height = 1
else:
print("invalid shard spec")
in_shard_width = y1.shape[-1]
out_shard_width = y1.shape[-1] + y2.shape[-1] + y3.shape[-1]
input_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, in_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
(shard_height, out_shard_width),
core_grid=shard_grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
)
y1 = ttnn.to_memory_config(y1, memory_config=input_sharded_memory_config)
y2 = ttnn.to_memory_config(y2, memory_config=input_sharded_memory_config)
y3 = ttnn.to_memory_config(y3, memory_config=input_sharded_memory_config)
memory_config_used = output_sharded_memory_config
else:
memory_config_used = ttnn.L1_MEMORY_CONFIG
x = ttnn.concat((y1, y2, y3), 3, memory_config=memory_config_used)

if use_shard_concat:
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)

# if x.get_layout() == ttnn.ROW_MAJOR_LAYOUT:
# x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)

if x.get_layout() == ttnn.ROW_MAJOR_LAYOUT:
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
x = self.cv2(device, x)
# else:
# x = self.cv1(device, x)
# x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
# y1 = x[:, :, :, : x.shape[-1] // 2]
# y2 = x[:, :, :, x.shape[-1] // 2 : x.shape[-1]]
# y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT)

# y3 = self.c3k(device, y2)

# # if y1.is_sharded():
# # y1 = ttnn.sharded_to_interleaved(y1, ttnn.L1_MEMORY_CONFIG)
# # if y2.is_sharded():
# # y2 = ttnn.sharded_to_interleaved(y2, ttnn.L1_MEMORY_CONFIG)
# # if y3.is_sharded():
# # y3 = ttnn.sharded_to_interleaved(y3, ttnn.L1_MEMORY_CONFIG)

# if y2.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
# y2 = ttnn.to_layout(y2, ttnn.ROW_MAJOR_LAYOUT)
# if y3.get_layout() != ttnn.ROW_MAJOR_LAYOUT:
# y3 = ttnn.to_layout(y3, ttnn.ROW_MAJOR_LAYOUT)

# x = ttnn.concat((y1, y2, y3), 3, memory_config=ttnn.L1_MEMORY_CONFIG)
# if x.get_layout() == ttnn.ROW_MAJOR_LAYOUT:
# x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
# x = self.cv2(device, x)

ttnn.deallocate(y1)
ttnn.deallocate(y2)
Expand Down Expand Up @@ -843,9 +862,9 @@ def __call__(self, x):
x = self.c3k2_6(self.device, x) # 16
x16 = x
x = self.conv7(self.device, x) # 17
x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.to_dtype(x, ttnn.bfloat16)
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
# x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT)
# x = ttnn.to_dtype(x, ttnn.bfloat16)
# x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
print("x and x13 shapes are", x.shape, x13.shape, x.dtype, x13.dtype, x.layout, x13.layout)
x = ttnn.concat((x, x13), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 18
ttnn.deallocate(x13)
Expand Down
Loading

0 comments on commit 0eed687

Please sign in to comment.