diff --git a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py index c7a25d71e09..2866840ad8d 100644 --- a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py +++ b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py @@ -14,6 +14,7 @@ preprocess_model_parameters, ) from models.utility_functions import ( + is_blackhole, is_wormhole_b0, is_grayskull, divup, @@ -141,6 +142,23 @@ def load_resnet50_model(model_location_generator): ttnn.bfloat8_b, ): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433 }, + 32: { + ( + ttnn.MathFidelity.HiFi4, + ttnn.bfloat8_b, + ttnn.bfloat8_b, + ): 0.97, + ( + ttnn.MathFidelity.HiFi2, + ttnn.bfloat8_b, + ttnn.bfloat8_b, + ): 0.95, + ( + ttnn.MathFidelity.LoFi, + ttnn.bfloat8_b, + ttnn.bfloat8_b, + ): 0.88, + }, } golden_pcc = { @@ -255,10 +273,12 @@ def setup_l1_sharded_input(self, device, torch_input_tensor=None): if self.batch_size == 16: core_grid = ttnn.CoreGrid(y=8, x=6) elif self.batch_size == 20: - if is_grayskull(): + if is_grayskull() or is_blackhole(): core_grid = ttnn.CoreGrid(y=8, x=10) elif is_wormhole_b0(): core_grid = ttnn.CoreGrid(y=5, x=6) # untested due to unsupported batch20 on WH + elif self.batch_size == 32: + core_grid = ttnn.CoreGrid(y=10, x=13) num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() # torch tensor torch_input_tensor = self.torch_input_tensor if torch_input_tensor is None else torch_input_tensor diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py index fd982c479e9..2c12daac707 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py @@ -39,8 +39,19 @@ mcast_in0=True, ), 20: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=(8, 4), - in0_block_w=2, + compute_with_storage_grid_size=(8, 8), + in0_block_w=1, + out_subblock_h=1, + out_subblock_w=1, + per_core_M=1, + per_core_N=1, + fuse_batch=True, + fused_activation=None, + mcast_in0=True, + ), + 32: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=(8, 8), + in0_block_w=1, out_subblock_h=1, out_subblock_w=1, per_core_M=1, @@ -197,8 +208,6 @@ def run_downsample_if_req( enable_subblock_padding=enable_subblock_padding, ), } - if is_blackhole(): - conv_kwargs["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.ds_conv_weight_tensor): self.ds_conv_weight_tensor = ttnn.prepare_conv_weights( @@ -292,8 +301,6 @@ def __call__( transpose_shards=transpose_shards, ), } - if is_blackhole(): - conv_kwargs_1["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): self.conv1_weight_tensor = ttnn.prepare_conv_weights( @@ -415,7 +422,30 @@ def __call__( if is_blackhole(): conv_kwargs_2["conv_config"].act_block_h_override = 2 * 32 conv_kwargs_2["conv_config"].enable_subblock_padding = False - conv_kwargs_2["conv_config"].enable_split_reader = False + if ( + batch_size == 32 + and layer_module + and ( + layer_module == "layer1_module2" + or layer_module == "layer1_module3" + or layer_module == "layer2_module2" + or layer_module == "layer2_module3" + or layer_module == "layer2_module4" + ) + ): + conv_kwargs_2["conv_config"].act_block_h_override = 0 + elif ( + batch_size == 20 + and layer_module + and (layer_module == "layer4_module2" or layer_module == "layer4_module3") + ): + conv_kwargs_2["conv_config"].act_block_h_override = 0 + elif ( + batch_size == 16 + and layer_module + and (layer_module == "layer1_module2" or layer_module == "layer1_module3") + ): + conv_kwargs_2["conv_config"].act_block_h_override = 0 if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor): self.conv2_weight_tensor = ttnn.prepare_conv_weights( @@ -449,6 +479,7 @@ def __call__( return_output_dim=True, return_weights_and_bias=False, ) + if layer_module and layer_module == "layer4_module1": if ops_parallel_config and "layer4_module1_input" not in ops_parallel_config: x_memory_config = ttnn.get_memory_config(out) @@ -485,8 +516,6 @@ def __call__( transpose_shards=transpose_shards, ), } - if is_blackhole(): - conv_kwargs_3["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor): self.conv3_weight_tensor = ttnn.prepare_conv_weights( @@ -598,9 +627,6 @@ def __init__( self.conv1_output_channels = self.conv1_weight_tensor.shape[0] assert self.conv1_weight_tensor.shape[2] == 4 - self.max_pool_reader_patterns_cache = {} - max_pool_parallel_config_override = {} - self.layer1 = self._make_layer( parameters=parameters.layer1, planes=64, @@ -666,12 +692,18 @@ def __init__( ) # num_classes = 1000 self.transpose_shards = True - act_block_h_override = 1568 - if is_wormhole_b0() or is_blackhole(): + act_block_h_override = 0 + + if is_wormhole_b0(): + self.transpose_shards = False + act_block_h_override = 1568 + + if is_blackhole() and self.batch_size < 20: self.transpose_shards = False - else: - act_block_h_override = 0 + + if is_blackhole() and self.batch_size == 32: + act_block_h_override = 49 * 32 input_channels_alignment = 16 self.conv1_config = ttnn.Conv2dConfig( @@ -699,10 +731,6 @@ def __init__( self.conv1_config.act_block_h_override = 64 else: self.conv1_config.act_block_h_override = 49 * 32 - if is_blackhole(): - # self.conv1_config.act_block_h_override = 7 * 32 - # self.conv1_config.act_block_h_override = 2 * 32 - self.conv1_config.enable_split_reader = False self.conv1_kernel_size = (4, 4) self.conv1_stride = (1, 1) @@ -738,6 +766,9 @@ def __init__( if self.batch_size == 16: num_cores_x = 8 num_cores_y = 8 + self.fold_compute_grid_size = ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))} + ) elif self.batch_size == 20: if is_grayskull(): num_cores_x = 10 @@ -746,9 +777,19 @@ def __init__( num_cores_x = 8 num_cores_y = 5 elif is_blackhole(): - num_cores_x = 8 - num_cores_y = 10 - self.fold_compute_grid_size = (num_cores_x, num_cores_y) + num_cores_x = 10 + num_cores_y = 8 + self.fold_compute_grid_size = ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))} + ) + elif self.batch_size == 32: + core_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(12, 8)), + ttnn.CoreRange(ttnn.CoreCoord(0, 9), ttnn.CoreCoord(10, 9)), + } + ) + self.fold_compute_grid_size = core_grid conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16) conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) @@ -763,7 +804,7 @@ def __init__( self.conv1_output_width, device.compute_with_storage_grid_size(), self.conv1_config.input_channels_alignment, - is_grayskull(), + is_grayskull() or is_blackhole(), ) def __del__(self): @@ -871,6 +912,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt ) self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device) self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device) + x, [x_height, x_width] = ttnn.conv2d( input_tensor=fold_output_tensor, weight_tensor=self.conv1_weight_tensor, @@ -881,6 +923,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt return_output_dim=True, return_weights_and_bias=False, ) + # Relu is fused with conv1 if self.batch_size == 20: x = ttnn.reallocate(x) @@ -902,6 +945,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x = ttnn.reshape(x, (1, 1, x_height * x_width * self.batch_size, 64)) if is_blackhole(): + ## 112 core_range_set = ttnn.CoreRangeSet( { ttnn.CoreRange( @@ -927,15 +971,16 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt ) x = ttnn.to_memory_config(x, mem_config) - x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"]) - if self.batch_size == 20 and is_grayskull(): x = ttnn.reallocate(x) + if not is_blackhole(): + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"]) + logger.debug(f"==== Running layer 1 module 1") layer1_module1_input_shape = ttnn.Shape(x.padded_shape) - reshard = False + reshard = is_blackhole() height_shard = True x, x_height, x_width = self.layer1_module1( @@ -975,6 +1020,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt enable_act_double_buffer=False, enable_split_reader=True, enable_subblock_padding=not is_grayskull(), + layer_module="layer1_module2", ) logger.debug(f"==== Running layer 1 module 3") @@ -989,14 +1035,15 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt enable_act_double_buffer=False, enable_split_reader=True, enable_subblock_padding=not is_grayskull(), + layer_module="layer1_module3", ) layer2_module1_input_shape = ttnn.Shape(x.padded_shape) - reshard = not (is_wormhole_b0() or is_grayskull()) + reshard = is_blackhole() or not (is_wormhole_b0() or is_grayskull()) height_shard = True - if is_blackhole(): + if is_blackhole() and self.batch_size < 20: ## 98 core_range_set = ttnn.CoreRangeSet( { @@ -1031,8 +1078,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt height_sharding=height_shard, transpose_shards=self.transpose_shards, enable_act_double_buffer=True, - enable_split_reader=False, + enable_split_reader=True, enable_subblock_padding=False, + layer_module="layer2_module1", ) if is_first_run: @@ -1055,8 +1103,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, transpose_shards=self.transpose_shards, enable_act_double_buffer=True, - enable_split_reader=False, + enable_split_reader=True, enable_subblock_padding=False, + layer_module="layer2_module2", ) logger.debug(f"==== Running layer 2 module 3") @@ -1069,8 +1118,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, transpose_shards=self.transpose_shards, enable_act_double_buffer=True, - enable_split_reader=False, + enable_split_reader=True, enable_subblock_padding=False, + layer_module="layer2_module3", ) logger.debug(f"==== Running layer 2 module 4") @@ -1083,8 +1133,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, transpose_shards=self.transpose_shards, enable_act_double_buffer=True, - enable_split_reader=False, + enable_split_reader=True, enable_subblock_padding=False, + layer_module="layer2_module4", ) layer3_module1_input_shape = ttnn.Shape(x.padded_shape) @@ -1211,11 +1262,11 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt enable_subblock_padding=False, ) - reshard = is_grayskull() + reshard = is_grayskull() or (is_blackhole() and self.batch_size == 20) height_shard = False layer4_module1_input_shape = ttnn.Shape(x.padded_shape) - if is_blackhole(): + if is_blackhole() and self.batch_size != 20: # 104 grid_size = (13, 8) core_range_set = ttnn.CoreRangeSet( @@ -1275,6 +1326,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + layer_module="layer4_module2", ) logger.debug(f"==== Running layer 4 module 3") @@ -1289,9 +1341,12 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + layer_module="layer4_module3", ) grid_size = (8, 4) + if self.batch_size > 16: + grid_size = (8, 8) shard_grid = ttnn.CoreRangeSet( { ttnn.CoreRange( diff --git a/models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py b/models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py index 568a923c6b3..5fb5ea7942e 100644 --- a/models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py +++ b/models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py @@ -68,7 +68,6 @@ def test_perf_trace( ) -@run_for_wormhole_b0() @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_command_queues": 2}], indirect=True) @pytest.mark.parametrize( @@ -95,10 +94,9 @@ def test_perf_2cqs( ) -@run_for_wormhole_b0() @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( - "device_params", [{"l1_small_size": 32768, "num_command_queues": 2, "trace_region_size": 1332224}], indirect=True + "device_params", [{"l1_small_size": 32768, "num_command_queues": 2, "trace_region_size": 1470464}], indirect=True ) @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", diff --git a/models/demos/wormhole/resnet50/tests/test_resnet50_performant.py b/models/demos/wormhole/resnet50/tests/test_resnet50_performant.py index 5f33ad884b8..f51755ad140 100644 --- a/models/demos/wormhole/resnet50/tests/test_resnet50_performant.py +++ b/models/demos/wormhole/resnet50/tests/test_resnet50_performant.py @@ -27,7 +27,7 @@ def test_run_resnet50_inference( @run_for_wormhole_b0() -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 803016}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 845824}], indirect=True) @pytest.mark.parametrize( "batch_size, act_dtype, weight_dtype, math_fidelity", ((16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi),), @@ -67,7 +67,7 @@ def test_run_resnet50_2cqs_inference( @run_for_wormhole_b0() @pytest.mark.parametrize( - "device_params", [{"l1_small_size": 24576, "trace_region_size": 803016, "num_command_queues": 2}], indirect=True + "device_params", [{"l1_small_size": 24576, "trace_region_size": 845824, "num_command_queues": 2}], indirect=True ) @pytest.mark.parametrize( "batch_size, act_dtype, weight_dtype, math_fidelity", diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 3cd7f275927..da00379d3eb 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -11,7 +11,7 @@ from loguru import logger from models.utility_functions import is_grayskull, is_blackhole, torch_random from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal -from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from models.utility_functions import skip_for_grayskull, skip_for_blackhole, run_for_blackhole from tests.ttnn.utils_for_testing import assert_with_pcc @@ -56,6 +56,29 @@ def transpose( assert device.num_program_cache_entries() == expected_program_cache_size +@run_for_blackhole() +def test_fold_transpose(device, use_program_cache): + N = 32 + C = 4 + H = 256 + W = 224 + input_shape = (N, C, H, W) + ## 128 + grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(12, 8)), + ttnn.CoreRange(ttnn.CoreCoord(0, 9), ttnn.CoreCoord(10, 9)), + } + ) + sharded_config = ttnn.create_sharded_memory_config_( + input_shape, + grid, + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.ShardOrientation.ROW_MAJOR, + ) + transpose(input_shape, device, dim0=2, dim1=3, input_mem_config=sharded_config, output_mem_config=sharded_config) + + @pytest.mark.parametrize( "dtype", (ttnn.bfloat16, ttnn.float32), diff --git a/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py b/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py index bca11c9e2e8..776e74b8bfa 100644 --- a/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py +++ b/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py @@ -20,6 +20,8 @@ ( (16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2), (16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi), + (20, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi), + (32, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi), ), ) @pytest.mark.parametrize( @@ -43,6 +45,9 @@ def test_resnet_50( if (device.compute_with_storage_grid_size().x, device.compute_with_storage_grid_size().y) == (8, 7): pytest.skip("Test is not supported on n300 (8,7) grid") + if batch_size > 16 and not is_blackhole(): + pytest.skip("Batch size > 16 is not supported on non-blackhole devices") + if is_blackhole() and use_pretrained_weight: pytest.skip( "Skipping pretrained weight test on blackhole due to PCC error: https://github.com/tenstorrent/tt-metal/issues/17558" diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp index cb70977b246..4033fca3b98 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp @@ -273,4 +273,5 @@ void kernel_main() { } // act_w_num_outer cb_pop_front(tilized_in0_cb_id, act_block_num_tiles); } + noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp index dc9ea03e78c..679e5e4a4b5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp @@ -155,4 +155,5 @@ void kernel_main() { start_reader_idx += act_block_h_datums_second_reader_read; #endif } + noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index a88ed27882a..0a7f4048902 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -268,4 +268,5 @@ void kernel_main() { #ifdef SHARDED_OUT cb_wait_front(cb_id_out0, output_rows_tiles); #endif + noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index 37c8edb7701..b62b83d070b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -192,4 +192,5 @@ void kernel_main() { cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif + noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index 88744e90369..9e19f9eb8ff 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -335,4 +335,6 @@ void kernel_main() { cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif + + noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp index ca3d56d8f77..72cf59fbd4f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp @@ -119,14 +119,14 @@ std::vector fold_with_transpose_( ttnn::MemoryConfig create_sharded_memory_config( ttnn::Shape tensor_shape, - CoreCoord grid_size, - ShardOrientation orientation, + const CoreRangeSet& grid_size, + const ShardOrientation orientation, const std::optional& override_memory_config = std::nullopt) { if (override_memory_config.has_value()) { return override_memory_config.value(); } - uint32_t total_cores = grid_size.x * grid_size.y; + uint32_t total_cores = grid_size.num_cores(); uint32_t tensor_height = tensor_shape[-2] * tensor_shape[-3] * tensor_shape[-4]; uint32_t tensor_width = tensor_shape[-1]; @@ -136,10 +136,9 @@ ttnn::MemoryConfig create_sharded_memory_config( auto sharded_memory_config = ttnn::MemoryConfig{ .memory_layout = ttnn::TensorMemoryLayout::HEIGHT_SHARDED, .buffer_type = ttnn::BufferType::L1, - .shard_spec = ShardSpec{ - CoreRangeSet{std::set{CoreRange{CoreCoord{0, 0}, CoreCoord{grid_size.x - 1, grid_size.y - 1}}}}, - {shard_height, shard_width}, - orientation}}; + .shard_spec = ShardSpec{grid_size, {shard_height, shard_width}, orientation}}; + + tt::log_debug(tt::LogOp, "sharded_memory_config: {}", sharded_memory_config); return sharded_memory_config; } @@ -153,7 +152,7 @@ std::vector fold_with_transpose_sharded_( uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, - CoreCoord grid_size, + const CoreRangeSet& grid_size, const std::optional& override_memory_config) { using namespace tt::constants; IDevice* device; @@ -292,7 +291,7 @@ Tensor FoldOperation::invoke( uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, - const std::optional grid_size, + const std::optional& core_grid, const std::optional& override_memory_config) { if (use_transpose_as_fold) { if (input_tensor.is_sharded()) { @@ -306,7 +305,7 @@ Tensor FoldOperation::invoke( pad_c, pad_h, pad_w, - grid_size.value_or(CoreCoord(1, 1)), + core_grid.value_or(CoreRangeSet{CoreRange{CoreCoord{0, 0}, CoreCoord{1, 1}}}), override_memory_config) .at(0); } else { @@ -329,7 +328,7 @@ Tensor FoldOperation::invoke( uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, - const std::optional grid_size, + const std::optional& core_grid, const std::optional& override_memory_config) { QueueId queue_id = DefaultQueueId; return invoke( @@ -342,6 +341,6 @@ Tensor FoldOperation::invoke( pad_c, pad_h, pad_w, - grid_size); + core_grid); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.hpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.hpp index 7b52bd73666..87604a7f5e0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.hpp @@ -27,7 +27,7 @@ struct FoldOperation { uint32_t pad_c = 0, uint32_t pad_h = 0, uint32_t pad_w = 0, - const std::optional grid_size = std::nullopt, + const std::optional& core_grid = std::nullopt, const std::optional& override_memory_config = std::nullopt); static ttnn::Tensor invoke( QueueId queue_id, @@ -39,7 +39,7 @@ struct FoldOperation { uint32_t pad_c = 0, uint32_t pad_h = 0, uint32_t pad_w = 0, - const std::optional grid_size = std::nullopt, + const std::optional& core_grid = std::nullopt, const std::optional& override_memory_config = std::nullopt); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp index 4980a8e11e5..43383a60c53 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp @@ -37,7 +37,7 @@ void bind_fold_operation(py::module& module) { uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, - std::optional grid_size, + std::optional grid_size, std::optional override_memory_config, QueueId queue_id) -> ttnn::Tensor { return op( diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp index a009d7d00aa..350dc0c2b88 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp @@ -1143,8 +1143,6 @@ inline std::vector, std::vector>> get_ std::vector reader_kernel_args; reader_kernel_args.push_back(core_stick_map.size()); // num_cores - tt::log_debug("num_cores: {}", core_stick_map.size()); - for (const auto& core_stick_pair : core_stick_map) { auto xy_pair = core_stick_pair.first; if (row_major) { @@ -1154,9 +1152,6 @@ inline std::vector, std::vector>> get_ reader_kernel_args.push_back((std::uint32_t)xy_pair.first); // noc x reader_kernel_args.push_back((std::uint32_t)xy_pair.second); // noc y } - - tt::log_debug("xy_pair.first: {}", xy_pair.first); - tt::log_debug("xy_pair.second: {}", xy_pair.second); } // coalesce the sticks into chunks @@ -1164,17 +1159,12 @@ inline std::vector, std::vector>> get_ for (auto core_stick_pair : core_stick_map) { auto stick_chunks = group_contiguous_and_repeated_values(core_stick_pair.second); stick_chunks_per_core.push_back(stick_chunks); - reader_kernel_args.push_back(stick_chunks.size()); // num_chunks for current core - tt::log_debug("chunk_size: {}", stick_chunks.size()); } for (const auto& stick_chunks : stick_chunks_per_core) { for (auto chunk : stick_chunks) { - reader_kernel_args.push_back(chunk[0]); // start id of a chunk - tt::log_debug("chunk_start_id: {}", chunk[0]); - + reader_kernel_args.push_back(chunk[0]); // start id of a chunk reader_kernel_args.push_back(chunk.size()); // length of a chunk - tt::log_debug("chunk_length: {}", chunk.size()); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp index 31ceef8f604..d42273701e2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_program_factory.cpp @@ -513,8 +513,6 @@ inline std::vector, std::vector>> get_ std::vector reader_kernel_args; reader_kernel_args.push_back(core_stick_map.size()); // num_cores - tt::log_debug("num_cores: {}", core_stick_map.size()); - for (const auto& core_stick_pair : core_stick_map) { auto xy_pair = core_stick_pair.first; if (row_major) { @@ -524,9 +522,6 @@ inline std::vector, std::vector>> get_ reader_kernel_args.push_back(xy_pair.first); // noc x reader_kernel_args.push_back(xy_pair.second); // noc y } - - tt::log_debug("xy_pair.first: {}", xy_pair.first); - tt::log_debug("xy_pair.second: {}", xy_pair.second); } // coalesce the sticks into chunks @@ -536,15 +531,11 @@ inline std::vector, std::vector>> get_ stick_chunks_per_core.push_back(stick_chunks); reader_kernel_args.push_back(stick_chunks.size()); // num_chunks for current core - tt::log_debug("chunk_size: {}", stick_chunks.size()); } for (const auto& stick_chunks : stick_chunks_per_core) { for (auto chunk : stick_chunks) { - reader_kernel_args.push_back(chunk[0]); // start id of a chunk - tt::log_debug("chunk_start_id: {}", chunk[0]); - + reader_kernel_args.push_back(chunk[0]); // start id of a chunk reader_kernel_args.push_back(chunk.size()); // length of a chunk - tt::log_debug("chunk_length: {}", chunk.size()); } } diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index 73ef8d67cfb..79d25ce9c5c 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -198,14 +198,6 @@ void MAIN { // accumulation is done by iterating matmul_block across inner dim // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride // in0 - -#ifdef ARCH_BLACKHOLE - // FIXME: This is a temporary workaround to avoid hangs on blackhole. - // https://github.com/tenstorrent/tt-metal/issues/16439 - for (uint32_t i = 0; i < 10; i++) { - asm volatile("nop"); - } -#endif matmul_block( in0_cb_id, in1_cb_id, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp index a56029d83bb..606e3c6235b 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp @@ -284,4 +284,6 @@ void kernel_main() { in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride; } } + + noc_async_write_barrier(); }