From 816545d56e4343edbbd87f03e4fe9f3cdde10e4f Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Fri, 22 Nov 2024 06:40:44 +0000 Subject: [PATCH] Apply changes to models --- models/demos/llama3/tt/llama_attention.py | 4 +- models/demos/llama3/tt/llama_mlp.py | 2 +- .../tt/multimodal/llama_cross_attention.py | 4 +- models/demos/qwen/tt/qwen_attention.py | 4 +- models/demos/qwen/tt/qwen_mlp.py | 2 +- models/demos/t3000/falcon40b/tt/falcon_mlp.py | 4 +- .../llama2_70b/tt/llama_mlp_optimized.py | 4 +- models/demos/tg/llama3_70b/tt/llama_common.py | 2 +- .../operations/ccl/perf/test_ccl_perf.py | 12 ++-- .../test_reduce_scatter_N300_post_commit.py | 6 +- .../ccl/test_reduce_scatter_TG_nightly.py | 4 +- .../ccl/test_reduce_scatter_nightly.py | 12 ++-- .../ccl/test_reduce_scatter_post_commit.py | 56 +++++++++---------- 13 files changed, 58 insertions(+), 58 deletions(-) diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 9a51aad2a743..927c0a6ed826 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -357,7 +357,7 @@ def forward_decode( if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( dense_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -530,7 +530,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( output_11SH, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index f06e2ff63f1f..88b449277158 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -137,7 +137,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if self.args.is_multichip: w2_out_reduced = ttnn.reduce_scatter( w2_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 63f87fbeb731..fb9266c23a18 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -271,7 +271,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, if self.is_multichip: dense_out_reduced = ttnn.reduce_scatter( output, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -358,7 +358,7 @@ def forward_prefill( if self.is_multichip: # TODO use_fused_all_gather_matmul dense_out_reduced = ttnn.reduce_scatter( output, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/qwen/tt/qwen_attention.py b/models/demos/qwen/tt/qwen_attention.py index ba598cc96c1d..0e80c47b228d 100644 --- a/models/demos/qwen/tt/qwen_attention.py +++ b/models/demos/qwen/tt/qwen_attention.py @@ -414,7 +414,7 @@ def forward_decode( if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( dense_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -598,7 +598,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( output_11SH, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/qwen/tt/qwen_mlp.py b/models/demos/qwen/tt/qwen_mlp.py index e07d4943d1c6..ad5008539200 100644 --- a/models/demos/qwen/tt/qwen_mlp.py +++ b/models/demos/qwen/tt/qwen_mlp.py @@ -142,7 +142,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if self.args.is_multichip: w2_out_reduced = ttnn.reduce_scatter( w2_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG, diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index 1788c3ac6b62..5101b309d4dc 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -124,7 +124,7 @@ def fwd_decode(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states = ttnn.get_device_tensors( ttnn.reduce_scatter( ttnn.aggregate_as_tensor(hidden_states), - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, # only unidirectional supported for now memory_config=self.model_config["DEFAULT_MEMCFG"], @@ -200,7 +200,7 @@ def fwd_prefill(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states = ttnn.get_device_tensors( ttnn.reduce_scatter( ttnn.aggregate_as_tensor(hidden_states), - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, # only one link supported for now memory_config=self.model_config["DEFAULT_MEMCFG"], diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 2861253da1a9..a185fc605f0f 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -219,7 +219,7 @@ def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states_reduced = ttnn.reduce_scatter( hidden_states_mm, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -268,7 +268,7 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states_reduced = ttnn.reduce_scatter( hidden_states, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=self.model_config["RESIDUAL_16_CORES_OUTPUT_MEMCFG"], diff --git a/models/demos/tg/llama3_70b/tt/llama_common.py b/models/demos/tg/llama3_70b/tt/llama_common.py index 1b16fde6a60b..9824afbc44c0 100644 --- a/models/demos/tg/llama3_70b/tt/llama_common.py +++ b/models/demos/tg/llama3_70b/tt/llama_common.py @@ -93,7 +93,7 @@ def tt_composite_sharded_all_reduce( input_mem_cfg = input_tensor.memory_config() reduce_scattered_tensor = ttnn.reduce_scatter( input_tensor, - scatter_dim=dim, + dim=dim, math_op=ttnn.ReduceType.Sum, num_links=num_links, cluster_axis=cluster_axis, diff --git a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py index 1429eb0fce12..800d25befb8d 100644 --- a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py +++ b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py @@ -141,7 +141,7 @@ def test_all_gather_on_t3000( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), @@ -171,7 +171,7 @@ def test_reduce_scatter_on_t3000( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -187,7 +187,7 @@ def test_reduce_scatter_on_t3000( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -210,7 +210,7 @@ def test_reduce_scatter_on_t3000( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), @@ -239,7 +239,7 @@ def test_reduce_scatter_on_n300( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -254,7 +254,7 @@ def test_reduce_scatter_on_n300( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py index c34c4fd61913..086efb1d534c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py @@ -20,7 +20,7 @@ ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), @@ -50,7 +50,7 @@ def test_ring_reduce_scatter_n300_post_commit( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -65,7 +65,7 @@ def test_ring_reduce_scatter_n300_post_commit( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 9e9fbf479f5d..1b5bfe8f6724 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -145,7 +145,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) ttnn_tensor_out = ttnn.reduce_scatter( ttnn_tensor, - scatter_dim=dim, + dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, math_op=math_op, @@ -158,7 +158,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( for _ in range(num_iters): ttnn_tensor_out = ttnn.reduce_scatter( ttnn_tensor, - scatter_dim=dim, + dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, math_op=math_op, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py index 5a00f7883ab2..aaf8e21fc10e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py @@ -19,7 +19,7 @@ ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), @@ -58,7 +58,7 @@ def test_reduce_scatter_t3k_8chip_nightly( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -73,7 +73,7 @@ def test_reduce_scatter_t3k_8chip_nightly( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -95,7 +95,7 @@ def test_reduce_scatter_t3k_8chip_nightly( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), @@ -136,7 +136,7 @@ def test_reduce_scatter_t3k_4chip_nightly( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -151,7 +151,7 @@ def test_reduce_scatter_t3k_4chip_nightly( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py index 916682dd84e4..4efe5152448c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py @@ -10,7 +10,7 @@ from models.utility_functions import skip_for_grayskull -def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): +def is_unsupported_case(input_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 tensor_size_bytes = elem_size for i in input_shape: @@ -19,7 +19,7 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devic if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: return True, "L1 buffer can't support large tensor sizes" - # if input_dtype == ttnn.bfloat8_b and tuple(input_shape) == (1, 1, 2048, 1024) and scatter_dim == 3: + # if input_dtype == ttnn.bfloat8_b and tuple(input_shape) == (1, 1, 2048, 1024) and dim == 3: # return True, "Known failure with bfp8_b data format" return False, "" @@ -28,7 +28,7 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devic def run_with_trace( t3k_mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, output_mem_config, @@ -41,7 +41,7 @@ def run_with_trace( logger.info("Compiling model") output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -58,7 +58,7 @@ def run_with_trace( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -84,7 +84,7 @@ def run_reduce_scatter_test( mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -105,7 +105,7 @@ def run_reduce_scatter_test( debug = False (is_known_failure, message) = is_unsupported_case( - per_chip_output_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout + per_chip_output_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout ) if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") @@ -114,11 +114,11 @@ def run_reduce_scatter_test( if enable_async: logger.info(f"Using Async Mode for Reduce Scatter Op Dispatch") - logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}") + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, dim: {dim}") # Generate input tensors canonical_input_shape = per_chip_output_shape.copy() - canonical_input_shape[scatter_dim] *= num_devices + canonical_input_shape[dim] *= num_devices tt_input_tensors = [] numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] @@ -143,7 +143,7 @@ def run_reduce_scatter_test( output_tensor_mesh = run_with_trace( mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, mem_config, @@ -154,7 +154,7 @@ def run_reduce_scatter_test( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=mem_config, @@ -172,7 +172,7 @@ def run_reduce_scatter_test( for i, t in enumerate(input_tensors): golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() - golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, scatter_dim) + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, dim) tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) logger.info(f"Compare") @@ -211,7 +211,7 @@ def run_reduce_scatter_test( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 2, 256, 32 * 8], 3, ttnn.TILE_LAYOUT), # Input tensor is (16*32) x (64*32) = 8 * input tensor shape ([1, 1, 32, 32 * 8], 3, ttnn.TILE_LAYOUT), @@ -241,7 +241,7 @@ def test_ring_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -256,7 +256,7 @@ def test_ring_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -279,7 +279,7 @@ def test_ring_reduce_scatter_post_commit( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 32 * 8], 3, ttnn.TILE_LAYOUT), ([1, 2, 224, 32 * 8], 3, ttnn.TILE_LAYOUT), @@ -306,7 +306,7 @@ def test_line_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -321,7 +321,7 @@ def test_line_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -345,7 +345,7 @@ def test_line_reduce_scatter_post_commit( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 1280], 1, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 1, ttnn.TILE_LAYOUT), @@ -369,7 +369,7 @@ def test_line_reduce_scatter_post_commit_4chip( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -384,7 +384,7 @@ def test_line_reduce_scatter_post_commit_4chip( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -403,7 +403,7 @@ def run_reduce_scatter_sharded_test( num_devices, per_chip_output_shape, output_shard_shape, - scatter_dim, + dim, num_links, math_op, shard_grid, @@ -427,7 +427,7 @@ def run_reduce_scatter_sharded_test( f"Not enough devices on machine to implement test case. Wanted {num_devices} but found {len(t3k_mesh_device.get_device_ids())}" ) - logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}") + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, dim: {dim}") debug = False @@ -438,7 +438,7 @@ def run_reduce_scatter_sharded_test( assert in_shard_override is None in_shard_grid = shard_grid input_shard_shape = list(output_shard_shape) - if scatter_dim == 3: + if dim == 3: input_shard_shape[1] *= num_devices else: input_shard_shape[0] *= num_devices @@ -468,7 +468,7 @@ def run_reduce_scatter_sharded_test( ) canonical_input_shape = list(per_chip_output_shape) - canonical_input_shape[scatter_dim] *= num_devices + canonical_input_shape[dim] *= num_devices numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] input_tensors = [ @@ -492,7 +492,7 @@ def run_reduce_scatter_sharded_test( output_tensor_mesh = run_with_trace( t3k_mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, output_mem_config, @@ -504,7 +504,7 @@ def run_reduce_scatter_sharded_test( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -521,7 +521,7 @@ def run_reduce_scatter_sharded_test( for i, t in enumerate(input_tensors): golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() - golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, scatter_dim) + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, dim) tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) logger.info(f"Compare")