From b80a975f1d480c1deb5cf2e520c9797c058c0903 Mon Sep 17 00:00:00 2001 From: Saad Jameel <163029024+sjameelTT@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:21:38 -0500 Subject: [PATCH] Add transpose WH sharded, generalize row major permute when N > 4, and do a minor refactor of ttnn::permute (#15881) ### Ticket #14790 add transpose wh sharded implementation when shard shape < height dimension #15165 add N-d permute with width dimension #15589 correct permute dimensionality when less than 4D #15750 remove the composite flag from permute #12550 re-enable some permute tests for blackhole #12349 re-enable working transpose tests for blackhole #16066 disable test uniform as it's stochastic ### Problem description This PR addresses several permute and transpose problems all at once - Transpose WH sharded does not currently work when the shard shape is less than the height - Permute on greater than 4 dimensions does not work when moving width around (for both tiled and RM) - The Permute kernel when width doesn't change is single core - Permute has an unclean API in which we have a composite flag that is not generically applicable - Permute on less than 4 dimensions gets an incorrect output shape in cases where it's a no-op - Permute tests are disabled for BH due to LLK issues - Transpose tests are disabled for BH due to LLK issues ### What's changed - Add transpose WH sharded implementation for when shard shape is less than the height dim (outputs a block sharded output) - Add an N-d permute kernel that works generically on any row major input. We have to call a global init each loop of the compute kernel as transpose sets some registers that aren't cleared (there's no transpose_uninit). This results in bad pcc when there's more than one loop. For GS/BH, even the global init doesn't solve the problem so the test is disabled. For Tiled, we need 5D untilize/tilize. This increases sweeps coverage for permute from **50%** to **86%** - For the optimized case where Permute's width dimension is not shuffled, make the kernel multicore - Remove composite flag that is default set to to make permute non-generic. This has caused forge models to have bad pcc as they were not aware of this optional argument. - Refactor ttnn::permute to add nop checks and correct shape calculations - Re-enable permute and transpose tests for blackhole When replacing variants of transpose with this RM permute kernel, a lot of tests on BH/GS failed, so I will do that in a follow-up to address. The LLK issues are causing pains there. If we get N-d untilize/tilize support and once the LLK issues are fixed, permute should have the ability to be generic. The remaining issues for the pytorch 2.0 sweeps after the untilize/tilize fix are the CB overflow on transpose wh, which should be fixed out of the box when we replace the kernel that is used (which I am not doing in this PR since it doesn't work for GS/BH atm). ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12367177499/job/34547311782 (failing test is failing on main) - [x] Blackhole Post commit (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12367175575 - [x] Model regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12357119737 - [x] Device performance regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12357115316 - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- .../pytests/tt_dnn/test_permute.py | 1 - .../unit_testing/misc/test_transpose.py | 84 ++++- .../unit_tests/operations/test_permute.py | 101 +++++- .../unit_tests/operations/test_uniform.py | 1 + .../data_movement/common/kernels/common.hpp | 39 +++ .../transpose_xw_rm_single_tile_size.cpp | 65 ++++ ...permute_interleaved_rm_blocked_generic.cpp | 124 ++++++++ ..._permute_interleaved_rm_row_invariant.cpp} | 7 +- ...permute_interleaved_rm_blocked_generic.cpp | 167 ++++++++++ ..._permute_interleaved_rm_row_invariant.cpp} | 13 +- .../device/permute_device_operation.cpp | 13 +- .../device/permute_device_operation.hpp | 28 +- .../device/permute_program_factory.cpp | 288 ++++++++++++++++-- .../data_movement/permute/permute.cpp | 105 ++----- .../data_movement/permute/permute.hpp | 1 - .../data_movement/permute/permute_pybind.cpp | 2 +- .../transpose/device/transpose_op.cpp | 88 ++++-- .../device/transpose_program_factory.cpp | 75 +++-- .../device/attn_matmul_device_operation.cpp | 4 +- ttnn/cpp/ttnn/tensor/shape/shape.cpp | 12 + ttnn/cpp/ttnn/tensor/shape/shape.hpp | 2 + 21 files changed, 1041 insertions(+), 179 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{reader_permute_interleaved_rm.cpp => reader_permute_interleaved_rm_row_invariant.cpp} (78%) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{writer_permute_interleaved_rm.cpp => writer_permute_interleaved_rm_row_invariant.cpp} (80%) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py index d9ab2571a58..98699e2d4f2 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py @@ -20,7 +20,6 @@ ] -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("input_shapes, permute_args", params) def test_run_permute_test(input_shapes, permute_args, device, function_level_defaults): datagen_func = [ diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 489b25ba5e9..a8f8385c059 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -299,7 +299,6 @@ def test_transpose_wh_sharded_program_cache(dtype, device, use_program_cache): ) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [1]) @pytest.mark.parametrize("c", [1]) @@ -333,7 +332,6 @@ def test_tranpose_hw_rm_with_padding(device, n, c, h, w): assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @@ -369,13 +367,10 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache): memory_config=ttnn.L1_MEMORY_CONFIG, ) activation_pyt_padded = ttnn.transpose(activation_pyt_padded, 2, 3, memory_config=ttnn.L1_MEMORY_CONFIG) - activation_pyt_padded_out = ttnn.to_memory_config(activation_pyt_padded, ttnn.L1_MEMORY_CONFIG) - activation_pyt_padded_out = ttnn.from_device(activation_pyt_padded_out) - activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded_out) + activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded) assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @@ -402,7 +397,7 @@ def test_tranpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cache @pytest.mark.parametrize("c", [224]) @pytest.mark.parametrize("h", [16]) @pytest.mark.parametrize("w", [112]) -def test_tranpose_hw_sharded_rm(device, n, c, h, w): +def test_transpose_hw_sharded_rm(device, n, c, h, w): torch.manual_seed(2005) torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) torch_output_tensor = torch_input_tensor.transpose(2, 3) @@ -469,7 +464,6 @@ def run_tranpose_hw_sharded_rm_with_program_cache(device, n, c, h, w): assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @pytest.mark.parametrize("h", [128]) @@ -581,7 +575,6 @@ def run_tranpose_hc_sharded(device, n, c, h, w, grid_size): assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize( "n, c, h, w, grid_size", [ @@ -1011,3 +1004,76 @@ def test_transpose_forge_hc(device, b, h, w, dim0, dim1): output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("n", [1]) +@pytest.mark.parametrize("c", [1]) +@pytest.mark.parametrize("h", [256]) +@pytest.mark.parametrize("w", [32]) +def test_tranpose_hw_sharded_tiled_8_cores(device, n, c, h, w): + torch.manual_seed(2005) + torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(2, 3) + tt_input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + sharded_mem_config = ttnn.create_sharded_memory_config( + (32, 32), + core_grid=ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 6)), + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(1, 0)), + } + ), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.COL_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config) + + tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) + + assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) + + +@pytest.mark.parametrize("n", [1]) +@pytest.mark.parametrize("c", [1]) +@pytest.mark.parametrize("h", [224]) +@pytest.mark.parametrize("w", [32]) +def test_tranpose_hw_sharded_tiled_n_cores(device, n, c, h, w): + torch.manual_seed(2005) + torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(2, 3) + tt_input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + sharded_mem_config = ttnn.create_sharded_memory_config( + (32, 32), + core_grid=ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, h // 32 - 1)), + } + ), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.COL_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config) + + tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config) + tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) + + assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 40a57515f56..cc09f7d7d5e 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -7,9 +7,10 @@ import torch import ttnn +import itertools from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import is_blackhole +from models.utility_functions import is_blackhole, is_grayskull, skip_for_grayskull, skip_for_blackhole @pytest.mark.parametrize("h", [32]) @@ -171,3 +172,101 @@ def test_permute_pad_value(device, pad_value): assert ttnn.to_torch(a) == float("-inf") tt_output = ttnn.to_torch(tt_output) assert_with_pcc(torch_output, tt_output, 0.9999) + + +def generate_permutations(N): + """ + Generator function that yields all permutations of tuples with values 0 to N-1. + + :param N: The number defining the range of values (0 to N-1). + :yield: Tuples representing each permutation. + """ + for perm in itertools.permutations(range(N)): + yield perm + + +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") +@pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)]) +@pytest.mark.parametrize("perm", generate_permutations(5)) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) +def test_permute_5d_width(shape, perm, memory_config, dtype, device): + torch.manual_seed(2005) + input_a = torch.randn(shape) + torch_output = torch.permute(input_a, perm) + + tt_input = ttnn.from_torch( + input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config + ) + + tt_output = ttnn.permute(tt_input, perm) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) + + +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") +@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)]) +@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)]) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) +def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): + torch.manual_seed(520) + input_a = torch.randn(shape) + + torch_output = torch.permute(input_a, perm) + + tt_input = ttnn.from_torch( + input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config + ) + + tt_output = ttnn.permute(tt_input, perm) + tt_output = ttnn.to_torch(tt_output) + + assert_with_pcc(torch_output, tt_output, 0.9999) + + +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") +def test_permute_nd(device): + torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + output_tensor = ttnn.permute(input_tensor, (0, 2, 4, 3, 5, 1)) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, (0, 2, 4, 3, 5, 1)) + assert_with_pcc(torch_output, output_tensor, 0.9999) + + +def test_permute_squeeze(device): + ones = ttnn.ones((1, 1, 3)) + tensor = ttnn.to_device(ones, device) + out = ttnn.permute(tensor, (0, 1, 2)) + assert_with_pcc(ttnn.to_torch(out), ttnn.to_torch(ones), 0.9999) + + +@pytest.mark.parametrize("shape", [(1, 49, 768)]) +@pytest.mark.parametrize("perm", generate_permutations(3)) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) +def test_permute_3D(shape, perm, layout, memory_config, dtype, device): + if is_grayskull() and dtype == ttnn.float32: + pytest.skip("Grayskull doesn't support float32") + torch_tensor = torch.rand(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=dtype, memory_config=memory_config) + output_tensor = ttnn.permute(input_tensor, perm) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, perm) + assert torch_output.shape == output_tensor.shape + assert_with_pcc(torch_output, output_tensor, 0.9999) + + +def test_nil_volume_permute(device): + torch_tensor = torch.rand([1, 0, 30, 32], dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.permute(input_tensor, (0, 1, 3, 2)) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, (0, 1, 3, 2)) + assert torch_output.shape == output_tensor.shape + assert_with_pcc(torch_output, output_tensor, 0.9999) diff --git a/tests/ttnn/unit_tests/operations/test_uniform.py b/tests/ttnn/unit_tests/operations/test_uniform.py index 9c3f05a6a6a..abdfd9aaa31 100644 --- a/tests/ttnn/unit_tests/operations/test_uniform.py +++ b/tests/ttnn/unit_tests/operations/test_uniform.py @@ -94,6 +94,7 @@ def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, m ) +@pytest.mark.skip("#16066: Undefined behaviour. It will fail on some runs and pass on others since it's stochastic.") @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "shape", diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 27c68f53b18..34cf4e3eb3b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -137,4 +137,43 @@ template FORCE_INLINE constexpr uint32_t round_up() { return b * div_up(); } + +// Function template to swap two elements in a uint32_t array +template +FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { + // Perform the swap + uint32_t temp = array[i]; + array[i] = array[j]; + array[j] = temp; +} + +// 2D Transpose function for debug use in reader/writer kernels +FORCE_INLINE void transpose_2d( + uint32_t input_l1_addr, + uint32_t output_l1_addr, + uint32_t X, + uint32_t W, + uint32_t element_size, + uint32_t input_page_size, + uint32_t output_page_size) { + volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); + volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); + // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner + // each element is element_size bytes + // each row is W elements, and each row is separated by input_page_size bytes + // each output row is X elements, and each row is separated by output_page_size bytes + + for (uint32_t x = 0; x < X; ++x) { + for (uint32_t w = 0; w < W; ++w) { + // Compute the input and output addresses + uint32_t input_addr = x * input_page_size + w * element_size; + uint32_t output_addr = w * output_page_size + x * element_size; + // Copy the element - do we have memcpy? use this for now + for (uint32_t i = 0; i < element_size; ++i) { + output_ptr[output_addr + i] = input_ptr[input_addr + i]; + } + } + } +} + } // namespace tt::data_movement::common diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp new file mode 100644 index 00000000000..41151070064 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" +#include "compute_kernel_api/transpose_wh.h" +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/untilize.h" +#include "compute_kernel_api/pack_untilize.h" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t x_block_size = get_compile_time_arg_val(0); + constexpr uint32_t w_block_size = get_compile_time_arg_val(1); + + uint32_t num_blocks = get_arg_val(0); + + constexpr auto cb_in = tt::CBIndex::c_0; + constexpr auto cb_tilize = tt::CBIndex::c_1; + constexpr auto cb_out = tt::CBIndex::c_2; + + unary_op_init_common(cb_in, cb_out); + + for (uint32_t n = 0; n < num_blocks; n++) { + // tilize input via unpack and then pack + tilize_init_short(cb_in, 1); + + cb_wait_front(cb_in, x_block_size); + cb_reserve_back(cb_tilize, 1); + + tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize + + // tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit? + cb_push_back(cb_tilize, 1); + cb_pop_front(cb_in, x_block_size); + + tilize_uninit(cb_in); + + // transpose input + cb_wait_front(cb_tilize, 1); + transpose_wh_init_short(cb_tilize); + pack_untilize_dst_init_short<1>(cb_out); + + tile_regs_acquire(); + transpose_wh_tile(cb_tilize, 0, 0); // transpose call + tile_regs_commit(); + + // pack and untilize + cb_reserve_back(cb_out, w_block_size); + + tile_regs_wait(); + pack_untilize_dst<1>(cb_out); // pack call + tile_regs_release(); + + cb_push_back(cb_out, w_block_size); + + cb_wait_front(cb_out, w_block_size); + pack_untilize_uninit(cb_out); + + cb_pop_front(cb_tilize, 1); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp new file mode 100644 index 00000000000..f63aaab6d09 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp @@ -0,0 +1,124 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t input_cb_page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + constexpr uint32_t x_dim = get_compile_time_arg_val(4); + constexpr uint32_t num_blocks_total = get_compile_time_arg_val(5); + constexpr uint32_t x_blocks = get_compile_time_arg_val(6); + constexpr uint32_t w_blocks = get_compile_time_arg_val(7); + constexpr uint32_t x_block_size = get_compile_time_arg_val(8); + constexpr uint32_t w_block_size = get_compile_time_arg_val(9); + constexpr uint32_t element_size = get_compile_time_arg_val(10); + constexpr uint32_t input_tensor_page_size = get_compile_time_arg_val(11); + + // Precomputed constants: size of a 32 element block along the W dimension (measured in bytes) + constexpr uint32_t w_block_size_bytes = w_block_size * element_size; + + const uint32_t src_addr = get_arg_val(0); + uint32_t start_block = get_arg_val(1); + uint32_t end_block = get_arg_val(2); + + // Input shape and strides (excluding W dimension and measured in rows, not bytes) + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args + uint32_t input_shape[N], src_strides[N]; + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + src_strides[i - 3] = get_arg_val(i + N); + } + + /** + * We have a multidimensional tensor: + * - num_blocks_total = (rows * x_blocks * w_blocks) where rows = num_rows / X + * Here, 'rows' represent the combination of all rows before and after X dimension. + * So: rows * X * W_dimension = total number of elements (conceptually). + * + * For each 'block': + * - Compute which w_block and x_block this corresponds to. + * - Then compute which row set (xw_block) we are in. + */ + + // x_dim is the dimension along which we are reading the tensor, as it's the new W dimension in the output tensor + uint32_t X = input_shape[x_dim]; + uint32_t X_stride = src_strides[x_dim]; + + const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = input_tensor_page_size}; + + uint32_t idxs[N]; + idxs[N - 1] = 0; + uint32_t non_x_rows = num_rows / X; + + for (uint32_t block = start_block; block < end_block; ++block) { + // Decompose block into w_block, x_block, and xw_block indices + uint32_t rem = block; + const uint32_t w_block = rem % w_blocks; // Which W block are we in? + rem /= w_blocks; + + const uint32_t x_block = rem % x_blocks; // Which X block? + rem /= x_blocks; + + uint32_t xw_block = rem % (non_x_rows); // Which row set (beyond X dimension)? + uint32_t remainder = xw_block; + + // Compute X block boundaries + uint32_t x_start = x_block * x_block_size; + uint32_t x_end = min(x_start + x_block_size, X); + + // Compute W block boundaries + uint32_t w_start = w_block * w_block_size; + uint32_t w_end = min(w_start + w_block_size, input_shape[N - 1]); + uint32_t w_offset = w_start * element_size; + + uint32_t w_read_size_bytes = (w_end - w_start) * element_size; + + // Map linear index i to multidimensional indices idxs[] + // We skip x_dim when doing this mapping and set it separately later + for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + if (d == (int32_t)x_dim) { + idxs[d] = 0; // Initialize x_dim to zero (will be set in inner loop) + continue; // Skip x_dim during mapping + } + idxs[d] = remainder % input_shape[d]; + remainder /= input_shape[d]; + } + idxs[N - 1] = 0; // Initialize W dimension index to zero if not already set + + // Precompute the base address offset (excluding x_dim) + uint64_t base_addr_offset = 0; + for (uint32_t d = 0; d < N; ++d) { + if (d != x_dim) { + base_addr_offset += idxs[d] * src_strides[d]; + } + } + + // Reserve space in the circular buffer for the X-block length + cb_reserve_back(tt::CBIndex::c_0, x_block_size); + uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + + // We read in 'x_block_len' chunks along the X dimension + uint32_t page_offset = 0; + // Read along the X dimension + for (uint32_t x = x_start; x < x_end; ++x) { + // Compute the address offset for this index + uint64_t addr_offset = base_addr_offset + x * X_stride; + uint64_t src_noc_addr = get_noc_addr(addr_offset, s0, w_offset); + + // Perform async read of the current line (w_block_len elements) into L1 + noc_async_read(src_noc_addr, src_buffer_l1_addr + page_offset, w_read_size_bytes); + + // Advance output pointer by one page size for next row + page_offset += input_cb_page_size; + } + // Wait for all async reads to complete before proceeding + noc_async_read_barrier(); + // Push the filled block into the circular buffer + cb_push_back(tt::CBIndex::c_0, x_block_size); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp similarity index 78% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp index 73241cb9703..93a42f81325 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp @@ -12,16 +12,17 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t src_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; uint32_t curr_addr = src_addr; - for (uint32_t i = 0; i < num_rows; ++i) { + for (uint32_t row = start_row; row < end_row; ++row) { cb_reserve_back(tt::CBIndex::c_0, 1); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); - noc_async_read_page(i, s0, src_buffer_l1_addr); + noc_async_read_page(row, s0, src_buffer_l1_addr); noc_async_read_barrier(); - volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(src_buffer_l1_addr); cb_push_back(tt::CBIndex::c_0, 1); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp new file mode 100644 index 00000000000..5af2edb379f --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +void kernel_main() { + // Compile-time constants + constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t output_cb_page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + + constexpr uint32_t X = get_compile_time_arg_val(4); + constexpr uint32_t X_stride = get_compile_time_arg_val(5); + constexpr uint32_t x_dim = get_compile_time_arg_val(6); + + constexpr uint32_t W_stride = get_compile_time_arg_val(7); + constexpr uint32_t input_cb_page_size = get_compile_time_arg_val(8); + constexpr uint32_t element_size = get_compile_time_arg_val(9); + + constexpr uint32_t num_blocks_total = get_compile_time_arg_val(10); + constexpr uint32_t x_blocks = get_compile_time_arg_val(11); + constexpr uint32_t w_blocks = get_compile_time_arg_val(12); + constexpr uint32_t x_block_size = get_compile_time_arg_val(13); + constexpr uint32_t w_block_size = get_compile_time_arg_val(14); + constexpr uint32_t W = get_compile_time_arg_val(15); + constexpr uint32_t output_tensor_page_size = get_compile_time_arg_val(16); + + constexpr uint32_t cb_id_in = tt::CBIndex::c_2; + + // Precompute bytes-per-block along X + constexpr uint32_t x_block_size_bytes = x_block_size * element_size; + + // W dimension is always the last dimension + constexpr uint32_t w_dim = N - 1; + + // Calculate how many "non_x_rows" we have (these are the combinations of all dimensions except X) + constexpr uint32_t non_x_rows = num_rows / X; + + // Destination base address + const uint32_t dst_addr = get_arg_val(0); + const uint32_t start_block = get_arg_val(1); + const uint32_t end_block = get_arg_val(2); + + // Interleaved address configuration for the destination + const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_tensor_page_size}; + + // Input shape, permutation, and destination strides + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args + uint32_t input_shape[N], perm[N], dest_strides[N]; + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + perm[i - 3] = get_arg_val(i + N); + dest_strides[i - 3] = get_arg_val(i + 2 * N); + } + + // The source data was transposed between W and X by the previous kernel. + // Adjust input_shape and perm to reflect that swap. + tt::data_movement::common::swap_elements(input_shape, x_dim, w_dim); + for (uint32_t i = 0; i < N; i++) { + if (perm[i] == x_dim) { + perm[i] = w_dim; + } else if (perm[i] == w_dim) { + perm[i] = x_dim; + } + } + + // Find where the original X dimension ended up in the permuted output + uint32_t x_dim_in_dest = N; // Will hold the position of x_dim in the permuted array + for (uint32_t i = 0; i < N; ++i) { + if (perm[i] == x_dim) { + x_dim_in_dest = i; + break; + } + } + + uint32_t src_multi_idx[N] = {0}; + uint32_t dest_multi_idx[N] = {0}; + + // Process each block of data from start_block to end_block + for (uint32_t block = start_block; block < end_block; ++block) { + // Decompose linear block index into w_block, x_block, and xw_block + uint32_t rem = block; + + // w_block: portion of the W dimension handled by this block + const uint32_t w_block = rem % w_blocks; + rem /= w_blocks; + + // x_block: portion of the X dimension handled by this block + const uint32_t x_block = rem % x_blocks; + rem /= x_blocks; + + // xw_block: which "non-X row set" we are in + const uint32_t xw_block = rem % non_x_rows; + + // Compute start/end boundaries for the current X and W blocks + const uint32_t x_start = x_block * x_block_size; + const uint32_t x_end = min(x_start + x_block_size, X); + + const uint32_t w_start = w_block * w_block_size; + const uint32_t w_end = min(w_start + w_block_size, W); + + // Compute the read size for the X dimension + const uint32_t x_read_size_bytes = (x_end - x_start) * element_size; + const uint32_t x_offset = x_start * element_size; + + // Decode xw_block into multi-dimensional indices excluding the W dimension and X dimension + uint32_t remainder = xw_block; + for (int32_t d = N - 2; d >= 0; --d) { + if (d == (int32_t)x_dim) { + // Skip the original X dimension index during this mapping + continue; + } + src_multi_idx[d] = remainder % input_shape[d]; + remainder /= input_shape[d]; + } + + // Compute dest_multi_idx (excluding W dimension), and a base linear index + // for all dimensions except W and X. We'll add W and X offsets later. + uint32_t dest_linear_idx_base = 0; + for (uint32_t i = 0; i < N; ++i) { + uint32_t src_idx = perm[i]; + if (src_idx != x_dim) { + dest_multi_idx[i] = src_multi_idx[src_idx]; + // Accumulate partial index product for all dimensions except W + if (i < w_dim) { + dest_linear_idx_base += dest_multi_idx[i] * dest_strides[i]; + } + } + } + + // Wait for the transposed block data to be ready in the input CB + cb_wait_front(cb_id_in, w_block_size); + uint32_t transposed_buffer_read_addr = get_read_ptr(cb_id_in); + + // Iterate over the W dimension elements + for (uint32_t w = w_start; w < w_end; ++w) { + // Update indices for the current W + src_multi_idx[x_dim] = w; + dest_multi_idx[x_dim_in_dest] = w; + + // Compute final linear index for the current W + uint32_t dest_linear_idx = dest_linear_idx_base; + if (x_dim_in_dest < w_dim) { + dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; + } + + // Compute the NoC address for the output + uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0, x_offset); + + // Compute the L1 address from which to write (offset by W-block pages) + uint32_t l1_addr = transposed_buffer_read_addr + (w - w_start) * output_cb_page_size; + + // Perform an asynchronous write of the X-block to the destination + noc_async_write(l1_addr, dst_noc_addr, x_read_size_bytes); + } + + // Wait until all writes are completed before proceeding to the next block + noc_async_write_barrier(); + + // Pop the block from the input circular buffer, as we're done writing it + cb_pop_front(cb_id_in, w_block_size); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp similarity index 80% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp index 34be75dfdf4..a06e5d56892 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp @@ -12,19 +12,22 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t dst_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = page_size}; + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args uint32_t input_shape[N], perm[N], dest_strides[N]; - for (uint32_t i = 1; i <= N; i++) { - input_shape[i - 1] = get_arg_val(i); - perm[i - 1] = get_arg_val(i + N); - dest_strides[i - 1] = get_arg_val(i + 2 * N); + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + perm[i - 3] = get_arg_val(i + N); + dest_strides[i - 3] = get_arg_val(i + 2 * N); } uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); uint32_t curr_addr = dst_addr; - for (uint32_t row = 0; row < num_rows; ++row) { + for (uint32_t row = start_row; row < end_row; ++row) { // Compute multi-dimensional index for the source row uint32_t src_multi_idx[N]; size_t remaining = row; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index bbe319681bb..8bc4bece3b0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -12,7 +12,12 @@ namespace ttnn::operations::data_movement { PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return SingleCore{}; + // If the last dimension is not permuted, we can use the row-invariant kernel + if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) { + return MultiCoreRowInvariant{}; + } + // Otherwise, we need to use the blocked generic, row moving kernel + return MultiCoreBlockedGeneric{}; } void PermuteDeviceOperation::validate_on_program_cache_miss( @@ -20,10 +25,6 @@ void PermuteDeviceOperation::validate_on_program_cache_miss( TT_FATAL( attributes.dims.size() == tensor_args.input_tensor.get_logical_shape().rank(), "Permute dimensions must match input tensor rank"); - TT_FATAL( - attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1, - "Last dimension of permute must be the last dimension of the input tensor as page-breaking is not supported at " - "the moment"); TT_FATAL(tensor_args.input_tensor.is_sharded() == false, "Permute operation does not support sharded input tensor"); TT_FATAL( tensor_args.input_tensor.get_layout() == Layout::ROW_MAJOR, "Permute operation only supports row-major layout"); @@ -34,7 +35,7 @@ void PermuteDeviceOperation::validate_on_program_cache_hit( PermuteDeviceOperation::shape_return_value_t PermuteDeviceOperation::compute_output_shapes( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { - SmallVector shape, padded_shape; + SmallVector shape; auto input_shape = tensor_args.input_tensor.get_logical_shape(); shape.reserve(input_shape.rank()); for (auto dim : attributes.dims) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 2f9481feb8c..05e251e8ca8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -30,11 +30,12 @@ struct PermuteDeviceOperation { using tensor_return_value_t = Tensor; - struct SingleCore { + struct MultiCoreRowInvariant { // Shared variables are the variables that are shared between the create and override_runtime_arguments methods struct shared_variables_t { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; + CoreRangeSet core_range; }; using cached_program_t = ttnn::device_operation::CachedProgram; @@ -49,7 +50,30 @@ struct PermuteDeviceOperation { const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + + struct MultiCoreBlockedGeneric { + // Shared variables are the variables that are shared between the create and override_runtime_arguments methods + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + KernelHandle compute_kernel_id; + CoreRangeSet core_range; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 29f6065cb5b..56bfe893f5d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -4,18 +4,20 @@ #include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" #include "tt_metal/common/work_split.hpp" +#include "noc/noc_parameters.h" // DRAM_ALIGNMENT namespace ttnn::operations::data_movement { namespace detail { uint32_t num_pages(const ttnn::Tensor& input_tensor) { - const auto& padded_shape = input_tensor.get_logical_shape(); - return padded_shape.volume() / padded_shape[-1]; + const auto& shape = input_tensor.get_logical_shape(); + return shape.volume() / shape[-1]; } uint32_t page_size(const ttnn::Tensor& input_tensor) { - const auto& padded_shape = input_tensor.get_logical_shape(); // in anticipation of RM padding - return padded_shape[-1] * input_tensor.element_size(); + auto BUFFER_ALIGNMENT = input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT; + const auto& shape = input_tensor.get_logical_shape(); // in anticipation of RM padding + return tt::round_up(shape[-1] * input_tensor.element_size(), BUFFER_ALIGNMENT); } std::vector get_row_strides(const ttnn::SimpleShape& shape) { @@ -27,9 +29,10 @@ std::vector get_row_strides(const ttnn::SimpleShape& shape) { } return strides; } + } // namespace detail -PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::SingleCore::create( +PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOperation::MultiCoreRowInvariant::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -55,56 +58,80 @@ PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::Sin tt::tt_metal::Device* device = input_tensor.device(); uint32_t src0_cb_index = tt::CBIndex::c_0; - uint32_t num_input_pages_to_read = 1; + uint32_t num_input_pages_to_read = 2; + + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_rows); - CoreRange core({0, 0}, {0, 0}); tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig( num_input_pages_to_read * input_rm_page_size, {{src0_cb_index, cb_data_format}}) .set_page_size(src0_cb_index, input_rm_page_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); uint32_t N = operation_attributes.dims.size(); - uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "reader_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "writer_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - std::vector reader_runtime_args = {src_buffer->address()}; - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; auto input_shape_view = input_tensor.get_logical_shape().view(); auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); // in anticipation of RM padding - std::vector writer_runtime_args = {dst_buffer->address()}; + std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); writer_runtime_args.insert( writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + auto cores = corerange_to_cores(all_cores, std::nullopt); + uint32_t start_row = 0; + uint32_t num_rows_per_core = 0; + for (const auto& core : cores) { + if (core_group_1.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_2; + } else { + // no-op + num_rows_per_core = 0; + } + uint32_t end_row = start_row + num_rows_per_core; + reader_runtime_args[1] = start_row; + reader_runtime_args[2] = end_row; + writer_runtime_args[1] = start_row; + writer_runtime_args[2] = end_row; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + start_row = end_row; + } return { std::move(program), {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; } -void PermuteDeviceOperation::SingleCore::override_runtime_arguments( +void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -118,15 +145,228 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments( auto src_buffer = input_tensor.buffer(); auto dst_buffer = output_tensor.buffer(); + auto& all_cores = cached_program.shared_variables.core_range; - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); + auto cores = corerange_to_cores(all_cores, std::nullopt); + for (const auto& core : cores) { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); runtime_args[0] = src_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args_writer[0] = dst_buffer->address(); } +} + +PermuteDeviceOperation::MultiCoreBlockedGeneric::cached_program_t +PermuteDeviceOperation::MultiCoreBlockedGeneric::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + + tt::tt_metal::Program program{}; + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t w_block_size = constants::TILE_WIDTH; + uint32_t input_cb_page_size = w_block_size * input_tensor.element_size(); + + tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + uint32_t x_block_size = constants::TILE_HEIGHT; + uint32_t output_cb_page_size = x_block_size * input_tensor.element_size(); + + tt::tt_metal::Device* device = input_tensor.device(); + + uint32_t src0_cb_index = tt::CBIndex::c_0; + uint32_t src1_cb_index = tt::CBIndex::c_2; + uint32_t src2_cb_index = tt::CBIndex::c_1; + uint32_t num_input_pages_to_read = 2; + + // we are focused on reading one row at a time, in a pattern that allows us to write an entire output row at a time + // if W is being swapped with another dim X (e.g. H), then we need to read X rows at a time (X is the new row + // dimension) CB is thus X pages in size (X*W*element_size) we read in X input rows of size W, and write out W + // output rows of size X find the new row dimension (X) + + uint32_t x_dim = operation_attributes.dims.back(); + uint32_t X = input_tensor.get_logical_shape()[x_dim]; + // stride from one row to the next for each dim in the input tensor + auto input_strides = detail::get_row_strides(input_tensor.get_logical_shape()); + uint32_t X_stride = input_strides[x_dim]; + + auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); + // after we transpose X and W, we need to stride from one row to the next for each dim in the output tensor + uint32_t W = input_tensor.get_logical_shape()[-1]; + uint32_t W_stride = output_strides[x_dim]; + + uint32_t N = operation_attributes.dims.size(); + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + // treat the input tensor as 3D with rows * x_blocks * w_blocks + uint32_t x_blocks = tt::div_up(X, x_block_size); + uint32_t w_blocks = tt::div_up(W, w_block_size); + uint32_t num_blocks_total = (num_rows / X) * x_blocks * w_blocks; + + auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_blocks_total); + + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * input_cb_page_size * x_block_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, input_cb_page_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * output_cb_page_size * w_block_size, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, output_cb_page_size); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + + tt::tt_metal::CircularBufferConfig cb_src2_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * x_block_size * w_block_size * input_tensor.element_size(), + {{src2_cb_index, cb_data_format}}) + .set_page_size(src2_cb_index, x_block_size * w_block_size * input_tensor.element_size()); + auto cb_src2 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); + + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + (uint32_t)src_is_dram, + N, + input_cb_page_size, + num_rows, + x_dim, + num_blocks_total, + x_blocks, + w_blocks, + x_block_size, + w_block_size, + input_tensor.element_size(), + input_tensor.get_logical_shape()[-1] * input_tensor.element_size()}; - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); - runtime_args[0] = dst_buffer->address(); + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "reader_permute_interleaved_rm_blocked_generic.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = { + (std::uint32_t)dst_is_dram, + N, + output_cb_page_size, + num_rows, + + X, + X_stride, + x_dim, + + W_stride, + input_cb_page_size, + input_tensor.element_size(), + + num_blocks_total, + x_blocks, + w_blocks, + x_block_size, + w_block_size, + + W, + output_tensor.get_logical_shape()[-1] * output_tensor.element_size()}; + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "writer_permute_interleaved_rm_blocked_generic.cpp", + all_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + std::vector compute_kernel_args = {x_block_size, w_block_size}; + bool fp32_dest_acc_en = cb_data_format_output == tt::DataFormat::Float32; + auto compute_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp", + all_cores, + tt::tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .compile_args = compute_kernel_args, + }); + + auto input_shape_view = input_tensor.get_logical_shape().view(); + + std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; + reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); + + std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; + + writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + writer_runtime_args.insert( + writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); + writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); + auto cores = corerange_to_cores(all_cores, std::nullopt); + + std::vector compute_runtime_args = {dst_buffer->address(), 0, 0}; + + uint32_t start_block = 0; + uint32_t num_blocks_per_core = 0; + for (const auto& core : cores) { + if (core_group_1.contains(core)) { + num_blocks_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_blocks_per_core = num_tiles_per_core_group_2; + } else { + // no-op + num_blocks_per_core = 0; + } + compute_runtime_args[0] = num_blocks_per_core; + uint32_t end_block = start_block + num_blocks_per_core; + reader_runtime_args[1] = start_block; + reader_runtime_args[2] = end_block; + writer_runtime_args[1] = start_block; + writer_runtime_args[2] = end_block; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, core, compute_runtime_args); + start_block = end_block; + } + + return { + std::move(program), + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .compute_kernel_id = compute_kernel_id, + .core_range = all_cores}}; +} + +void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + auto& program = cached_program.program; + auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto& compute_kernel_id = cached_program.shared_variables.compute_kernel_id; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + auto& all_cores = cached_program.shared_variables.core_range; + + auto cores = corerange_to_cores(all_cores, std::nullopt); + for (const auto& core : cores) { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args_writer[0] = dst_buffer->address(); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 288f5b5a101..00d622a15fb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -25,30 +25,15 @@ inline bool is_on_device(const Tensor& t) { ttnn::has_storage_type_of(t, ttnn::StorageType::MULTI_DEVICE); } -inline bool has_tile_padding(const Tensor& t) { - if (t.get_logical_shape().rank() > 1) { - auto the_shape = t.get_logical_shape(); - auto the_shape_with_padding = t.get_padded_shape(); - return the_shape[-1] != the_shape_with_padding[-1] or the_shape[-2] != the_shape_with_padding[-2]; - } - return false; -} - ttnn::Tensor permute_impl( const ttnn::Tensor& a, - const SmallVector& dims, + const ttnn::SmallVector& dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { using ttnn::operations::experimental::auto_format::AutoFormat; - Device* device; // Get the device - if (a.storage_type() != StorageType::DEVICE) { - device = AutoFormat::GetDefaultDevice(); - TT_ASSERT(device != nullptr, "Requires setting default device if no inputs to op are on device"); - } else { - device = a.device(); - } + Device* device = a.device(); if (a.get_shape().rank() > 4) { auto input = a.get_layout() == Layout::TILE @@ -57,16 +42,14 @@ ttnn::Tensor permute_impl( TT_FATAL( !(pad_value.has_value() && pad_value.value() != 0.0f), "Non-zero padding is not supported for permute on tensors with rank > 4."); - input = ttnn::prim::permute(input, dims, output_mem_config, std::nullopt); + SmallVector permute_dims(dims.begin(), dims.end()); + input = ttnn::prim::permute(input, permute_dims, output_mem_config, std::nullopt); return ttnn::to_layout(input, a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr); } TT_FATAL(dims.size() == 4, "Only 4D tensor are supported for permute."); uint32_t N = dims[0], C = dims[1], H = dims[2], W = dims[3]; - // Convert tensor back to original - auto input_shape = a.get_logical_shape(); - auto formatted_input_tensor = a; // WH and CN should be supported without typecast bool wh = N == 0 && C == 1 && H == 3 && W == 2; @@ -142,13 +125,14 @@ ttnn::Tensor permute_impl( } else { TT_ASSERT(false, "Illegal permute args"); } + // Convert tensor back to original dtype if typecast was performed output = typecast ? ttnn::typecast(output, DataType::BFLOAT8_B) : output; return output; } ttnn::Tensor permute_launch( const ttnn::Tensor& a, - tt::stl::Span dims, + const ttnn::SmallVector& dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { std::vector output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))}; @@ -159,31 +143,21 @@ ttnn::Tensor permute_launch( const std::vector>& optional_output_tensors) mutable -> std::vector { auto& a = input_tensors.at(0); - SmallVector normalized_dims(dims.size()); - std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [a](std::int64_t idx) { - return a.get_legacy_shape().get_normalized_index(idx); - }); - SmallVector seq_dims(dims.size()); - std::iota(seq_dims.begin(), seq_dims.end(), 0); - if (normalized_dims == seq_dims) { - return {ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config( - a, output_mem_config)}; - } - return {permute_impl(a, normalized_dims, output_mem_config, pad_value)}; + return {permute_impl(a, dims, output_mem_config, pad_value)}; }, {a}, output_tensors); return output_tensors.at(0); } -Tensor composite_invoke( - const ttnn::Tensor& input_tensor, - tt::stl::Span dims, - const std::optional& memory_config, - const std::optional& pad_value) { - auto output_tensor = - permute_launch(input_tensor, dims, memory_config.value_or(input_tensor.memory_config()), pad_value); - return output_tensor; +bool is_permute_nop(const ttnn::Tensor& a, tt::stl::Span dims) { + if (a.get_shape().rank() <= 1) { + return true; + } + auto normalized_dims = ttnn::SmallVector(dims.begin(), dims.end()); + ttnn::SmallVector seq_dims(dims.size()); + std::iota(seq_dims.begin(), seq_dims.end(), 0); + return normalized_dims == seq_dims; } } // namespace detail @@ -193,23 +167,24 @@ ttnn::Tensor ExecutePermute::invoke( const ttnn::Tensor& input_tensor, tt::stl::Span dims, const std::optional& memory_config, - bool composite, const std::optional& pad_value) { - if (composite) { - return detail::composite_invoke(input_tensor, dims, memory_config, pad_value); - } - - const bool initial_input_tensor_on_device = detail::is_on_device(input_tensor); - const auto input_layout = input_tensor.get_layout(); const auto input_rank = input_tensor.get_logical_shape().rank(); - TT_FATAL( input_rank == dims.size(), "The number of dimensions in the tensor input does not match the length of the desired ordering"); + TT_FATAL(detail::is_on_device(input_tensor), "Tensor must already be on device"); + + SmallVector normalized_dims(dims.size()); + std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [input_tensor](std::int64_t idx) { + return input_tensor.get_logical_shape().get_normalized_index(idx); + }); + if (detail::is_permute_nop(input_tensor, normalized_dims)) { + return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config())); + } - auto adjust_order = [](tt::stl::Span dims) { - ttnn::SmallVector new_order; - TT_FATAL(dims.size() <= 4, "Error"); + auto adjust_order = [](tt::stl::Span dims) { + ttnn::SmallVector new_order; + TT_FATAL(dims.size() <= 4, "Minimum rank of tensor required is 4"); int additional_ranks = 4 - dims.size(); for (int i = 0; i < additional_ranks; i++) { new_order.push_back(i); @@ -220,33 +195,15 @@ ttnn::Tensor ExecutePermute::invoke( return new_order; }; auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; - auto iorder = - dims.size() < 4 ? adjust_order(dims) : dims; // internals of permute_impl already adjust negative indices + auto iorder = normalized_dims.size() < 4 ? adjust_order(normalized_dims) : normalized_dims; - TT_FATAL(detail::is_on_device(itensor), "Error"); + const auto input_layout = input_tensor.get_layout(); auto output_tensor = detail::permute_launch(itensor, iorder, memory_config.value_or(input_tensor.memory_config()), pad_value); output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr); if (input_rank < 4) { - const auto shape = output_tensor.get_shape(); - const auto full_shape = output_tensor.get_shape().with_tile_padding(); - SmallVector shape_vec{}; - SmallVector full_shape_vec{}; - int i = 0; - while (i < 3 and shape[i] == 1) { - i++; - } - for (; i < shape.rank(); i++) { - shape_vec.push_back(shape[i]); - full_shape_vec.push_back(full_shape[i]); - } - output_tensor = ttnn::reshape(output_tensor, ttnn::Shape(shape_vec, full_shape_vec)); - } - - if (initial_input_tensor_on_device and not detail::is_on_device(output_tensor)) { - output_tensor = - ttnn::to_device(output_tensor, input_tensor.device(), memory_config.value_or(input_tensor.memory_config())); + output_tensor = ttnn::squeeze_from_4D(output_tensor, input_rank); } return output_tensor; @@ -257,7 +214,7 @@ ttnn::Tensor ExecutePermute::invoke( tt::stl::Span dims, const std::optional& memory_config, const std::optional& pad_value) { - return invoke(DefaultQueueId, input_tensor, dims, memory_config, true, pad_value); + return invoke(DefaultQueueId, input_tensor, dims, memory_config, pad_value); } ttnn::Tensor ExecutePermute::invoke( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp index 7f9301b696b..2f13b9c2845 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp @@ -15,7 +15,6 @@ struct ExecutePermute { const ttnn::Tensor& input_tensor, tt::stl::Span dims, const std::optional& memory_config, - bool composite = true, const std::optional& pad_value = 0.0f); static ttnn::Tensor invoke( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp index 2fbb5c0bcd0..be6adbf880b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp @@ -44,7 +44,7 @@ void bind_permute(py::module& module) { const std::optional& memory_config, uint8_t queue_id, const std::optional& pad_value) { - return self(queue_id, input_tensor, dims, memory_config, false, pad_value); + return self(queue_id, input_tensor, dims, memory_config, pad_value); }, py::arg("input_tensor").noconvert(), py::arg("dims"), diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 1776e28adc8..c560cd17ebe 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -23,14 +23,29 @@ void Transpose::validate(const std::vector& input_tensors) const { TT_FATAL(input_tensor.buffer() != nullptr, "Operands to transpose need to be allocated in buffers on device!"); TT_FATAL( !(this->dim != TransposeOpDim::HC && this->pad_value.has_value() && this->pad_value != 0.0f), - "Non-zero padding is not supported for any transpose other than HC."); + "Non-zero padding {} is not supported for any transpose other than HC.", + this->pad_value.value()); + TT_FATAL( + this->dim == TransposeOpDim::HC || this->dim == TransposeOpDim::WH || this->dim == TransposeOpDim::CN, + "Transpose HC, WH, CN are the only supported transpose operations. Transpose {} is not supported.", + (int)this->dim); const auto shape = input_tensor.get_padded_shape(); bool row_major = input_tensor.get_layout() == Layout::ROW_MAJOR; uint32_t W = shape[3], H = shape[2], C = shape[1], N = shape[0]; uint32_t HW = H * W; if (not row_major) { - TT_FATAL(W % TILE_WIDTH == 0 && H % TILE_HEIGHT == 0, "Error"); - TT_FATAL(input_tensor.volume() % TILE_HW == 0, "Error"); + TT_FATAL( + W % TILE_WIDTH == 0 && H % TILE_HEIGHT == 0, + "Tiled tensor H {} W {} must be a multiple of TILE HEIGHT {} and TILE WIDTH", + H, + W, + TILE_HEIGHT, + TILE_WIDTH); + TT_FATAL( + input_tensor.volume() % TILE_HW == 0, + "Tiled tensor volume {} must be a multiple of TILE HEIGHT * TILE WIDTH", + input_tensor.volume(), + TILE_HW); } uint32_t ROW_MAJOR_STICK_WIDTH = 16; if (this->dim == TransposeOpDim::WH) { @@ -38,27 +53,43 @@ void Transpose::validate(const std::vector& input_tensors) const { TT_FATAL( (W * input_tensor.element_size()) % ROW_MAJOR_STICK_WIDTH == 0 && (H * input_tensor.element_size()) % ROW_MAJOR_STICK_WIDTH == 0, - "Error"); + "Row major tensor W {} H {} must be a multiple of ROW_MAJOR_STICK_WIDTH for transpose wh", + W, + H, + ROW_MAJOR_STICK_WIDTH); } if (input_tensor.is_sharded()) { - TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); + TT_FATAL( + input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, + "Only height and block sharding is supported for transpose wh"); const auto shard_spec = input_tensor.shard_spec().value(); - TT_FATAL(shard_spec.shape[1] == W, "Error"); - TT_FATAL(shard_spec.shape[0] % H == 0, "Error"); - TT_FATAL(this->output_mem_config.is_sharded(), "Error"); - TT_FATAL(this->output_mem_config.memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); + TT_FATAL( + (shard_spec.shape[0] % H == 0) || (H % shard_spec.shape[0] == 0), + "Only a multiple of H {} or a factor of H is allows for the shard height {} for transpose WH", + H, + shard_spec.shape[0]); + TT_FATAL(shard_spec.shape[1] == W, "Only height sharding is supported"); + TT_FATAL(this->output_mem_config.is_sharded(), "Output must be sharded for transpose WH"); + TT_FATAL( + this->output_mem_config.memory_layout != TensorMemoryLayout::WIDTH_SHARDED, + "Only height and block sharding is supported for transpose wh"); } else { - TT_FATAL(!this->output_mem_config.is_sharded(), "Error"); + TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved input tensors cannot output sharded outputs"); } } else { if (input_tensor.is_sharded()) { - TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Only height sharding is supported for transpose hc"); const auto shard_spec = input_tensor.shard_spec().value(); - TT_FATAL(shard_spec.shape[1] == W, "Error"); - TT_FATAL(this->output_mem_config.is_sharded(), "Error"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL(shard_spec.shape[1] == W, "Block/Width sharding is not supported"); + TT_FATAL( + this->output_mem_config.is_sharded(), "Sharded input can only output sharded tensors for transpose hc"); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Only height sharding is supported for the ouput of sharded transpose hc"); } else { - TT_FATAL(!this->output_mem_config.is_sharded(), "Error"); + TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved inputs cannot output sharded outputs"); } } if (this->dim == TransposeOpDim::HC) { @@ -78,19 +109,8 @@ void Transpose::validate(const std::vector& input_tensors) const { "HC transpose does not support sharded+tilized inputs"); TT_FATAL( !(input_tensor.is_sharded() && pad_value.has_value() && pad_value.value() != 0.0f), - "Sharded HC transpose does not support non-zero padding"); - } else if (this->dim == TransposeOpDim::CW) { - TT_FATAL(C % TILE_WIDTH == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); - } else if (this->dim == TransposeOpDim::NH) { - TT_FATAL(N % TILE_HEIGHT == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); - } else if (this->dim == TransposeOpDim::NW) { - TT_FATAL(N % TILE_WIDTH == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); + "Sharded HC transpose does not support non-zero padding {}", + pad_value.value()); } } @@ -147,9 +167,15 @@ std::vector Transpose::compute_output_specs(const std::vector< if (this->dim == TransposeOpDim::WH) { const auto& input_padded_shape = input_tensor.get_padded_shape(); ShardSpec shard_spec = input_tensor.shard_spec().value(); - shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1]; - shard_spec.shape[1] = input_padded_shape[-2]; - output_mem_config.shard_spec = shard_spec; + if (shard_spec.shape[0] >= input_padded_shape[-2]) { + shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1]; + shard_spec.shape[1] = input_padded_shape[-2]; + output_mem_config.shard_spec = shard_spec; + } else { + std::swap(shard_spec.shape[0], shard_spec.shape[1]); + output_mem_config.shard_spec = shard_spec; + output_mem_config.memory_layout = TensorMemoryLayout::BLOCK_SHARDED; + } } else if (this->dim == TransposeOpDim::HC) { output_mem_config.shard_spec = input_tensor.shard_spec().value(); } else { diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp index 2dbdcf7dc8a..02bc49406d8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp @@ -1776,8 +1776,9 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, uint32_t dst_single_tile_size = tt::tt_metal::detail::TileSize(dst_cb_data_format); tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - int32_t num_tiles = a.volume() / TILE_HW; + const auto tile = a.get_tensor_spec().tile(); + const uint32_t tile_hw = tile.get_tile_hw(); + int32_t num_tiles = a.volume() / tile_hw; tt::tt_metal::Device* device = a.device(); @@ -1793,7 +1794,7 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, auto& all_cores = shard_spec.grid; uint32_t num_cores = all_cores.num_cores(); - uint32_t num_tiles_per_shard = shard_spec.numel() / TILE_HW; + uint32_t num_tiles_per_shard = shard_spec.numel() / tile_hw; tt::tt_metal::LegacyShape output_shape = output.get_legacy_shape(); @@ -1848,11 +1849,22 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, total_cores, tt::tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_compile_time_args}); - uint32_t Wt = shard_spec.shape[1] / TILE_WIDTH; - uint32_t Ht = a.get_legacy_shape()[-2] / TILE_HEIGHT; - uint32_t HtWt = Ht * Wt; - uint32_t N = shard_spec.shape[0] / a.get_legacy_shape()[-2]; - uint32_t NHtWt = N * HtWt; + auto padded_shape = a.get_padded_shape(); + auto shard_shape = shard_spec.shape; + + uint32_t H = padded_shape[2], W = padded_shape[3]; + uint32_t Hs = shard_shape[0], Ws = shard_shape[1]; + + uint32_t Hts = Hs / tile.tile_shape[0]; + uint32_t Wts = Ws / tile.tile_shape[1]; + + uint32_t Ht = H / tile.tile_shape[0]; + uint32_t Ht_per_shard = std::min(Ht, Hts); + + uint32_t num_hw_blocks_per_shard = Hts > Ht ? Hts / Ht : 1; + + uint32_t HtWt_tile_size = Ht_per_shard * Wts; + uint32_t num_blocks = num_hw_blocks_per_shard * HtWt_tile_size; auto bbox = all_cores.bounding_box(); std::vector cores = @@ -1862,13 +1874,17 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, std::vector> unary_compute_args = {cores.size(), std::vector(5)}; std::vector> unary_writer_args = {cores.size(), std::vector(1)}; std::fill( - unary_reader_args.begin(), unary_reader_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_reader_args.begin(), + unary_reader_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); std::fill( unary_compute_args.begin(), unary_compute_args.begin() + all_cores.num_cores(), - std::vector{NHtWt, HtWt, N, Ht, Wt}); + std::vector{num_blocks, HtWt_tile_size, num_hw_blocks_per_shard, Ht_per_shard, Wts}); std::fill( - unary_writer_args.begin(), unary_writer_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_writer_args.begin(), + unary_writer_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, cores, unary_reader_args); tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, cores, unary_compute_args); @@ -1899,7 +1915,11 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, auto shard_spec = src_tensor.shard_spec().value(); - uint32_t num_tiles_per_shard = shard_spec.numel() / TILE_HW; + const auto tile = src_tensor.get_tensor_spec().tile(); + const uint32_t tile_hw = tile.get_tile_hw(); + int32_t num_tiles = src_tensor.volume() / tile_hw; + + uint32_t num_tiles_per_shard = shard_spec.numel() / tile_hw; if (src0_sharded) { UpdateDynamicCircularBufferAddressAndTotalSize( @@ -1911,11 +1931,22 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, program, cb_output, *dst_buffer, num_tiles_per_shard * dst_single_tile_size); } - uint32_t Wt = shard_spec.shape[1] / TILE_WIDTH; - uint32_t Ht = src_tensor.get_legacy_shape()[-2] / TILE_HEIGHT; - uint32_t HtWt = Ht * Wt; - uint32_t N = shard_spec.shape[0] / src_tensor.get_legacy_shape()[-2]; - uint32_t NHtWt = N * HtWt; + auto padded_shape = src_tensor.get_padded_shape(); + auto shard_shape = shard_spec.shape; + + uint32_t H = padded_shape[2], W = padded_shape[3]; + uint32_t Hs = shard_shape[0], Ws = shard_shape[1]; + + uint32_t Hts = Hs / tile.tile_shape[0]; + uint32_t Wts = Ws / tile.tile_shape[1]; + + uint32_t Ht = H / tile.tile_shape[0]; + uint32_t Ht_per_shard = std::min(Ht, Hts); + + uint32_t num_hw_blocks_per_shard = Hts > Ht ? Hts / Ht : 1; + + uint32_t HtWt_tile_size = Ht_per_shard * Wts; + uint32_t num_blocks = num_hw_blocks_per_shard * HtWt_tile_size; const auto& all_cores = shard_spec.grid; bool row_major = shard_spec.orientation == ShardOrientation::ROW_MAJOR; @@ -1927,13 +1958,17 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, std::vector> unary_compute_args = {cores.size(), std::vector(5)}; std::vector> unary_writer_args = {cores.size(), std::vector(1)}; std::fill( - unary_reader_args.begin(), unary_reader_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_reader_args.begin(), + unary_reader_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); std::fill( unary_compute_args.begin(), unary_compute_args.begin() + all_cores.num_cores(), - std::vector{NHtWt, HtWt, N, Ht, Wt}); + std::vector{num_blocks, HtWt_tile_size, num_hw_blocks_per_shard, Ht_per_shard, Wts}); std::fill( - unary_writer_args.begin(), unary_writer_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_writer_args.begin(), + unary_writer_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, cores, unary_reader_args); tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, cores, unary_compute_args); diff --git a/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp index 26de812b441..092481bdaee 100644 --- a/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp @@ -60,7 +60,9 @@ void AttnMatmulDeviceOperation::validate(const std::vector& input_tensor } else { TT_FATAL( ashape[3] == bshape[2], - "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in attn_matmul op"); // A.K == B.K + "Dimension K (A.shape[3]and B.shape[2]) must match for A shape: {} and B shape: {} in attn_matmul op", + ashape, + bshape); // A.K == B.K } } diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.cpp b/ttnn/cpp/ttnn/tensor/shape/shape.cpp index 7dee5428526..d4a54500c46 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.cpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.cpp @@ -7,6 +7,7 @@ #include #include #include "ttnn/tensor/shape/small_vector.hpp" +#include "tt_metal/common/assert.hpp" namespace tt::tt_metal { @@ -20,6 +21,17 @@ uint64_t SimpleShape::volume() const { return std::accumulate(cbegin(), cend(), uint64_t{1}, std::multiplies()); } +const uint32_t SimpleShape::get_normalized_index(std::int64_t index) const { + std::int64_t rank = static_cast(this->rank()); + std::uint64_t normalized_index = index >= 0 ? index : rank + index; + TT_FATAL( + normalized_index >= 0 and normalized_index < rank, + "Index is out of bounds for the rank, should be between 0 and {} however is {}", + rank - 1, + normalized_index); + return normalized_index; +} + std::ostream& operator<<(std::ostream& os, const tt::tt_metal::SimpleShape& shape) { os << "SimpleShape(["; for (size_t i = 0; i < shape.rank(); ++i) { diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.hpp b/ttnn/cpp/ttnn/tensor/shape/shape.hpp index b1661578927..f6f78d35fd5 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.hpp @@ -30,6 +30,8 @@ class SimpleShape final : protected ShapeBase { [[nodiscard]] size_t rank() const; [[nodiscard]] uint64_t volume() const; + const uint32_t get_normalized_index(std::int64_t index) const; + // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); auto attribute_values() const { return std::forward_as_tuple(this->value_); }