From 86f299793e5b077d23f668ee04b7ea3745415243 Mon Sep 17 00:00:00 2001 From: vguduruTT Date: Thu, 13 Feb 2025 06:39:02 +0000 Subject: [PATCH] #3:add generic script for concat sharding(rm,fp16 --- .../functional_yolov11/tt/ttnn_yolov11.py | 174 ++++-------------- 1 file changed, 35 insertions(+), 139 deletions(-) diff --git a/models/experimental/functional_yolov11/tt/ttnn_yolov11.py b/models/experimental/functional_yolov11/tt/ttnn_yolov11.py index 9f49f4a1e9e..491499dc9d8 100644 --- a/models/experimental/functional_yolov11/tt/ttnn_yolov11.py +++ b/models/experimental/functional_yolov11/tt/ttnn_yolov11.py @@ -157,6 +157,32 @@ def Yolov11_shard_upsample(device, x): return x +def sharded_concat(input_tensors, num_cores=64, dim=3): # expected input tensors to be in fp16, RM, same (h*w) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) + in_shard_width = input_tensors[0].shape[-1] + shard_height = (input_tensors[0].shape[2] + num_cores - 1) // num_cores + input_sharded_memory_config = ttnn.create_sharded_memory_config( + (shard_height, in_shard_width), + core_grid=shard_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + out_shard_width = 0 + for i in range(len(input_tensors)): + out_shard_width += input_tensors[i].shape[-1] + input_tensors[i] = ttnn.to_memory_config(input_tensors[i], input_sharded_memory_config) + output_sharded_memory_config = ttnn.create_sharded_memory_config( + (shard_height, out_shard_width), + core_grid=shard_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + output = ttnn.concat(input_tensors, dim, memory_config=output_sharded_memory_config) + output = ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG) + + return output + + class Conv: def __init__(self, device, parameter, conv_pt, enable_act=True, is_detect=False): self.enable_act = enable_act @@ -232,52 +258,15 @@ def __call__(self, device, x): ) use_sharded_concat = True if use_sharded_concat: - shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - if m1.shape[2] == 49: # 224 - shard_height = 1 # (n*h*w + num_cores - 1)//num_cores - elif m1.shape[2] == 400: # 640 - shard_height = 7 - else: - print("invalid shard spec") - in_shard_width = x1.shape[-1] - out_shard_width = x1.shape[-1] + m1.shape[-1] + m2.shape[-1] + m3.shape[-1] - input_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, in_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) - output_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, out_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) - print("inpit is", input_sharded_memory_config) - print("outpt is", output_sharded_memory_config) - x1 = ttnn.to_memory_config(x1, memory_config=input_sharded_memory_config) - m1 = ttnn.to_memory_config(m1, memory_config=input_sharded_memory_config) - m2 = ttnn.to_memory_config(m2, memory_config=input_sharded_memory_config) - m3 = ttnn.to_memory_config(m3, memory_config=input_sharded_memory_config) - memory_config_used = output_sharded_memory_config + y = sharded_concat([x1, m1, m2, m3]) else: - if m2.is_sharded(): - m2 = ttnn.sharded_to_interleaved(m2, ttnn.L1_MEMORY_CONFIG) - if m3.is_sharded(): - m3 = ttnn.sharded_to_interleaved(m3, ttnn.L1_MEMORY_CONFIG) - if m1.is_sharded(): - m1 = ttnn.sharded_to_interleaved(m1, ttnn.L1_MEMORY_CONFIG) - memory_config_used = ttnn.L1_MEMORY_CONFIG - y = ttnn.concat([x1, m1, m2, m3], dim=-1, memory_config=memory_config_used) - if y.is_sharded(): - y = ttnn.sharded_to_interleaved(y, ttnn.L1_MEMORY_CONFIG) + y = ttnn.concat([x1, m1, m2, m3], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) x = self.cv2(device, y) ttnn.deallocate(x1) ttnn.deallocate(m1) ttnn.deallocate(m2) ttnn.deallocate(m3) - x = ttnn.reallocate(x) - print("output shape is ", x.shape) + # x = ttnn.reallocate(x) return x @@ -295,62 +284,20 @@ def __call__(self, device, x): k1 = self.k1(device, x1) k2 = self.k2(device, k1) - print( - "input to concat", - x2.shape, - x2.layout, - x2.dtype, - x2.memory_config(), - k2.shape, - k2.layout, - k2.dtype, - k2.memory_config(), - ) use_shard_concat = False # fps drop due to layout conversion if use_shard_concat: - shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - if x2.shape[2] == 49: # 224 - shard_height = 1 # (n*h*w + num_cores - 1)//num_cores - elif x2.shape[2] == 196: # 224 - shard_height = 4 - elif x2.shape[2] == 400: # 640 - shard_height = 7 - elif x2.shape[2] == 1600: # 640 - shard_height = 13 - else: - print("invalid shard spec") - in_shard_width = x2.shape[-1] - out_shard_width = x2.shape[-1] + k2.shape[-1] - input_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, in_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) x2 = ttnn.to_layout(x2, ttnn.ROW_MAJOR_LAYOUT) x2 = ttnn.to_dtype(x2, ttnn.bfloat16) - x2 = ttnn.to_memory_config(x2, memory_config=input_sharded_memory_config) k2 = ttnn.to_layout(k2, ttnn.ROW_MAJOR_LAYOUT) k2 = ttnn.to_dtype(k2, ttnn.bfloat16) - k2 = ttnn.to_memory_config(k2, memory_config=input_sharded_memory_config) - output_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, out_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) - memory_config_used = output_sharded_memory_config + x = sharded_concat([k2, x2]) else: - memory_config_used = ttnn.L1_MEMORY_CONFIG - x = ttnn.concat((k2, x2), 3, memory_config=memory_config_used) - if x.is_sharded(): - x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) + x = ttnn.concat((k2, x2), 3, memory_config=ttnn.L1_MEMORY_CONFIG) x = self.cv3(device, x) ttnn.deallocate(x1) ttnn.deallocate(x2) ttnn.deallocate(k1) ttnn.deallocate(k2) - x = ttnn.reallocate(x) return x @@ -369,14 +316,10 @@ def __init__(self, device, parameter, conv_pt, is_bk_enabled=False): self.c3k = C3K(device, parameter[0], conv_pt.m[0]) def __call__(self, device, x): - # if self.is_bk_enabled: x = self.cv1(device, x) x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) - print("x.sha[e is ]", x.shape) y1 = x[:, :, :, : x.shape[-1] // 2] y2 = x[:, :, :, x.shape[-1] // 2 : x.shape[-1]] - # x = ttnn.reallocate(x) - # y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT) if self.is_bk_enabled: y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT) y3 = self.k(device, y2) @@ -387,63 +330,16 @@ def __call__(self, device, x): y2 = ttnn.to_layout(y2, ttnn.ROW_MAJOR_LAYOUT) if y3.get_layout() != ttnn.ROW_MAJOR_LAYOUT: y3 = ttnn.to_layout(y3, ttnn.ROW_MAJOR_LAYOUT) - - print("y1,y2,y3", y1.shape, y2.shape, y3.shape, y1.dtype, y2.dtype, y3.dtype) use_shard_concat = True if use_shard_concat: - shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - if y1.shape[2] == 25600: # 640 - shard_height = 400 # (n*h*w + num_cores - 1)//num_cores - elif y1.shape[2] == 6400: # 640 - shard_height = 102 - elif y1.shape[2] == 1600: # 640 - shard_height = 25 - elif y1.shape[2] == 400: # 640 - shard_height = 7 - elif y1.shape[2] == 3136: # 224 - shard_height = 49 - elif y1.shape[2] == 784: # 224 - shard_height = 13 - elif y1.shape[2] == 196: # 224 - shard_height = 4 - elif y1.shape[2] == 49: # 224 - shard_height = 1 - else: - print("invalid shard spec") - in_shard_width = y1.shape[-1] - out_shard_width = y1.shape[-1] + y2.shape[-1] + y3.shape[-1] - input_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, in_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) - output_sharded_memory_config = ttnn.create_sharded_memory_config( - (shard_height, out_shard_width), - core_grid=shard_grid, - strategy=ttnn.ShardStrategy.HEIGHT, - use_height_and_width_as_shard_shape=True, - ) - y1 = ttnn.to_memory_config(y1, memory_config=input_sharded_memory_config) - y2 = ttnn.to_memory_config(y2, memory_config=input_sharded_memory_config) - y3 = ttnn.to_memory_config(y3, memory_config=input_sharded_memory_config) - memory_config_used = output_sharded_memory_config + x = sharded_concat([y1, y2, y3]) else: - memory_config_used = ttnn.L1_MEMORY_CONFIG - x = ttnn.concat((y1, y2, y3), 3, memory_config=memory_config_used) - - if use_shard_concat: - x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) - - # if x.get_layout() == ttnn.ROW_MAJOR_LAYOUT: - # x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) - + x = ttnn.concat((y1, y2, y3), 3, memory_config=ttnn.L1_MEMORY_CONFIG) x = self.cv2(device, x) ttnn.deallocate(y1) ttnn.deallocate(y2) ttnn.deallocate(y3) - x = ttnn.reallocate(x) return x @@ -503,7 +399,7 @@ def __call__(self, device, x, batch_size=1): ttnn.deallocate(k) ttnn.deallocate(v) ttnn.deallocate(x2) - x = ttnn.reallocate(x) + # x = ttnn.reallocate(x) return x @@ -573,7 +469,7 @@ def __call__(self, device, x): x = self.cv2(device, x) ttnn.deallocate(a) ttnn.deallocate(b) - x = ttnn.reallocate(x) + # x = ttnn.reallocate(x) return x