From 6c566aa27ceaeeefc7b73f7c6a6418da7680b361 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Sun, 18 Aug 2024 16:14:30 -0400 Subject: [PATCH] Enable multi-buffer per channel in EDM (#11387) #6300: Add multi-buffering per EDM channel Adds the option to add a multiple buffers (e.g. double buffered) per EDM channel. This is useful for improving performance of CCL operation. To simplify the worker <-> EDM interface to allow a kernel to automatically support multi-buffered channels, new adapter components are added: - WorkerToEdmReader: for a worker pulling data from EDM - WorkerToEdmSender: for a worker pushing data to the EDM These hide details such as buffer offsets in the channel and any other details that may only be relevant to the EDM. Additionally, their use encapsulates the worker <-> EDM data movement protocol, allowing future low level changes to buffer layouts and allocations on the EDM without requiring worker kernel changes. As an a coinciding required step to enable this functionality the EDM channel count limit has been lifted to unlimited (limited only to as many buffers can fit into L1). This provides additional flexibility for op writers and let's the `erisc_info::channels` to be shrunk back to single entry. Note that this commit only adds this feature, but does not yet enable it for CCL ops. --- .github/workflows/ttnn-post-commit.yaml | 2 + tests/CMakeLists.txt | 2 +- tests/scripts/t3000/run_t3000_unit_tests.sh | 1 + tests/tt_eager/ops/ccl/test_ccl_helpers.cpp | 3 +- ...erisc_datamover_receiver_worker_reader.cpp | 89 -- ...erisc_datamover_receiver_worker_sender.cpp | 43 - .../erisc_datamover_sender_worker_sender.cpp | 93 -- .../erisc/ethernet_bidirectional_ubench.cpp | 1 - .../ethernet_ping_latency_ubench_receiver.cpp | 2 - .../ethernet_ping_latency_ubench_sender.cpp | 2 - tests/ttnn/unit_tests/gtests/CMakeLists.txt | 5 + ...erisc_datamover_receiver_worker_reader.cpp | 44 + ...erisc_datamover_receiver_worker_sender.cpp | 34 + .../erisc_datamover_sender_worker_reader.cpp | 26 +- .../erisc_datamover_sender_worker_sender.cpp | 59 + .../test_erisc_data_mover_with_workers.cpp | 1178 +++++++++++++++++ .../test_reduce_scatter_post_commit.py | 8 +- tt_metal/hw/inc/ethernet/dataflow_api.h | 1 - .../ccl/all_gather/device/all_gather_op.cpp | 15 +- .../ccl/all_gather/device/all_gather_op.hpp | 9 +- .../dataflow/worker_ring_gather_utils.hpp | 3 +- .../multi_core/all_gather_op_multi_core.cpp | 140 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 21 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 1 + .../ccl/ccl_host_datastructures.cpp | 45 +- .../ccl/ccl_host_datastructures.hpp | 68 +- .../ccl/kernel_common/worker_edm_adapters.hpp | 142 ++ .../ccl/kernel_common/worker_edm_utils.hpp | 1 - .../ccl/kernels/edm/erisc_async_datamover.hpp | 222 ++-- .../ccl/kernels/edm/erisc_datamover.cpp | 81 +- .../host/reduce_scatter_full_worker_grid.cpp | 34 +- ...interleaved_ring_reduce_scatter_reader.cpp | 2 +- .../hetergeneous_data_structs.hpp | 22 +- .../sharded_tensor_addr_gen.hpp | 1 + 34 files changed, 1900 insertions(+), 500 deletions(-) delete mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp delete mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp delete mode 100644 tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp rename tests/{tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc => ttnn/unit_tests/gtests/ccl/kernels}/erisc_datamover_sender_worker_reader.cpp (62%) create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp diff --git a/.github/workflows/ttnn-post-commit.yaml b/.github/workflows/ttnn-post-commit.yaml index 31a8b635cea..378f53732ee 100644 --- a/.github/workflows/ttnn-post-commit.yaml +++ b/.github/workflows/ttnn-post-commit.yaml @@ -60,6 +60,8 @@ jobs: fast_runtime_mode_off: true - name: ttnn examples and cpp tests cmd: ./build/test/ttnn/unit_tests_ttnn && ./tests/scripts/run_ttnn_examples.sh + - name: ttnn ccl cpp unit tests + cmd: ./build/test/ttnn/unit_tests_ttnn_ccl name: ${{ matrix.test-group.name }} ${{ inputs.arch }} ${{ inputs.runner-label }} env: TT_METAL_ENV: ${{ vars.TT_METAL_ENV }} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f26ba585ef7..a2500c51947 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,5 +8,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_eager) # this should go away and be replaced with link to ttnn add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ttnn/unit_tests/gtests) -set(TESTS_DEPENDS_LIST metal_tests eager_tests unit_tests_ttnn test_multi_device galaxy_unit_tests_ttnn ttnn watcher_dump) +set(TESTS_DEPENDS_LIST metal_tests eager_tests unit_tests_ttnn unit_tests_ttnn_ccl test_multi_device galaxy_unit_tests_ttnn ttnn watcher_dump) add_custom_target(tests DEPENDS ${TESTS_DEPENDS_LIST}) diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 06f237638c7..d7a8ce42166 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -34,6 +34,7 @@ run_t3000_ttnn_tests() { echo "LOG_METAL: Running run_t3000_ttnn_tests" WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml ./build/test/ttnn/test_multi_device WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml ./build/test/ttnn/unit_tests_ttnn + ./build/test/ttnn/unit_tests_ttnn_ccl WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_trace.py ; fail+=$? WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_events.py ; fail+=$? pytest -n auto tests/ttnn/unit_tests/test_multi_device.py ; fail+=$? diff --git a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp index ba32f967434..3cd023cfa67 100644 --- a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp @@ -14,7 +14,8 @@ TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingM ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; ttnn::ccl::EriscDataMoverTerminationMode termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; - auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, buffer_sharing_mode, termination_mode); + std::size_t num_buffers_per_channel = 1; + auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, num_buffers_per_channel, buffer_sharing_mode, termination_mode); std::vector worker_semaphore_ids = {0, 1, 2, 3}; std::vector message_counts = {256, 512, 24, 1}; std::vector> const& worker_coords = { diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp deleted file mode 100644 index d29943843c6..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" -#include "debug/dprint.h" - -FORCE_INLINE void fetch_chunk( - const uint32_t max_pages_per_chunk, - const uint32_t total_pages_to_read, - uint32_t& num_pages_read, - const uint32_t& cb_id, - const uint32_t& page_size, - uint64_t remote_l1_read_addr) { - const uint32_t num_pages_this_chunk = std::min(total_pages_to_read - num_pages_read, max_pages_per_chunk); - - for (uint32_t i = 0; i < num_pages_this_chunk; ++i) { - cb_reserve_back(cb_id, 1); - uint32_t l1_write_addr = get_write_ptr(cb_id); - noc_async_read(remote_l1_read_addr, l1_write_addr, page_size); - remote_l1_read_addr += page_size; - noc_async_read_barrier(); - cb_push_back(cb_id, 1); - } - - num_pages_read += num_pages_this_chunk; -} - -void kernel_main() { - const uint32_t eth_receiver_l1_base_addr = get_compile_time_arg_val(0); - const uint32_t eth_receiver_l1_sem_addr = get_compile_time_arg_val(1); - const uint32_t num_pages_per_read_chunk = get_arg_val(0); - const uint32_t total_pages_to_read = get_arg_val(1); - const uint32_t page_size = get_arg_val(2); - const uint32_t receiver_erisc_datamover_noc_x = get_arg_val(3); - const uint32_t receiver_erisc_datamover_noc_y = get_arg_val(4); - // Worker local L1 semaphore that erisc datamover signals to - const uint32_t receiver_read_sem_addr = get_semaphore(get_arg_val(5)); - - DPRINT << " rwr: args: eth_receiver_l1_base_addr="<< - eth_receiver_l1_base_addr<< - "\n\teth_receiver_l1_sem_addr="<(receiver_read_sem_addr); - - // Address of the buffer on the eth receiver, this is different per receiver worker core - const uint64_t eth_receiver_l1_base_noc_addr = - get_noc_addr(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y, eth_receiver_l1_base_addr); - // Address of the semaphore on the eth receiver, this is the same per receiver worker core - const uint64_t eth_receiver_l1_semaphore_noc_addr = - get_noc_addr(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y, eth_receiver_l1_sem_addr); - - DPRINT << " rwr: noc_index " << (uint32_t)noc_index << "\n"; - DPRINT << " rwr: my_x[0],my_y[0] " << (uint32_t)my_x[0] << "," << (uint32_t)my_y[0] << "\n"; - DPRINT << " rwr: my_x[1],my_y[1] " << (uint32_t)my_x[1] << "," << (uint32_t)my_y[1] << "\n"; - uint32_t num_pages_read = 0; - while (num_pages_read < total_pages_to_read) { - DPRINT << " rwr: page " << num_pages_read << " waiting for semaphore at " << (uint32_t)receiver_read_sem_addr << "\n"; - noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1); - DPRINT << " rwr: got semaphore signal from sender erisc\n"; - noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0); - // Read page by page so that writer can be kicked off instead of being blocked waiting for full chunk to be read - // Look into perf/optimizations for this - DPRINT << " rwr: fetch chunk\n"; - fetch_chunk( - num_pages_per_read_chunk, - total_pages_to_read, - num_pages_read, - cb_id_in0, - page_size, - eth_receiver_l1_base_noc_addr); - DPRINT << " rwr: increment semaphore on eth core at address " << eth_receiver_l1_sem_addr << "\n"; - noc_semaphore_inc(eth_receiver_l1_semaphore_noc_addr, 1); - } - -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp deleted file mode 100644 index ab0f6a92f5d..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -#include "debug/dprint.h" - -void kernel_main() { - const uint32_t dst_addr = get_arg_val(0); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t num_pages_total = get_compile_time_arg_val(1); - constexpr uint32_t page_size = get_compile_time_arg_val(2); - - constexpr uint32_t cb_id_in0 = tt::CB::c_in0; - InterleavedAddrGen dest_addr_generator = { - .bank_base_address = dst_addr, .page_size = page_size}; - DPRINT << " rws: args: " << - "\n\tdst_addr="<(l1_read_addr) << "\n"; - uint64_t dst_noc_addr = get_noc_addr(p, dest_addr_generator); - noc_async_write(l1_read_addr, dst_noc_addr, page_size); - DPRINT << "rws: write barrier complete\n"; - noc_async_write_barrier(); - DPRINT << "rws: cb_pop_front\n"; - cb_pop_front(cb_id_in0, 1); - } - - // DPRINT << "rws: DONE\n"; - // ncrisc_noc_full_sync(); - // DPRINT << "rws: DONE DONE\n"; -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp deleted file mode 100644 index de99988c729..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "noc_nonblocking_api.h" -#include "noc_parameters.h" - -FORCE_INLINE void send_chunk( - const uint32_t max_pages_per_chunk, - const uint32_t total_pages_to_send, - uint32_t& num_pages_sent, - const uint32_t& cb_id, - const uint32_t& page_size, - uint64_t remote_l1_write_addr, - volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr - ) { - - const uint32_t num_pages_this_chunk = std::min(total_pages_to_send - num_pages_sent, max_pages_per_chunk); - for (uint32_t i = 0; i < num_pages_this_chunk; ++i) { - cb_wait_front(cb_id, 1); - uint32_t l1_read_addr = get_read_ptr(cb_id); - noc_async_write(l1_read_addr, remote_l1_write_addr, page_size); - remote_l1_write_addr += page_size; - noc_async_write_barrier(); - cb_pop_front(cb_id, 1); - } - num_pages_sent += num_pages_this_chunk; -} - -// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). -// -> takes input from local cb and pushes to erisc L1 -void kernel_main() { - const uint32_t eth_l1_base_addr = get_arg_val(0); - // erisc l1 semaphore address - const uint32_t eth_sender_l1_sem_addr = get_arg_val(1); - const uint32_t writer_send_sem_addr = get_semaphore(get_arg_val(2)); - const uint32_t eth_sender_noc_x = get_arg_val(3); - const uint32_t eth_sender_noc_y = get_arg_val(4); - - constexpr uint32_t num_pages_per_send = get_compile_time_arg_val(0); - constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); - constexpr uint32_t page_size = get_compile_time_arg_val(2); - - DPRINT << " sws: args:" << - "\n\teth_sender_l1_base_addr="<(writer_send_sem_addr); - - // This is different per writer core - const uint64_t eth_l1_sender_base_noc_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_l1_base_addr); - // Used to signal eth sender that data is available. This is different per writer core - const uint64_t eth_l1_sender_semaphore_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr); - - // num_transfers = num_devices - 1 - uint32_t num_pages_sent = 0; - DPRINT << " sws: noc_index " << (uint32_t)noc_index << "\n"; - DPRINT << " sws: my_x[0],my_y[0] " << (uint32_t)my_x[0] << "," << (uint32_t)my_y[0] << "\n"; - DPRINT << " sws: my_x[1],my_y[1] " << (uint32_t)my_x[1] << "," << (uint32_t)my_y[1] << "\n"; - - uint32_t old_val_NIU_SLV_CMD_ACCEPTED = NOC_STATUS_READ_REG(noc_index, NIU_SLV_REQ_ACCEPTED); - uint32_t old_val_NIU_SLV_ATOMIC_RESP_SENT = NOC_STATUS_READ_REG(noc_index, NIU_SLV_ATOMIC_RESP_SENT); - uint32_t old_val_NIU_SLV_POSTED_ATOMIC_RECEIVED = NOC_STATUS_READ_REG(noc_index, NIU_SLV_POSTED_ATOMIC_RECEIVED); - uint32_t old_val_NIU_SLV_NONPOSTED_ATOMIC_SENT = NOC_STATUS_READ_REG(noc_index, NIU_SLV_NONPOSTED_ATOMIC_RECEIVED); - - bool diffed_NIU_SLV_CMD_ACCEPTED = false; - bool diffed_NIU_SLV_ATOMIC_RESP_SENT = false; - bool diffed_NIU_SLV_POSTED_ATOMIC_RECEIVED = false; - bool diffed_NIU_SLV_NONPOSTED_ATOMIC_SENT = false; - while (num_pages_sent < total_pages_to_send) { - noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); - - noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); - send_chunk(num_pages_per_send, total_pages_to_send, num_pages_sent, cb_id_in0, page_size, eth_l1_sender_base_noc_addr, writer_send_semaphore_addr_ptr); - noc_semaphore_inc(eth_l1_sender_semaphore_addr, 1); - } -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp index e0936a79645..b2ec93643a7 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp @@ -120,7 +120,6 @@ void kernel_main() { channel_addrs[sender_channel], channel_addrs[sender_channel], message_size_payload, - sender_channel, message_size_payload, message_size_payload_eth_words + 1); ready_to_send_payload &= ~(1 << s_i); diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp index 6777713fc3c..2328352269f 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp @@ -46,7 +46,6 @@ FORCE_INLINE void run_loop_iteration( reinterpret_cast(channel_sync_addrs[i]), reinterpret_cast(channel_sync_addrs[i]), sizeof(eth_channel_sync_t), - i, // remove this field - it's superfluous sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t) >> 4); } @@ -67,7 +66,6 @@ FORCE_INLINE void run_loop_iteration( reinterpret_cast(channel_sync_addrs[i]), reinterpret_cast(channel_sync_addrs[i]), sizeof(eth_channel_sync_t), - i, // remove this field - it's superfluous sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t) >> 4); } diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp index 72aa3da064c..ed9923692a5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp @@ -40,7 +40,6 @@ FORCE_INLINE void run_loop_iteration( channel_addrs[i], channel_addrs[i], full_payload_size, - i, full_payload_size, full_payload_size_eth_words); } @@ -60,7 +59,6 @@ FORCE_INLINE void run_loop_iteration( channel_addrs[i], channel_addrs[i], full_payload_size, - i, full_payload_size, full_payload_size_eth_words); } diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 5e5be0e16bd..070659e6431 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -8,8 +8,12 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_reflect.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_to_and_from_json.cpp ) +set(TTNN_CCL_UNIT_TESTS_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp +) add_executable(unit_tests_ttnn ${TTNN_UNIT_TESTS_SRC}) +add_executable(unit_tests_ttnn_ccl ${TTNN_CCL_UNIT_TESTS_SRC}) add_executable(test_multi_device ${CMAKE_CURRENT_SOURCE_DIR}/test_multi_device.cpp) add_executable(galaxy_unit_tests_ttnn ${CMAKE_CURRENT_SOURCE_DIR}/test_ccl_on_tg.cpp) @@ -28,5 +32,6 @@ endfunction() # Set up properties for both targets setup_ttnn_test_target(unit_tests_ttnn) +setup_ttnn_test_target(unit_tests_ttnn_ccl) setup_ttnn_test_target(test_multi_device) setup_ttnn_test_target(galaxy_unit_tests_ttnn) diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp new file mode 100644 index 00000000000..481457ac8cd --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" + +void kernel_main() { + constexpr uint32_t eth_receiver_l1_base_addr = get_compile_time_arg_val(0); + constexpr uint32_t eth_receiver_l1_sem_addr = get_compile_time_arg_val(1); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(2); + constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast(get_compile_time_arg_val(3)); + const uint32_t num_pages_per_read_chunk = get_arg_val(0); + const uint32_t total_pages_to_read = get_arg_val(1); + const uint32_t page_size = get_arg_val(2); + const uint32_t receiver_erisc_datamover_noc_x = get_arg_val(3); + const uint32_t receiver_erisc_datamover_noc_y = get_arg_val(4); + // Worker local L1 semaphore that erisc datamover signals to + volatile uint32_t* const receiver_read_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(5))); + const uint32_t num_buffers_per_edm_channel = get_arg_val(6); + + ccl::edm::WorkerToEdmReader reader( + ttnn::ccl::WorkerXY(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y), + eth_receiver_l1_base_addr, + num_buffers_per_channel, + eth_receiver_l1_sem_addr, + num_pages_per_read_chunk * page_size, + receiver_read_sem_addr); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + for (uint32_t i = 0; i < total_pages_to_read; i += num_pages_per_read_chunk) { + bool last_message = (i + num_pages_per_read_chunk) >= total_pages_to_read; + uint32_t num_pages_to_read = std::min(total_pages_to_read - i, num_pages_per_read_chunk); + reader.wait_for_payload_available(); + reader.fetch_payload_blocking(cb_id_in0, num_pages_to_read, page_size, last_message); + } + + reader.close(); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp new file mode 100644 index 00000000000..fecb458407b --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + const uint32_t dst_addr = get_arg_val(0); + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t num_pages_total = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = get_compile_time_arg_val(3); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + InterleavedAddrGen dest_addr_generator = { + .bank_base_address = dst_addr, .page_size = page_size}; + + for (uint32_t p = 0; p < num_pages_total; p += pages_per_edm_buffer) { + uint32_t num_pages_to_send = std::min(pages_per_edm_buffer, num_pages_total - p); + cb_wait_front(cb_id_in0, num_pages_to_send); + uint32_t l1_read_addr = get_read_ptr(cb_id_in0); + + for (uint32_t i = 0; i < num_pages_to_send; ++i) { + uint64_t dst_noc_addr = get_noc_addr(p + i, dest_addr_generator); + noc_async_write(l1_read_addr, dst_noc_addr, page_size); + l1_read_addr += page_size; + } + noc_async_write_barrier(); + + cb_pop_front(cb_id_in0, num_pages_to_send); + } + +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp similarity index 62% rename from tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp rename to tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp index fc875d64f87..41d453e2793 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp @@ -7,11 +7,14 @@ #include "debug/dprint.h" void kernel_main() { - const uint32_t src_addr = get_arg_val(0); constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; constexpr uint32_t num_pages_to_read_total = get_compile_time_arg_val(1); constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = get_compile_time_arg_val(3); constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + const uint32_t src_addr = get_arg_val(0); + const InterleavedAddrGen source_address_generator = { .bank_base_address = src_addr, .page_size = page_size}; @@ -21,19 +24,22 @@ void kernel_main() { "\n\tnum_pages_to_read_total="<(pages_per_edm_buffer, num_pages_to_read_total - num_pages_read); + cb_reserve_back(cb_id_in0, pages_to_read); uint32_t local_l1_read_addr = get_write_ptr(cb_id_in0); - uint64_t src_noc_addr = get_noc_addr(num_pages_read, source_address_generator); - DPRINT << "swr: async_read\n"; - noc_async_read(src_noc_addr, local_l1_read_addr, page_size); - DPRINT << "swr: read_barrier\n"; + for (uint32_t p = 0; p < pages_to_read; ++p) { + + uint64_t src_noc_addr = get_noc_addr(num_pages_read + p, source_address_generator); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); + local_l1_read_addr += page_size; + } noc_async_read_barrier(); - DPRINT << "swr: cb_push_back\n"; - cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in0, pages_to_read); + // DPRINT << "SR " << num_pages_read << "\n"; } + DPRINT << "SR DONE\n"; } diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp new file mode 100644 index 00000000000..4cff4c2ec51 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" + + + +// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). +// -> takes input from local cb and pushes to erisc L1 +void kernel_main() { + const uint32_t eth_l1_base_addr = get_arg_val(0); + // erisc l1 semaphore address + const uint32_t eth_sender_l1_sem_addr = get_arg_val(1); + volatile uint32_t* const writer_send_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(2))); + const uint32_t eth_sender_noc_x = get_arg_val(3); + const uint32_t eth_sender_noc_y = get_arg_val(4); + const uint32_t num_buffers_per_edm_channel = get_arg_val(5); + + constexpr uint32_t num_pages_per_send = get_compile_time_arg_val(0); + constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3); + constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast(get_compile_time_arg_val(4)); + + ccl::edm::WorkerToEdmSender sender( + ttnn::ccl::WorkerXY(eth_sender_noc_x, eth_sender_noc_y), + eth_l1_base_addr, + num_buffers_per_channel, + eth_sender_l1_sem_addr, + num_pages_per_send * page_size, + writer_send_sem_addr); + + std::array eth_buffer_addresses; + for (uint32_t i = 0; i < num_buffers_per_channel; i++) { + eth_buffer_addresses[i] = get_noc_addr( + eth_sender_noc_x, + eth_sender_noc_y, + eth_l1_base_addr + (i * ((num_pages_per_send * page_size) + 16)));//sizeof(eth_channel_sync_t)))); + } + + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + + uint32_t buffer_index = 0; + for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) { + uint32_t num_pages_to_send = std::min(num_pages_per_send, total_pages_to_send - p); + sender.wait_for_empty_write_slot(); + sender.send_payload_blocking(cb_id_in0, num_pages_to_send, page_size); + } + + sender.close(); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp new file mode 100644 index 00000000000..f78c198258c --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp @@ -0,0 +1,1178 @@ + +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "device/tt_arch_types.h" +// #include "tt_backend_api_types.hpp" +#include "tt_metal/common/core_coord.h" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +#include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/test_utils/stimulus.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" + +// #include "impl/kernels/kernel_types.hpp" + +using namespace tt; +using namespace tt::test_utils; +using namespace tt::test_utils::df; + +// Taken from ccl_common... some dependency annoyance to deal with so just copying it here for now... resolve before merging +namespace ttnn { +namespace ccl { +void set_edm_runtime_args( + tt_metal::Program& program, + KernelHandle edm_kernel_handle, + ccl::EriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core +) { + std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); + tt_metal::SetRuntimeArgs(program, edm_kernel_handle, eth_core, edm_clockwise_kernel_rt_args); + + std::stringstream ss; + ss << "EDM ARGS:\n"; + for (auto const& s : edm_clockwise_kernel_rt_args) { + ss << "\t" << s << "\n"; + } + log_info(tt::LogOp, "{}", ss.str()); +} + +} // namespace ccl +} // namespace ttnn + + +class N300TestDevice { + public: + N300TestDevice() : device_open(false) { + arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + + num_devices_ = tt::tt_metal::GetNumAvailableDevices(); + if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() >= 2 and + tt::tt_metal::GetNumPCIeDevices() >= 1) { + std::vector ids(num_devices_,0); + std::iota(ids.begin(), ids.end(), 0); + devices_ = tt::tt_metal::detail::CreateDevices(ids); + + } else { + TT_THROW("This suite can only be run on N300 Wormhole devices"); + } + device_open = true; + } + ~N300TestDevice() { + if (device_open) { + TearDown(); + } + } + + void TearDown() { + device_open = false; + for (auto [device_id, device_ptr] : devices_) { + tt::tt_metal::CloseDevice(device_ptr); + } + } + + std::map devices_; + tt::ARCH arch_; + size_t num_devices_; + + private: + bool device_open; +}; + +struct BankedConfig { + size_t num_pages; + size_t size_bytes; + size_t page_size_bytes; + BufferType input_buffer_type; // = BufferType::L1; + BufferType output_buffer_type; // = BufferType::L1; + tt::DataFormat l1_data_format; // = tt::DataFormat::Float16_b; +}; + +struct KernelXY { + uint16_t x; + uint16_t y; + + uint32_t to_uint32() const { return y << 16 | x; } +}; + +void generate_receiver_worker_kernels( + Program &program, + Device *device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& edm_channel, + uint32_t page_size, + uint32_t num_pages, + std::size_t num_buffers_per_edm_channel, + uint32_t num_pages_per_edm_buffer, + uint32_t worker_semaphore_address, + uint32_t dram_output_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool dest_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + // Just want a dummy DF + uint32_t src0_cb_index = CB::c_in0; + tt::DataFormat df = page_size == 1024 ? tt::DataFormat::Bfp8 : + page_size == 2048 ? tt::DataFormat::Float16 : + tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size); + + CBHandle receiver_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + std::vector receiver_worker_writer_compile_args{ + dest_is_dram, // + num_pages, // + page_size, + num_pages_per_edm_buffer}; + std::vector receiver_worker_writer_runtime_args{dram_output_buffer_base_addr}; + log_info(tt::LogTest, "\tReceiverWriter CT Args"); + for (auto const& arg : receiver_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tReceiverWriter RT Args"); + for (auto const& arg : receiver_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + + std::vector receiver_worker_receiver_compile_args{ + edm_channel.eth_buffer_l1_address, + edm_channel.eth_semaphore_l1_address, + num_buffers_per_edm_channel, + edm_termination_mode + }; + std::vector receiver_worker_receiver_runtime_args{ + num_pages_per_edm_buffer, + num_pages, + page_size, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + worker_semaphore_address, + num_buffers_per_edm_channel}; + log_info(tt::LogTest, "\tReceiverReader CT Args"); + for (auto const& arg : receiver_worker_receiver_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tReceiverReader RT Args"); + for (auto const& arg : receiver_worker_receiver_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + + auto receiver_worker_receiver_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = receiver_worker_receiver_compile_args}); + auto receiver_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = receiver_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs( + program, + receiver_worker_receiver_kernel, + worker_core, + receiver_worker_receiver_runtime_args); + tt_metal::SetRuntimeArgs( + program, + receiver_worker_writer_kernel, + worker_core, + receiver_worker_writer_runtime_args); +} + +void generate_sender_worker_kernels( + Program &program, + Device *device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& edm_channel, + uint32_t page_size, + uint32_t num_pages_total, + std::size_t num_buffers_per_edm_channel, + uint32_t num_pages_per_edm_buffer, + uint32_t worker_semaphore_address, + uint32_t dram_output_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool src_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + std::vector sender_worker_reader_compile_args{ + src_is_dram, // + num_pages_total, // + page_size, + num_pages_per_edm_buffer}; + std::vector sender_worker_reader_runtime_args{dram_output_buffer_base_addr}; + + log_info(tt::LogTest, "\tSenderReader CT Args"); + for (auto const& arg : sender_worker_reader_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderReader RT Args"); + for (auto const& arg : sender_worker_reader_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + std::vector sender_worker_writer_compile_args{ + num_pages_per_edm_buffer, + num_pages_total, + page_size, + num_buffers_per_edm_channel, + edm_termination_mode + }; + std::vector sender_worker_writer_runtime_args{ + edm_channel.eth_buffer_l1_address, + edm_channel.eth_semaphore_l1_address, + worker_semaphore_address, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + num_buffers_per_edm_channel + }; + uint32_t src0_cb_index = CB::c_in0; + log_info(tt::LogTest, "\tSenderWriter CT Args"); + for (auto const& arg : sender_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderWriter RT Args"); + for (auto const& arg : sender_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + // Just want a dummy DF + tt::DataFormat df = page_size == 1024 ? tt::DataFormat::Bfp8 : + page_size == 2048 ? tt::DataFormat::Float16 : + tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size); + CBHandle sender_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + auto sender_worker_reader_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = sender_worker_reader_compile_args}); + auto sender_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = sender_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs( + program, + sender_worker_reader_kernel, + worker_core, + sender_worker_reader_runtime_args); + tt_metal::SetRuntimeArgs( + program, + sender_worker_writer_kernel, + worker_core, + sender_worker_writer_runtime_args); +} + + + +bool RunWriteBWTest( + tt_metal::Device* sender_device, + tt_metal::Device* receiver_device, + + const CoreCoord& eth_sender_core, + const CoreCoord& eth_receiver_core, + + const uint32_t num_local_sender_channels, + const uint32_t num_remote_sender_channels, + + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + std::size_t num_buffers_per_edm_channel, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram, + + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + + std::size_t tensor_size_bytes = num_pages_total * page_size; + + tt_metal::Program sender_program{}; + tt_metal::Program receiver_program{}; + + std::vector worker_cores; + { + std::size_t row = 0; + std::size_t col = 0; + for (uint32_t i = 0; i < num_local_sender_channels + num_remote_sender_channels; i++) { + worker_cores.push_back(CoreCoord(col, row)); + col++; + if (col == 8) { + col = 0; + row++; + } + } + } + + std::vector local_worker_semaphore_addresses; + std::vector remote_worker_semaphore_addresses; + for (auto const& worker_core : worker_cores) { + local_worker_semaphore_addresses.push_back(tt::tt_metal::CreateSemaphore(sender_program, worker_core, 0)); + remote_worker_semaphore_addresses.push_back(tt::tt_metal::CreateSemaphore(receiver_program, worker_core, 0)); + log_info(tt::LogTest, "worker_core=(x={},y={}), local_worker_semaphore_address={}, remote_worker_semaphore_address={}", + worker_core.x, worker_core.y, local_worker_semaphore_addresses.back(), remote_worker_semaphore_addresses.back()); + } + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + auto inputs = generate_uniform_random_vector(0, 100, tensor_size_bytes / sizeof(uint32_t)); + std::iota(inputs.begin(), inputs.end(), 0); + + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + auto local_input_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + auto remote_input_buffer = CreateBuffer(InterleavedBufferConfig{ + receiver_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + bool input_is_dram = test_config.input_buffer_type == BufferType::DRAM; + + tt_metal::detail::WriteToBuffer(local_input_buffer, inputs); + tt_metal::detail::WriteToBuffer(remote_input_buffer, inputs); + + std::vector local_input_buffer_addresses(num_local_sender_channels, local_input_buffer->address()); + std::vector remote_input_buffer_addresses(num_remote_sender_channels, remote_input_buffer->address()); + + //////////////////////////////////////////////////////////////////////////// + // EMPTY INITIALIZE THE OUTPUT CB + //////////////////////////////////////////////////////////////////////////// + + // Clear expected value at ethernet L1 address + std::vector all_zeros(inputs.size(), 0); + + std::vector> local_output_buffers; + std::vector> remote_output_buffers; + + for (std::size_t i = 0; i < num_local_sender_channels; i++) { + auto output_buffer = CreateBuffer(InterleavedBufferConfig{ + receiver_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + remote_output_buffers.push_back(output_buffer); + } + for (std::size_t i = 0; i < num_remote_sender_channels; i++) { + auto output_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + local_output_buffers.push_back(output_buffer); + } + + bool output_is_dram = test_config.output_buffer_type == BufferType::DRAM; + for (auto buffer_id : local_output_buffers) { + tt_metal::detail::WriteToBuffer(buffer_id, all_zeros); + } + for (auto buffer_id : remote_output_buffers) { + tt_metal::detail::WriteToBuffer(buffer_id, all_zeros); + } + + uint32_t erisc_handshake_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + + uint32_t chip0_next_buffer_address = erisc_handshake_address + 16; + std::vector chip0_edm_args = {erisc_handshake_address}; + uint32_t chip0_sender_channels_offset = 0; + uint32_t chip0_arg_sender_num_channels = 1; + + //////////////////////////////////////////////////////////////////////////// + // EDM Builder Setup + //////////////////////////////////////////////////////////////////////////// + + ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED; + + const std::size_t num_edm_channels = num_local_sender_channels + num_remote_sender_channels; + // TODO: Allow an override of EDM buffer size + auto local_chip_edm_builder = ttnn::ccl::create_erisc_datamover_builder( + num_edm_channels, page_size, num_buffers_per_edm_channel, buffer_sharing_mode, edm_termination_mode); + auto remote_chip_edm_builder = ttnn::ccl::create_erisc_datamover_builder( + num_edm_channels, page_size, num_buffers_per_edm_channel, buffer_sharing_mode, edm_termination_mode); + + const uint32_t num_bytes_per_send = local_chip_edm_builder.get_eth_buffer_size_bytes(); + const uint32_t pages_per_send = num_bytes_per_send / page_size; + TT_ASSERT(num_bytes_per_send > 0); + TT_ASSERT(num_bytes_per_send >= page_size); + TT_ASSERT(num_bytes_per_send >= page_size); + const uint32_t num_messages_to_send = (((num_pages_total * page_size) - 1) / num_bytes_per_send) + 1; + log_info(tt::LogTest, "num_bytes_per_send={}", num_bytes_per_send); + log_info(tt::LogTest, "page_size={}", page_size); + log_info(tt::LogTest, "pages_per_send={}", pages_per_send); + log_info(tt::LogTest, "num_messages_to_send={}", num_messages_to_send); + std::vector num_messages_to_send_over_channel(num_edm_channels, num_messages_to_send); + + std::vector local_sender_workers; + std::vector remote_receiver_workers; + std::vector remote_sender_workers; + std::vector local_receiver_workers; + + // setup edm channels + std::vector local_edm_channels; + std::vector remote_edm_channels; + for (uint32_t i = 0; i < num_local_sender_channels; i++) { + auto const& worker_core_local_chip = ttnn::ccl::WorkerXY( + sender_device->worker_core_from_logical_core(worker_cores.at(i)).x, + sender_device->worker_core_from_logical_core(worker_cores.at(i)).y); + auto const& worker_core_remote_chip = ttnn::ccl::WorkerXY( + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).x, + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).y); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& local_sender_channel_buffer = local_chip_edm_builder.add_sender_channel( + local_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_local_chip}); + local_edm_channels.push_back(local_sender_channel_buffer); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& remote_receiver_channel_buffer = remote_chip_edm_builder.add_receiver_channel( + remote_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_remote_chip}); + remote_edm_channels.push_back(remote_receiver_channel_buffer); + } + for (uint32_t i = num_local_sender_channels; i < num_local_sender_channels + num_remote_sender_channels; i++) { + auto const& worker_core_remote_chip = ttnn::ccl::WorkerXY( + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).x, + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).y); + auto const& worker_core_local_chip = ttnn::ccl::WorkerXY( + sender_device->worker_core_from_logical_core(worker_cores.at(i)).x, + sender_device->worker_core_from_logical_core(worker_cores.at(i)).y); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& local_receiver_channel_buffer = local_chip_edm_builder.add_receiver_channel( + local_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_remote_chip}); + local_edm_channels.push_back(local_receiver_channel_buffer); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& remote_sender_channel_buffer = remote_chip_edm_builder.add_sender_channel( + remote_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_local_chip}); + remote_edm_channels.push_back(remote_sender_channel_buffer); + } + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + for (uint32_t i = 0; i < num_local_sender_channels; i++) { + auto const& worker_core = worker_cores.at(i); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", i, worker_core.x, worker_core.y); + generate_sender_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + local_edm_channels.at(i), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + local_worker_semaphore_addresses.at(i), + local_input_buffer_addresses.at(i), + src_is_dram, + edm_termination_mode + ); + generate_receiver_worker_kernels( + receiver_program, + receiver_device, + worker_core, + eth_receiver_core, + remote_edm_channels.at(i), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + remote_worker_semaphore_addresses.at(i), + remote_output_buffers.at(i)->address(), + dest_is_dram, + edm_termination_mode + ); + + } + log_info(tt::LogTest, "Generating remote_sender -> local_receiver workers"); + for (uint32_t i = 0; i < num_remote_sender_channels; i++) { + log_info(tt::LogTest, "Worker {}", i); + auto const& worker_core = worker_cores.at(i + num_local_sender_channels); + generate_sender_worker_kernels( + receiver_program, + receiver_device, + worker_core, + eth_receiver_core, + remote_edm_channels.at(i + num_local_sender_channels), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + remote_worker_semaphore_addresses.at(i + num_local_sender_channels), + remote_input_buffer_addresses.at(i), + src_is_dram, + edm_termination_mode + ); + + generate_receiver_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + local_edm_channels.at(i + num_local_sender_channels), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + local_worker_semaphore_addresses.at(i + num_local_sender_channels), + local_output_buffers.at(i)->address(), + dest_is_dram, + edm_termination_mode + ); + } + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( + sender_program, + sender_device, + local_chip_edm_builder, + eth_sender_core, + NOC::NOC_0); + set_edm_runtime_args( + sender_program, + local_edm_kernel, + local_chip_edm_builder, + eth_sender_core + ); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + receiver_program, + receiver_device, + remote_chip_edm_builder, + eth_receiver_core, + NOC::NOC_0); + set_edm_runtime_args( + receiver_program, + remote_edm_kernel, + remote_chip_edm_builder, + eth_receiver_core + ); + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + + try { + tt::tt_metal::detail::CompileProgram(sender_device, sender_program); + tt::tt_metal::detail::CompileProgram(receiver_device, receiver_program); + } catch (std::exception& e) { + log_error("Failed compile: {}", e.what()); + throw e; + } + + log_info(tt::LogTest, "Running..."); + + if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE")) { + std::thread th2 = std::thread([&] { tt_metal::detail::LaunchProgram(sender_device, sender_program); }); + std::thread th1 = std::thread([&] { tt_metal::detail::LaunchProgram(receiver_device, receiver_program); }); + + th2.join(); + th1.join(); + } else { + tt_metal::EnqueueProgram(sender_device->command_queue(), sender_program, false); + tt_metal::EnqueueProgram(receiver_device->command_queue(), receiver_program, false); + + log_debug(tt::LogTest, "Calling Finish"); + tt_metal::Finish(sender_device->command_queue()); + tt_metal::Finish(receiver_device->command_queue()); + } + // tt::tt_metal::detail::DumpDeviceProfileResults(receiver_device); + // tt::tt_metal::detail::DumpDeviceProfileResults(sender_device); + log_info(tt::LogTest, "Reading back outputs"); + + auto is_output_correct = [&all_zeros, &inputs](std::shared_ptr output_buffer) { + constexpr bool debug_mode = false; + std::vector readback_data_vec; // init to 0 data for easier debug + readback_data_vec.reserve(all_zeros.size()); + std::fill(readback_data_vec.begin(), readback_data_vec.end(), 0); + + tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); + log_info(tt::LogTest, "Checking outputs"); + if (readback_data_vec.size() != inputs.size()) { + log_error(tt::LogTest, "Output size mismatch: expected {} got {}", inputs.size(), readback_data_vec.size()); + return false; + } + bool pass = (readback_data_vec == inputs); + TT_ASSERT( + std::any_of(inputs.begin(), inputs.end(), [](uint32_t x) { return x != 0; }), + "Input buffer expected to not be all 0"); + if (not pass) { + log_error("Output mismatch"); + if (debug_mode) { + std::size_t num_printed_mismatches = 0; + for (size_t i = 0; i < readback_data_vec.size() && num_printed_mismatches < 64; i++) { + if (readback_data_vec[i] != inputs[i]) { + log_error("[{}]: expected {} got {}", i, inputs[i], readback_data_vec[i]); + num_printed_mismatches++; + } + } + log_error("... (remaining mismatches omitted)"); + } + } + return pass; + }; + + bool pass = true; + constexpr bool enable_check = true; + if constexpr(enable_check) { + for (auto const& output_buffer : local_output_buffers) { + pass &= is_output_correct(output_buffer); + } + for (auto const& output_buffer : remote_output_buffers) { + pass &= is_output_correct(output_buffer); + } + } + + + return pass; +} + +int TestEntrypoint( + const uint32_t num_local_sender_channels, + const uint32_t num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + std::size_t num_buffers_per_edm_channel, + const uint32_t page_size, + const uint32_t num_pages_total, + const bool src_is_dram, + const bool dest_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode termination_mode +) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 2) { + log_info("This test can only be run on n300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + N300TestDevice test_fixture; + + const auto& device_0 = test_fixture.devices_.at(0); + + auto const& active_eth_cores = device_0->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + chip_id_t device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end); + std::tie(device_id, eth_receiver_core) = device_0->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (device_id != 1); + TT_ASSERT(device_id == 1); + const auto& device_1 = test_fixture.devices_.at(device_id); + + bool success = false; + try { + success = RunWriteBWTest( + device_0, + device_1, + + eth_sender_core, + eth_receiver_core, + + num_local_sender_channels, // from args + num_remote_sender_channels, // from args + num_buffers_per_edm_channel, // from args + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + + termination_mode); + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +//////////////////////////////////////////////////////////////////// +/// MESSAGE COUNT TERMINATION MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_0ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 0; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_3BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 3; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_2ChannelForward_2ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 2; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + + +//////////////////////////////////////////////////////////////////// +/// WORKER_INITIATED_TERMINATION_MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_3BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 3; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +// EnablePersistentKernelCache diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index 7698269fee8..811e1dce457 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -41,7 +41,7 @@ def run_reduce_scatter_test( mem_config, use_program_cache, function_level_defaults, - enable_async=False, + enable_async=True, num_iters=1, ): if len(t3k_device_mesh.get_device_ids()) != 8: @@ -214,7 +214,7 @@ def run_reduce_scatter_sharded_test( tensor_mem_layout, use_program_cache, function_level_defaults, - enable_async=False, + enable_async=True, num_iters=1, ): if len(t3k_device_mesh.get_device_ids()) != 8: @@ -286,10 +286,6 @@ def run_reduce_scatter_sharded_test( ttl.device.Synchronize(t3k_device_mesh.get_device(device_id)) logger.info(f"Done iteration {i}") - for device_id in t3k_device_mesh.get_device_ids(): - ttl.device.Synchronize(t3k_device_mesh.get_device(device_id)) - logger.info(f"Done iteration {i}") - # Compute golden # TODO: Make it model how reduce scatter actually works for numerical correctness/ordering golden_canonical_out_tensor = torch.zeros(canonical_input_shape).bfloat16() diff --git a/tt_metal/hw/inc/ethernet/dataflow_api.h b/tt_metal/hw/inc/ethernet/dataflow_api.h index 94538d3f65f..d4f62477dd2 100644 --- a/tt_metal/hw/inc/ethernet/dataflow_api.h +++ b/tt_metal/hw/inc/ethernet/dataflow_api.h @@ -149,7 +149,6 @@ void eth_send_bytes_over_channel_payload_only( uint32_t src_addr, uint32_t dst_addr, uint32_t num_bytes, - uint32_t channel, uint32_t num_bytes_per_send = 16, uint32_t num_bytes_per_send_word_size = 1) { // assert(channel < 4); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 4f73c9160e3..a5a2316ac77 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -25,7 +25,7 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con return AllGatherBidirectionalMode::FULL_TENSOR; } -AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology) : +AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology, std::size_t num_buffers_per_worker) : num_links(num_links), semaphore_size(32), ring_size(ring_size), @@ -37,8 +37,11 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM), output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM), - bidirectional_mode(choose_bidirectional_mode(input_tensor)) + bidirectional_mode(choose_bidirectional_mode(input_tensor)), + enable_merged_payload_and_channel_sync(true), + num_buffers_per_worker(num_buffers_per_worker) { + TT_FATAL(num_buffers_per_worker > 0, "num_buffers_per_worker must be > 0"); TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE); TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16); TT_ASSERT((erisc_handshake_address & (16-1)) == 0); @@ -68,15 +71,17 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu this->num_workers_per_link = this->num_eth_buffers; this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16; + // Really should be called offset_after_semaphore_region this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset; + std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); uint32_t const page_size = input_tensor.buffer()->page_size(); - this->eth_buffer_size = tt::round_down((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions), page_size); + std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_buffers_per_worker)) - channel_sync_bytes_overhead; + this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size); + TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_buffers_per_worker) + this->semaphore_offset <= total_l1_buffer_space); TT_FATAL(eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); - TT_FATAL(this->eth_buffer_size * (this->num_eth_buffers * num_duplicate_directions) + this->semaphore_offset <= total_l1_buffer_space); - } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 058225db49d..f3a0956b310 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -41,13 +41,13 @@ class AllGatherConfig { static AllGatherBidirectionalMode choose_bidirectional_mode(Tensor const& input_tensor); public: - AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology); + AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology, std::size_t num_buffers_per_worker); uint32_t get_erisc_handshake_address() const { return this->erisc_handshake_address; } - uint32_t get_semaphores_offset() const { return this->semaphore_offset; } uint32_t get_num_eth_buffers_per_edm() const { return this->num_eth_buffers; } uint32_t get_num_workers_per_link() const { return this->num_workers_per_link; } + uint32_t get_num_buffers_per_worker() const { return this->num_buffers_per_worker; } uint32_t get_num_workers() const { return this->num_workers_per_link * this->num_links; } uint32_t get_eth_buffer_size() const { return this->eth_buffer_size; } @@ -57,6 +57,7 @@ class AllGatherConfig { uint32_t get_eth_buffers_l1_base_byte_address() const { return this->eth_buffers_l1_base_byte_address; } uint32_t get_semaphore_size() const { return this->semaphore_size; } + std::size_t get_num_buffers_per_channel() const { return this->num_buffers_per_worker; } uint32_t get_num_edm_channels_in_clockwise_direction() const { return this->enable_bidirectional ? @@ -64,6 +65,7 @@ class AllGatherConfig { this->num_workers_per_link; } uint32_t get_ring_size() const { return this->ring_size; } + bool is_payload_and_channel_sync_merged() const { return enable_merged_payload_and_channel_sync;} bool is_buffer_in_clockwise_ring(const uint32_t buffer_index) const { // For now we split it as lower half => clockwise, upper half => counter-clockwise // This is slightly suboptimal since the non-full-chunks go to the upper half. @@ -89,6 +91,7 @@ class AllGatherConfig { log_trace(tt::LogOp, "\terisc_handshake_address: {}", erisc_handshake_address); log_trace(tt::LogOp, "\tnum_buffers: {}", num_eth_buffers); log_trace(tt::LogOp, "\tnum_workers_per_link: {}", num_workers_per_link); + log_trace(tt::LogOp, "\tnum_buffers_per_worker: {}", num_buffers_per_worker); log_trace(tt::LogOp, "\teth_buffer_size: {}", eth_buffer_size); log_trace(tt::LogOp, "\tsemaphore_size: {}", semaphore_size); log_trace(tt::LogOp, "\tsemaphore_offset: {}", semaphore_offset); @@ -104,6 +107,7 @@ class AllGatherConfig { uint32_t num_links; uint32_t num_eth_buffers; uint32_t num_workers_per_link; + uint32_t num_buffers_per_worker; uint32_t eth_buffer_size; uint32_t semaphore_size; uint32_t semaphore_offset; @@ -115,6 +119,7 @@ class AllGatherConfig { bool enable_bidirectional; const bool input_is_dram; const bool output_is_dram; + const bool enable_merged_payload_and_channel_sync; }; struct AllGather { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index b8bc34b2a4e..b4bd94c4938 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 +#pragma once #include "dataflow_api.h" #include "debug/assert.h" @@ -33,7 +34,7 @@ FORCE_INLINE void write_and_send_chunk( uint32_t l1_read_addr = get_read_ptr(cb_id); noc_async_write(l1_read_addr, remote_l1_write_addr, page_size * num_pages); noc_semaphore_inc(eth_l1_sender_semaphore_addr, 1); - // TODO: do eth semaphore inc here + for (uint32_t i = 0; i < num_pages; ++i) { #ifdef ROW_MAJOR_LAYOUT #ifdef INTERLEAVED_MEM_LAYOUT diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index c3e69e13104..244dc097a29 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -28,14 +28,14 @@ using namespace ccl; static std::tuple select_worker_cores(AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t link, uint32_t full_send_direction) { constexpr uint32_t worker_grid_width = 8; - const bool fit_sender_and_receiver_workers_on_same_row = (worker_grid_width / 2) >= all_gather_config.get_num_eth_buffers_per_edm(); + const bool fit_sender_and_receiver_workers_on_same_row = (worker_grid_width / 2) >= all_gather_config.get_num_workers_per_link(); std::set receiver_worker_cores = {}; std::set sender_worker_cores = {}; uint32_t max_cols = 8; - uint32_t curr_row = link * (((all_gather_config.get_num_eth_buffers_per_edm() * 2 - 1) / max_cols) + 1) + - (full_send_direction * num_links * (((all_gather_config.get_num_eth_buffers_per_edm() * 2 - 1) / max_cols) + 1)); + uint32_t curr_row = link * (((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1) + + (full_send_direction * num_links * (((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1)); uint32_t curr_col = 0; - for (uint32_t r = 0; r < all_gather_config.get_num_eth_buffers_per_edm(); r++) { + for (uint32_t r = 0; r < all_gather_config.get_num_workers_per_link(); r++) { receiver_worker_cores.insert(CoreRange(CoreCoord(curr_col, curr_row))); curr_col ++; if (curr_col == max_cols) { @@ -43,7 +43,7 @@ static std::tuple select_worker_cores(AllGatherConfig curr_row++; } } - for (uint32_t s = 0; s < all_gather_config.get_num_eth_buffers_per_edm(); s++) { + for (uint32_t s = 0; s < all_gather_config.get_num_workers_per_link(); s++) { sender_worker_cores.insert(CoreRange(CoreCoord(curr_col, curr_row))); curr_col ++; if (curr_col == max_cols) { @@ -61,8 +61,8 @@ static std::vector> compute_worker_sender_num_transfers( std::vector> worker_sender_num_transfers; worker_sender_num_transfers.reserve(num_links); for (uint32_t l = 0; l < num_links; ++l) { - worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); switch (topology) { case all_gather_op::Topology::Linear: @@ -99,8 +99,8 @@ static std::vector> compute_worker_receiver_num_transfers( std::vector> worker_sender_num_transfers; worker_sender_num_transfers.reserve(num_links); for (uint32_t l = 0; l < num_links; ++l) { - worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); switch (topology) { case all_gather_op::Topology::Linear: @@ -177,11 +177,12 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); + std::size_t num_edm_buffers_per_channel = 1; tt::tt_metal::Program program{}; // Issue #10978: CCLs need to be tagged as having multi-device dependencies, when running on Galaxy. program.capture_multi_device_dependencies(); const auto& device = input_tensor.device(); - auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology); + auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology, num_edm_buffers_per_channel); auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); bool enable_print = false; @@ -214,8 +215,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& bool width = input_tensor.get_legacy_shape().rank() - 1 == dim; tt::DataFormat df = datatype_to_dataformat_converter(input_tensor.get_dtype()); - uint32_t global_num_workers = all_gather_config.get_num_eth_buffers_per_edm() * num_links; - std::map worker_defines; if (rm) { worker_defines["ROW_MAJOR_LAYOUT"] = "1"; @@ -235,7 +234,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& const uint32_t num_full_send_directions = full_send_both_directions ? 2 : 1; constexpr uint32_t max_num_full_send_directions = 2; // number of worker cores is 2x this since there is 1 worker for the sender buffer and 1 worker for the receiver buffer - uint32_t total_worker_core_pairs_used = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions; + uint32_t global_num_workers = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions; + uint32_t total_worker_core_pairs_used = global_num_workers; uint32_t num_input_pages = input_tensor.buffer()->size() / input_page_size; uint32_t min_pages_per_link = num_input_pages / num_links; @@ -257,26 +257,41 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& auto edm_sem_addrs_per_link = std::vector>(num_links); auto edm_buffer_addrs_per_link = std::vector>(num_links); for (uint32_t link = 0; link < num_links; link++) { - edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); - edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); + edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_workers_per_link() * num_full_send_directions); + edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_workers_per_link() * num_full_send_directions); uint32_t edm_sem_addr = all_gather_config.get_eth_sems_l1_base_byte_address(); uint32_t edm_buffer_addr = all_gather_config.get_eth_buffers_l1_base_byte_address(); for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { edm_sem_addrs_per_link.at(link).push_back(edm_sem_addr); edm_sem_addr += all_gather_config.get_semaphore_size(); edm_buffer_addrs_per_link.at(link).push_back(edm_buffer_addr); - edm_buffer_addr += all_gather_config.get_eth_buffer_size(); + edm_buffer_addr += ((all_gather_config.get_eth_buffer_size() + + (all_gather_config.is_payload_and_channel_sync_merged() > 0 ? EriscDatamoverConfig::get_eth_word_size() : 0)) * all_gather_config.get_num_buffers_per_channel()); TT_ASSERT((direction == 0 && b == 0) || (edm_buffer_addrs_per_link.at(link).back() != edm_buffer_addrs_per_link.at(link).front())); TT_ASSERT((direction == 0 && b == 0) || (edm_sem_addrs_per_link.at(link).back() != edm_sem_addrs_per_link.at(link).front())); } } clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), + all_gather_config.get_erisc_handshake_address(), + edm_sem_addrs_per_link.at(link), + edm_buffer_addrs_per_link.at(link), + ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + all_gather_config.get_num_buffers_per_channel(), + input_tensor.device()->id()); counter_clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), + all_gather_config.get_erisc_handshake_address(), + edm_sem_addrs_per_link.at(link), + edm_buffer_addrs_per_link.at(link), + ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + all_gather_config.get_num_buffers_per_channel(), + input_tensor.device()->id()); } for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { @@ -358,8 +373,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // We can't have overlap between the mcast grid for worker cores for different links since mcasting the semaphore in receiver would corrupt other link semaphores // We can have overlap between a link's sender and receiver worker grids if we have the semaphores at different addresses auto const& [receiver_workers, sender_workers] = select_worker_cores(all_gather_config, num_links, i, direction); - uint32_t worker_index = 0; - uint32_t workers_per_link = all_gather_config.get_num_workers_per_link() / all_gather_config.get_num_eth_buffers_per_edm(); // Circular Buffer Setup uint32_t cb_page_size = input_page_size; @@ -386,17 +399,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // number of pages that can fit in a single ethernet L1 buffer (not the number of pages sent to this channel) std::vector pages_per_eth_l1_buffer; - pages_per_buffer.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + pages_per_buffer.reserve(all_gather_config.get_num_workers_per_link()); uint32_t max_pages_per_eth_l1_sender_buffer = all_gather_config.get_eth_buffer_size() / input_page_size; - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - pages_per_buffer.push_back((pages_per_link.at(i) / all_gather_config.get_num_eth_buffers_per_edm())); + for(uint32_t w = 0; w < all_gather_config.get_num_workers_per_link(); ++w) { + pages_per_buffer.push_back((pages_per_link.at(i) / all_gather_config.get_num_workers_per_link())); pages_per_eth_l1_buffer.push_back(max_pages_per_eth_l1_sender_buffer); - if (b < pages_per_link.at(i) % all_gather_config.get_num_eth_buffers_per_edm()) { + if (w < pages_per_link.at(i) % all_gather_config.get_num_workers_per_link()) { pages_per_buffer.back()++; } log_trace(tt::LogOp, "pages_per_link[{}]: {}", i, pages_per_link.at(i)); - log_trace(tt::LogOp, "pages_per_buffer[{}]: {}", b, pages_per_buffer.at(b)); + log_trace(tt::LogOp, "pages_per_buffer[{}]: {}", w, pages_per_buffer.at(w)); log_trace(tt::LogOp, "max_pages_per_eth_l1_sender_buffer: {}",max_pages_per_eth_l1_sender_buffer); } TT_ASSERT(std::accumulate(pages_per_buffer.begin(), pages_per_buffer.end(), 0) == pages_per_link.at(i)); @@ -420,41 +433,41 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(rem_pages < pages_per_chunk || num_full_chunks == 0); TT_ASSERT(rem_pages <= max_pages_per_chunk); - std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(),0); - std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); - std::vector is_channel_shrinkable(all_gather_config.get_num_eth_buffers_per_edm(), false); - std::vector largest_packets_per_channel(all_gather_config.get_num_eth_buffers_per_edm(), 0); + std::vector num_full_chunks_per_worker(all_gather_config.get_num_workers_per_link(),0); + std::vector rem_pages_per_worker(all_gather_config.get_num_workers_per_link(), 0); + std::vector is_channel_shrinkable(all_gather_config.get_num_workers_per_link(), false); + std::vector largest_packets_per_channel(all_gather_config.get_num_workers_per_link(), 0); std::vector clockwise_link_buffer_num_messages_to_send; std::vector counter_clockwise_link_buffer_num_messages_to_send; std::vector edm_semaphores_base_address; std::vector link_buffer_sender_addresses; - clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - edm_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - link_buffer_sender_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); + counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); + edm_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link()); + link_buffer_sender_addresses.reserve(all_gather_config.get_num_workers_per_link()); { - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { - num_full_chunks_per_worker.at(b) = num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm(); + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { + num_full_chunks_per_worker.at(b) = num_full_chunks / all_gather_config.get_num_workers_per_link(); } uint32_t worker_idx = 0; - for (worker_idx = 0; worker_idx < num_full_chunks % all_gather_config.get_num_eth_buffers_per_edm(); ++worker_idx) { + for (worker_idx = 0; worker_idx < num_full_chunks % all_gather_config.get_num_workers_per_link(); ++worker_idx) { num_full_chunks_per_worker.at(worker_idx)++; } if (rem_pages != 0) { - rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) = rem_pages; - TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) * 2 <= cb_num_pages); + rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) = rem_pages; + TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= cb_num_pages); } { // Logging log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (clockwise):"); - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { if (is_buffer_in_clockwise_direction(b)) { log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); } } log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (counter-clockwise):"); - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { if (!is_buffer_in_clockwise_direction(b)) { log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); } @@ -467,7 +480,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& is_channel_shrinkable.at(b) = shrinkable; largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size(); } - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { // link num messages clockwise_link_buffer_num_messages_to_send.push_back( (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * @@ -476,7 +489,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * receiver_worker_num_transfers.at(i).at(b)); } - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b)); log_trace(tt::LogOp, "num_full_chunks_per_worker[{}]: {}", b, num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "clockwise_link_buffer_num_messages_to_send[{}]: {}", b, clockwise_link_buffer_num_messages_to_send.at(b)); @@ -485,31 +498,27 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector receiver_semaphores_base_address; std::vector link_buffer_receiver_addresses; - receiver_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - link_buffer_receiver_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + receiver_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link()); + link_buffer_receiver_addresses.reserve(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { receiver_semaphores_base_address.push_back(all_gather_config.get_eth_sems_l1_base_byte_address() + b * all_gather_config.get_semaphore_size()); link_buffer_receiver_addresses.push_back(all_gather_config.get_eth_buffers_l1_base_byte_address() + b * all_gather_config.get_eth_buffer_size()); } - std::vector sender_eth_sem_addrs; sender_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector sender_eth_buffer_addrs; sender_eth_buffer_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector receiver_eth_sem_addrs; receiver_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector receiver_eth_buffer_addrs; receiver_eth_buffer_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - uint32_t num_workers_per_eth_buffer = std::min(workers_per_link, all_gather_config.get_num_eth_buffers_per_edm() - worker_index); - + std::vector sender_eth_sem_addrs; sender_eth_sem_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector sender_eth_buffer_addrs; sender_eth_buffer_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector receiver_eth_sem_addrs; receiver_eth_sem_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector receiver_eth_buffer_addrs; receiver_eth_buffer_addrs.reserve(all_gather_config.get_num_workers_per_link()); + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { std::vector sender_worker_coords; std::vector receiver_worker_coords; - for (uint32_t w = b * num_workers_per_eth_buffer; w < (b + 1) * num_workers_per_eth_buffer; ++w) { - sender_worker_coords.push_back( - ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(sender_worker_cores.at(w)).x, - device->worker_core_from_logical_core(sender_worker_cores.at(w)).y)); - receiver_worker_coords.push_back( - ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(receiver_worker_cores.at(w)).x, - device->worker_core_from_logical_core(receiver_worker_cores.at(w)).y)); - } + sender_worker_coords.push_back( + ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(sender_worker_cores.at(b)).x, + device->worker_core_from_logical_core(sender_worker_cores.at(b)).y)); + receiver_worker_coords.push_back( + ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(receiver_worker_cores.at(b)).x, + device->worker_core_from_logical_core(receiver_worker_cores.at(b)).y)); bool sender_enabled = (!is_linear || !is_last_chip_in_chain); if (sender_enabled) { @@ -543,9 +552,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } - // 1 Worker per buffer - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - uint32_t global_worker_index = all_gather_config.get_num_eth_buffers_per_edm() * i + b; + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { + uint32_t global_worker_index = all_gather_config.get_num_workers_per_link() * i + b; bool is_clockwise_direction = is_buffer_in_clockwise_direction(b); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index a26499b9693..7d55b8ba911 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -256,6 +256,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( } } + KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, @@ -294,6 +295,7 @@ KernelHandle generate_edm_kernel( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, + std::size_t num_buffers_per_channel, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, ccl::EriscDataMoverTerminationMode termination_mode) { TT_ASSERT(num_channels > 0); @@ -304,23 +306,26 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); - const uint32_t buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); + const uint32_t channel_buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; + TT_ASSERT(edm_buffer_addr % EriscDatamoverConfig::get_eth_word_size() == 0); edm_buffer_addresses.at(c) = edm_buffer_addr; - edm_buffer_addr += buffer_size; + log_trace(tt::LogOp, " edm_buffer_addresses({}) = {}", c, edm_buffer_addr); + edm_buffer_addr += num_buffers_per_channel * (channel_buffer_size + (ccl::EriscDatamoverConfig::enable_merged_payload_and_channel_sync ? ccl::EriscDatamoverConfig::get_eth_channel_sync_size_bytes() : 0)); TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } return ccl::EriscDatamoverBuilder( - buffer_size, + channel_buffer_size, ccl::EriscDatamoverConfig::get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode, - termination_mode); + termination_mode, + num_buffers_per_channel); } template @@ -340,12 +345,8 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor this->slice_dim_is_width = input_tensor.get_legacy_shape().rank() - 1 == slice_dim; this->is_sharded = input_tensor.is_sharded(); - int32_t shard_size_in_bytes = - is_sharded ? (input_tensor.buffer()->page_size() * input_tensor.buffer()->shard_spec().tensor2d_shape[0] * - input_tensor.buffer()->shard_spec().tensor2d_shape[1]) / - input_tensor.shard_spec()->num_cores() - : -1; - this->input_page_size = is_sharded ? shard_size_in_bytes : input_tensor.buffer()->page_size(); + this->input_page_size = input_tensor.buffer()->page_size(); + log_trace(tt::LogOp, "input_page_size={}", input_page_size); if (row_major) { this->num_cols = input_tensor.get_legacy_shape()[-1]; auto input_shape = input_tensor.get_legacy_shape(); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 6719203d8df..a47e7fd0ea0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -455,6 +455,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, + std::size_t num_buffers_per_channel, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp index 3d7a556a8f0..fd9e061b248 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp @@ -3,12 +3,55 @@ // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" namespace ttnn { namespace ccl { +std::size_t EriscDatamoverConfig::get_eth_channel_sync_size_bytes() { return eth_channel_sync_size_bytes; } + +uint32_t EriscDatamoverConfig::get_edm_handshake_address() { return usable_l1_base_address; } + +std::size_t EriscDatamoverConfig::get_semaphores_region_size(std::size_t num_edm_channels) { + return (num_edm_channels * semaphore_size); +} +std::size_t EriscDatamoverConfig::get_semaphores_region_start_offset(std::size_t num_edm_channels) { + return handshake_location_size + edm_receiver_first_level_ack_source_word_size; +} +uint32_t EriscDatamoverConfig::get_semaphores_base_address(std::size_t num_edm_channels) { + return usable_l1_base_address + get_semaphores_region_start_offset(num_edm_channels); +} +uint32_t EriscDatamoverConfig::get_buffers_region_start_offset(std::size_t num_edm_channels) { + return get_semaphores_region_start_offset(num_edm_channels) + get_semaphores_region_size(num_edm_channels); +} +std::size_t EriscDatamoverConfig::get_eth_word_size() { return eth_word_size_bytes; } +uint32_t EriscDatamoverConfig::get_buffers_base_address(std::size_t num_edm_channels) { + uint32_t base_address = tt::round_up(usable_l1_base_address + get_buffers_region_start_offset(num_edm_channels), eth_word_size_bytes); + TT_ASSERT(base_address % eth_word_size_bytes == 0); + return base_address; +} +uint32_t EriscDatamoverConfig::compute_buffer_size(std::size_t num_edm_channels, std::size_t num_buffers_per_channel, uint32_t page_size) { + page_size = std::max(page_size, eth_word_size_bytes); + TT_ASSERT(num_edm_channels > 0); + std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); + std::size_t total_usable_space = total_l1_buffer_space - get_buffers_region_start_offset(num_edm_channels); + std::size_t l1_per_buffer_region = (total_usable_space / (num_edm_channels * num_buffers_per_channel)) - channel_sync_bytes_overhead; + uint32_t buffer_size = tt::round_down(l1_per_buffer_region, page_size); + log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); + log_trace( + tt::LogOp, "get_buffers_base_address(num_edm_channels): {}", get_buffers_base_address(num_edm_channels)); + log_trace( + tt::LogOp, "usable buffer space: {}", total_l1_buffer_space - get_buffers_base_address(num_edm_channels)); + log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); + log_trace(tt::LogOp, "page_size: {}", page_size); + + log_trace(tt::LogOp, "Buffer size: {}", buffer_size); + + TT_ASSERT(buffer_size > 0 && buffer_size % page_size == 0); + return buffer_size; +} + CCLOpConfig::CCLOpConfig( std::vector& input_tensors, const std::vector& output_tensors, Topology topology) : input_tensors(&input_tensors), diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index 0529c10e805..af3dfae46af 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -5,7 +5,7 @@ #pragma once #include "eth_l1_address_map.h" -#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include @@ -21,42 +21,26 @@ struct EriscDatamoverConfig { static constexpr std::size_t semaphore_size = 32; static constexpr std::size_t handshake_location_size = 16; // ethernet word size + static constexpr std::size_t handshake_padding_multiple = 3; // ethernet word size // The EDM uses this fixed address as a source for a first level ack sent from receiver -> sender // side. We have this dedicated source address to avoid a race between first and second level ack // where second level ack overwrites the first level ack in L1 before the first one is sent out. // The memory contents in L1 will be {1, 1, x, x}. By having this dedicated source memory, we // avoid the race static constexpr std::size_t edm_receiver_first_level_ack_source_word_size = 16; // ethernet word size + static constexpr std::size_t eth_channel_sync_size_bytes = 16; static constexpr std::size_t eth_word_size_bytes = 16; + static constexpr bool enable_merged_payload_and_channel_sync = true; + static std::size_t get_eth_channel_sync_size_bytes(); + static uint32_t get_edm_handshake_address(); + static std::size_t get_semaphores_region_size(std::size_t num_edm_channels); + static std::size_t get_semaphores_region_start_offset(std::size_t num_edm_channels); + static uint32_t get_semaphores_base_address(std::size_t num_edm_channels); + static uint32_t get_buffers_region_start_offset(std::size_t num_edm_channels); + static std::size_t get_eth_word_size(); + static uint32_t get_buffers_base_address(std::size_t num_edm_channels); + static uint32_t compute_buffer_size(std::size_t num_edm_channels, std::size_t num_buffers_per_channel = 1, uint32_t page_size = eth_word_size_bytes); - static uint32_t get_edm_handshake_address() { return usable_l1_base_address; } - static uint32_t get_semaphores_base_address(std::size_t num_edm_channels) { - return usable_l1_base_address + (handshake_location_size * 3) + edm_receiver_first_level_ack_source_word_size; - } - static uint32_t get_buffers_base_address(std::size_t num_edm_channels) { - uint32_t base_address =tt::round_up( - get_semaphores_base_address(num_edm_channels) + num_edm_channels * semaphore_size, eth_word_size_bytes); - TT_ASSERT(base_address % eth_word_size_bytes == 0); - return base_address; - } - static uint32_t compute_buffer_size(std::size_t num_edm_channels, uint32_t page_size = eth_word_size_bytes) { - page_size = std::max(page_size, eth_word_size_bytes); - TT_ASSERT(num_edm_channels > 0); - uint32_t buffer_size =tt::round_down( - (total_l1_buffer_space - get_buffers_base_address(num_edm_channels)) / (num_edm_channels), page_size); - log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); - log_trace( - tt::LogOp, "get_buffers_base_address(num_edm_channels): {}", get_buffers_base_address(num_edm_channels)); - log_trace( - tt::LogOp, "usable buffer space: {}", total_l1_buffer_space - get_buffers_base_address(num_edm_channels)); - log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - log_trace(tt::LogOp, "page_size: {}", page_size); - - log_trace(tt::LogOp, "Buffer size: {}", buffer_size); - - TT_ASSERT(buffer_size > 0 && buffer_size % page_size == 0); - return buffer_size; - } }; struct CCLOpConfig { @@ -97,6 +81,7 @@ class EriscDatamoverBuilder { uint32_t worker_semaphore_id, uint32_t num_eth_messages_to_forward, uint32_t channel, + uint32_t num_buffers, std::vector const& worker_coords, uint32_t largest_message_size_bytes = 0) : worker_coords(worker_coords), @@ -111,6 +96,7 @@ class EriscDatamoverBuilder { uint32_t num_eth_messages_to_forward; uint32_t channel; uint32_t largest_message_size_bytes; + uint32_t num_buffers; bool is_sender; }; @@ -143,6 +129,8 @@ class EriscDatamoverBuilder { ccl::EriscDataMoverTerminationMode const termination_mode; uint32_t num_senders; uint32_t num_receivers; + std::size_t num_buffers_per_channel; + chip_id_t chip_id; bool enable_sender; bool enable_receiver; @@ -160,19 +148,24 @@ class EriscDatamoverBuilder { std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ccl::EriscDataMoverTerminationMode termination_mode = - ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : + ccl::EriscDataMoverTerminationMode termination_mode = ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + std::size_t num_buffers_per_channel = 1, + chip_id_t chip_id = -1) : local_semaphore_addresses(local_semaphore_addresses), local_buffer_addresses(local_buffer_addresses), eth_buffer_size_bytes(eth_buffer_size), handshake_addr(handshake_addr), num_channel_buffers(local_buffer_addresses.size()), buffer_sharing_mode(buffer_sharing_mode), + num_buffers_per_channel(num_buffers_per_channel), termination_mode(termination_mode), enable_sender(false), enable_receiver(false), num_senders(0), - num_receivers(0) { + num_receivers(0), + chip_id(chip_id) { + + TT_ASSERT(num_buffers_per_channel > 0); TT_ASSERT(local_buffer_addresses.size() == local_semaphore_addresses.size()); active_channels.reserve(num_channel_buffers); TT_ASSERT(eth_buffer_size_bytes < 163000); @@ -199,7 +192,7 @@ class EriscDatamoverBuilder { this->num_senders++; auto channel = active_channels.size(); active_channels.emplace_back( - true, worker_semaphore_id, num_eth_messages_to_forward, channel, worker_coords, expected_message_size_bytes); + true, worker_semaphore_id, num_eth_messages_to_forward, channel, this->num_buffers_per_channel, worker_coords, expected_message_size_bytes); log_trace(tt::LogOp, "Adding sender channel:"); log_trace(tt::LogOp, "\tworker_semaphore_id: {}", active_channels.back().worker_semaphore_id); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -229,7 +222,7 @@ class EriscDatamoverBuilder { this->num_receivers++; auto channel = active_channels.size(); active_channels.emplace_back( - false, worker_semaphore_id, num_eth_messages_to_forward, channel, worker_coords, expected_message_size_bytes); + false, worker_semaphore_id, num_eth_messages_to_forward, channel, this->num_buffers_per_channel, worker_coords, expected_message_size_bytes); log_trace(tt::LogOp, "Adding receiver channel:"); log_trace(tt::LogOp, "\tworker_semaphore_id: {}", active_channels.back().worker_semaphore_id); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -247,7 +240,12 @@ class EriscDatamoverBuilder { this->num_senders, this->num_receivers, this->buffer_sharing_mode, - this->termination_mode}; + this->termination_mode, + 1, + static_cast(this->num_senders > 0 && active_channels.at(0).is_sender), + this->num_buffers_per_channel, + chip_id + }; } [[nodiscard]] diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp new file mode 100644 index 00000000000..50692a5f4cb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "dataflow_api.h" + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" + +namespace ccl { +namespace edm { + +template +struct WorkerToEdmReader{ + constexpr WorkerToEdmReader ( + ttnn::ccl::WorkerXY edm_worker_xy, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_addr, + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_l1_sem_addr)), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(edm_l1_sem_addr), + buffer_size_bytes(buffer_size_bytes), + buffer_index(0) + {} + + FORCE_INLINE void wait_for_payload_available() const { + noc_semaphore_wait_min(worker_sem_addr, 1); + if (*worker_sem_addr > 1) { + DPRINT << "ERROR!!!!!!!!!!!!!!!!!!!!!!!!!\n"; + ASSERT(false); + } + noc_semaphore_set(worker_sem_addr, 0); + } + + FORCE_INLINE void fetch_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size, bool last_message) { + uint64_t buffer_address = edm_buffer_addr + (buffer_index * (buffer_size_bytes + sizeof(eth_channel_sync_t))); + fetch_chunk(cb_id, num_pages, page_size, buffer_address); + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + if (!last_message) { + DPRINT << "fetch_payload_blocking: incrementing semaphore to " << (uint32_t)(edm_semaphore_addr & 0xFFFFFFFF) << "\n"; + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + } + } else { + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + } + buffer_index = (buffer_index == last_buffer_index) ? 0 : buffer_index + 1; + } + + FORCE_INLINE void fetch_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + // With worker initiated termination mode, we must always specify if we are sending the last message or not + ASSERT(termination_mode != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); + fetch_payload_blocking(cb_id, num_pages, page_size, false); + } + + FORCE_INLINE void close() { + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t buffer_index; +}; + + +template +struct WorkerToEdmSender{ + constexpr WorkerToEdmSender ( + ttnn::ccl::WorkerXY edm_worker_xy, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_addr, + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_l1_sem_addr)), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(edm_l1_sem_addr), + buffer_size_bytes(buffer_size_bytes), + buffer_index(0) + { + ASSERT(buffer_size_bytes > 0); + } + + FORCE_INLINE void wait_for_empty_write_slot() const { + noc_semaphore_wait(worker_sem_addr, 1); + noc_semaphore_set(worker_sem_addr, 0); + } + + FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + uint64_t buffer_address = edm_buffer_addr + (buffer_index * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + DPRINT << "SENDER SEND buffer_size_bytes = " << (uint32_t)(this->buffer_size_bytes) << "\n"; + DPRINT << "SENDER SEND " << (uint32_t)(buffer_address & 0xffffffff) << " -> " << (uint32_t)((buffer_address & 0xffffffff) + (page_size * num_pages)) << "\n"; + send_chunk(cb_id, num_pages, page_size, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + buffer_index = (buffer_index == last_buffer_index) ? 0 : buffer_index + 1; + } + + FORCE_INLINE void close() { + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + this->wait_for_empty_write_slot(); + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t buffer_index; +}; + + +} // namespace edm +} // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index d1037849dd8..d865141dc42 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -6,7 +6,6 @@ #include "dataflow_api.h" #include "debug/assert.h" -#include "debug/dprint.h" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" using ttnn::ccl::ShardType; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp index 6c45db409de..a0232abc8b9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp @@ -2,16 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 +#pragma once + #include #include -#include #include "dataflow_api.h" #include "debug/assert.h" #include "eth_l1_address_map.h" #include "ethernet/dataflow_api.h" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "tt_metal/hw/inc/wormhole/noc/noc.h" using ttnn::ccl::EriscDataMoverBufferSharingMode; using ttnn::ccl::EriscDataMoverTerminationMode; @@ -20,10 +20,11 @@ using ttnn::ccl::EriscDataMoverWorkerSignal; namespace erisc { namespace datamover { -template +template struct EriscDatamoverConfig { static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = buffer_sharing_mode; static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = termination_mode; + static constexpr uint8_t NUM_BUFFERS_PER_CHANNEL = num_buffers_per_channel; }; template @@ -56,43 +57,46 @@ class ChannelBuffer final { enum STATE : uint8_t { DONE = 0, - // For sender: means we are ready to tell the worker(s) that the buffer is available for writing into - // - SIGNALING_WORKER, + // we are ready to tell the worker(s) that the buffer is available for writing into + SENDER_SIGNALING_WORKER, - // For sender: we are waiting for the payload to arrive in L1; we are checking local semaphore for worker - // completion For receiver: we are waiting for worker to complete pull of payload from L1; we are checking local - // semaphore for worker completion - WAITING_FOR_WORKER, + // we are waiting for the payload to arrive in L1; we are checking local semaphore for worker + // completion + SENDER_WAITING_FOR_WORKER, + + // means workers have signalled (via semaphores) that the buffer payload is + SENDER_READY_FOR_ETH_TRANSFER, + + // means we are waiting for ack from receiver that payload was received + SENDER_WAITING_FOR_ETH, + + // We received a packet from ethernet and we can signal the downstream worker to signal + // packet availability + RECEIVER_SIGNALING_WORKER, - // For sender: means workers have signalled (via semaphores) that the buffer payload is - // ready in L1 - // For receiver: - READY_FOR_ETH_TRANSFER, + // we are waiting for worker to complete pull of payload from L1; we are checking local + // semaphore for worker completion + RECEIVER_WAITING_FOR_WORKER, - // For sender: means we are waiting for ack from receiver that payload was received - // For receiver: means we are waitinf for a payload from sender - WAITING_FOR_ETH, + // means we are waitinf for a payload from sender + RECEIVER_WAITING_FOR_ETH, }; // for default initialization in arrays ChannelBuffer() : local_semaphore_address(0), worker_coords(0), - address(0), size_in_bytes(0), worker_semaphore_l1_address(0), num_workers(0), num_messages_moved(0), - channel_bytes_sent_address(0), - channel_bytes_acked_address(0), total_num_messages_to_move(0), state(STATE::DONE) {} ChannelBuffer( uint32_t eth_transaction_channel, size_t address, - size_t size_in_bytes, + size_t payload_size_in_bytes, uint32_t worker_semaphore_l1_address, uint32_t num_workers, uint32_t total_num_messages_to_move, @@ -102,19 +106,38 @@ class ChannelBuffer final { eth_transaction_channel(eth_transaction_channel), local_semaphore_address(local_semaphore_address), worker_coords(worker_coords), - address(address), - size_in_bytes(size_in_bytes), + size_in_bytes(payload_size_in_bytes + sizeof(eth_channel_sync_t)), worker_semaphore_l1_address(worker_semaphore_l1_address), num_workers(num_workers), num_messages_moved(0), - channel_bytes_sent_address(&erisc_info->channels[eth_transaction_channel].bytes_sent), - channel_bytes_acked_address(&erisc_info->channels[eth_transaction_channel].receiver_ack), total_num_messages_to_move(total_num_messages_to_move), - state(is_sender_side ? STATE::WAITING_FOR_WORKER : STATE::WAITING_FOR_ETH), + state( + is_sender_side ? TERMINATION_MODE == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED + ? STATE::SENDER_WAITING_FOR_WORKER + : STATE::SENDER_WAITING_FOR_WORKER + : TERMINATION_MODE == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED + ? STATE::RECEIVER_WAITING_FOR_ETH + : STATE::RECEIVER_WAITING_FOR_ETH), + + buffer_index(0), is_sender_completion_pending(false), is_sender_side(is_sender_side) { clear_local_semaphore(); + for (uint8_t i = 0; i < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL; i++) { + this->addresses[i] = address + i * (this->size_in_bytes); + + uint32_t channel_sync_addr = this->addresses[i] + payload_size_in_bytes; + volatile uint32_t* bytes_sent_addr = &(reinterpret_cast(channel_sync_addr)->bytes_sent); + volatile uint32_t* bytes_acked_addr = &(reinterpret_cast(channel_sync_addr)->receiver_ack); + channel_bytes_sent_addresses[i] = bytes_sent_addr; + channel_bytes_acked_addresses[i] = bytes_acked_addr; + + ASSERT((uint32_t)channel_bytes_acked_addresses[i] != (uint32_t)channel_bytes_sent_addresses[i]); + *(channel_bytes_sent_addresses[i]) = 0; + *(channel_bytes_acked_addresses[i]) = 0; + } + if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel @@ -126,7 +149,9 @@ class ChannelBuffer final { } } // Resets the semaphore in local L1, which workers write to remotely. - FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + FORCE_INLINE void clear_local_semaphore() { + noc_semaphore_set(local_semaphore_address, 0); + } // Increment the semaphore in the remote L1s of every worker associated with this ChannelBuffer FORCE_INLINE void increment_worker_semaphores() { @@ -185,15 +210,15 @@ class ChannelBuffer final { [[nodiscard]] FORCE_INLINE bool is_done() const { return this->state == STATE::DONE; } [[nodiscard]] FORCE_INLINE uint32_t get_eth_transaction_channel() const { - ASSERT(this->eth_transaction_channel < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); return this->eth_transaction_channel; } - [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->address; } [[nodiscard]] FORCE_INLINE std::size_t get_size_in_bytes() const { return this->size_in_bytes; } [[nodiscard]] FORCE_INLINE std::size_t get_current_payload_size() const { return this->get_size_in_bytes(); } - [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { return this->address; } + [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { + return this->addresses[buffer_index]; } + [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->get_buffer_address(); } FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } FORCE_INLINE void increment_messages_moved() { this->num_messages_moved++; } @@ -204,27 +229,72 @@ class ChannelBuffer final { FORCE_INLINE void set_send_completion_pending(bool value) { this->is_sender_completion_pending = value; } [[nodiscard]] FORCE_INLINE bool is_send_completion_pending() const { return this->is_sender_completion_pending; } - FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0; } - FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0; } - FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_address; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_address; } + FORCE_INLINE bool eth_is_receiver_channel_send_done() const { + ASSERT(buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + return *(this->channel_bytes_sent_addresses[buffer_index]) == 0; } + FORCE_INLINE bool eth_bytes_are_available_on_channel() const { + ASSERT(buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + return *(this->channel_bytes_sent_addresses[buffer_index]) != 0; + } + FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { + return *(this->channel_bytes_acked_addresses[buffer_index]) != 0; } + FORCE_INLINE void eth_clear_sender_channel_ack() const { *(this->channel_bytes_acked_addresses[buffer_index]) = 0; } + FORCE_INLINE void eth_receiver_channel_ack(uint32_t eth_transaction_ack_word_addr) const { + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)[0] == 1); + reinterpret_cast(eth_transaction_ack_word_addr)[1] = 1; + // Make sure we don't alias the erisc_info eth_channel_sync_t + ASSERT(eth_transaction_ack_word_addr != ((uint32_t)(this->channel_bytes_acked_addresses[buffer_index])) >> 4); + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)->bytes_sent != 0); + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)->receiver_ack == 1); + internal_::eth_send_packet( + 0, + eth_transaction_ack_word_addr >> 4, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + 1); + } + FORCE_INLINE void eth_receiver_channel_done() const { + *(this->channel_bytes_sent_addresses[buffer_index]) = 0; + *(this->channel_bytes_acked_addresses[buffer_index]) = 0; + internal_::eth_send_packet( + 0, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + 1); + } + + FORCE_INLINE void advance_buffer_index() { + if constexpr (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL == 1) { + return; + } else if constexpr (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL == 2) { + this->buffer_index = 1 - this->buffer_index; + } else if constexpr (((EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL) & (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1)) == 0) { + this->buffer_index = (buffer_index + 1) & (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1); + } else { + this->buffer_index = (buffer_index == EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1) ? 0 : buffer_index + 1; + } + + ASSERT(this->buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + } + + volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_addresses[buffer_index]; } + volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_addresses[buffer_index]; } public: uint32_t eth_transaction_channel; // volatile tt_l1_ptr uint32_t *const local_semaphore_address; WorkerXY const *const worker_coords; - std::size_t const address; + std::array addresses; std::size_t const size_in_bytes; // Even for multiple workers, this address will be the same std::size_t const worker_semaphore_l1_address; uint32_t const num_workers; uint32_t num_messages_moved; - volatile tt_l1_ptr uint32_t *const channel_bytes_sent_address; - volatile tt_l1_ptr uint32_t *const channel_bytes_acked_address; + std::array channel_bytes_sent_addresses; + std::array channel_bytes_acked_addresses; const uint32_t total_num_messages_to_move; STATE state; edm_worker_index worker_index; + uint8_t buffer_index; bool is_sender_completion_pending; bool is_sender_side; }; @@ -280,6 +350,13 @@ class QueueIndexPointer { uint8_t wrap_around; }; + +/* + * Before any payload messages can be exchanged over the link, we must ensure that the other end + * of the link is ready to start sending/receiving messages. We perform a handshake to ensure that's + * case. Before handshaking, we make sure to clear any of the channel sync datastructures local + * to our core. + */ FORCE_INLINE void eth_setup_handshake(std::uint32_t handshake_register_address, bool is_sender) { reinterpret_cast(handshake_register_address)[4] = 1; reinterpret_cast(handshake_register_address)[5] = 1; @@ -318,32 +395,27 @@ FORCE_INLINE void initialize_transaction_buffer_addresses( ///////////////////////////////////////////// // SENDER SIDE HELPERS ///////////////////////////////////////////// - template FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { bool did_something = false; if (sender_buffer_channel.eth_is_receiver_channel_send_done()) { bool need_to_send_completion = sender_buffer_channel.is_send_completion_pending(); - if (!sender_buffer_channel.is_send_completion_pending() && !eth_txq_is_busy()) { + if (!eth_txq_is_busy()) { static constexpr std::size_t ETH_BYTES_TO_WORDS_SHIFT = 4; + ASSERT((uint32_t)sender_buffer_channel.get_channel_bytes_sent_address() == + ((uint32_t)sender_buffer_channel.get_buffer_address() + (uint32_t)sender_buffer_channel.get_current_payload_size() - (uint32_t)sizeof(eth_channel_sync_t))); + *sender_buffer_channel.get_channel_bytes_sent_address() = sender_buffer_channel.get_current_payload_size(); + *sender_buffer_channel.get_channel_bytes_acked_address() = 0; + eth_send_bytes_over_channel_payload_only( sender_buffer_channel.get_buffer_address(), sender_buffer_channel.get_remote_eth_buffer_address(), sender_buffer_channel.get_current_payload_size(), - sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size(), sender_buffer_channel.get_current_payload_size() >> ETH_BYTES_TO_WORDS_SHIFT); - sender_buffer_channel.set_send_completion_pending(true); - need_to_send_completion = true; - did_something = true; - } - - if (need_to_send_completion && !eth_txq_is_busy()) { - eth_send_payload_complete_signal_over_channel( - sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size()); - sender_buffer_channel.set_send_completion_pending(false); - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + sender_buffer_channel.advance_buffer_index(); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_WAITING_FOR_ETH); did_something = true; } } @@ -354,6 +426,7 @@ FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sende template FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + bool channel_done = false; if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { channel_done = sender_buffer_channel.all_messages_moved(); @@ -367,7 +440,7 @@ FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( sender_buffer_channel.increment_worker_semaphores(); if (!channel_done) { - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_WAITING_FOR_WORKER); } else { sender_buffer_channel.goto_state(ChannelBuffer::DONE); num_senders_complete++; @@ -384,19 +457,19 @@ FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( bool transimission_acked_by_receiver = sender_buffer_channel.eth_is_receiver_channel_send_acked() || sender_buffer_channel.eth_is_receiver_channel_send_done(); if (transimission_acked_by_receiver) { - eth_clear_sender_channel_ack(sender_buffer_channel.get_eth_transaction_channel()); + sender_buffer_channel.eth_clear_sender_channel_ack(); sender_buffer_channel.increment_messages_moved(); - sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_SIGNALING_WORKER); + + // Don't need to guard as we can unconditionally notify the workers right away now that + // we're in the current state sender_notify_workers_if_buffer_available_sequence(sender_buffer_channel, num_senders_complete); - did_something = true; } return did_something; } -/* - * - */ + template FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( ChannelBuffer &sender_channel_buffer, uint32_t &num_senders_complete) { @@ -413,15 +486,14 @@ FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( bool read_finished = sender_channel_buffer.is_local_semaphore_full(); if (read_finished) { - // We can clear the semaphore, and wait for space on receiver - // sender_channel_buffer.clear_local_semaphore(); - sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); - did_something = true; + sender_channel_buffer.goto_state(ChannelBuffer::SENDER_READY_FOR_ETH_TRANSFER); erisc::datamover::sender_eth_send_data_sequence(sender_channel_buffer); + did_something = true; } return did_something; + } ///////////////////////////////////////////// @@ -434,9 +506,10 @@ FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( template FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { buffer_channel.clear_local_semaphore(); + uint32_t worker_semaphore_address = buffer_channel.worker_semaphore_l1_address; buffer_channel.increment_worker_semaphores(); - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_WAITING_FOR_WORKER); return true; } @@ -453,8 +526,8 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( if (buffer_channel.eth_bytes_are_available_on_channel()) { if (!eth_txq_is_busy()) { - eth_receiver_channel_ack(buffer_channel.get_eth_transaction_channel(), eth_transaction_ack_word_addr); - buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + buffer_channel.eth_receiver_channel_ack(eth_transaction_ack_word_addr); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_SIGNALING_WORKER); did_something = true; // FIXME: Decouple these so we can still signal workers even if eth command queue is busy @@ -467,6 +540,7 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( return did_something; } + /* * Does something if we are waiting for workers to complete their read and the read is complete: * - increment messages moved (that transfer is done) @@ -476,8 +550,7 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( template FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( ChannelBuffer &buffer_channel, - uint32_t &num_receivers_complete, - uint32_t eth_transaction_complete_addr) { + uint32_t &num_receivers_complete) { bool did_something = false; bool workers_are_finished_reading = buffer_channel.is_local_semaphore_full(); @@ -492,9 +565,11 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( bool can_notify_sender_of_buffer_available = workers_are_finished_reading; if (can_notify_sender_of_buffer_available) { if (!eth_txq_is_busy()) { - eth_receiver_channel_done(buffer_channel.get_eth_transaction_channel()); + buffer_channel.eth_receiver_channel_done(); buffer_channel.increment_messages_moved(); + buffer_channel.advance_buffer_index(); + bool channel_done = false; if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { channel_done = buffer_channel.all_messages_moved(); @@ -505,12 +580,11 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( } if (!channel_done) { - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_WAITING_FOR_ETH); } else { buffer_channel.goto_state(ChannelBuffer::DONE); num_receivers_complete++; } - did_something = true; } } @@ -518,6 +592,7 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( return did_something; } + //////////////////////////// // DEPRECATED //////////////////////////// @@ -658,7 +733,6 @@ bool receiver_eth_accept_payload_sequence( if (!receive_pointers_full) { if (eth_bytes_are_available_on_channel(eth_receiver_ptr.index())) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)eth_receiver_ptr << "\n"; eth_receiver_channel_ack(eth_receiver_ptr.index(), eth_channel_sync_ack_addr); eth_receiver_ptr.increment(); did_something = true; @@ -683,8 +757,6 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( bool writes_finished = ncrisc_noc_nonposted_writes_sent(noc_index); #endif if (writes_finished) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_ackptr - // << "\n"; noc_writer_buffer_ackptr.increment(); did_something = true; @@ -708,13 +780,9 @@ FORCE_INLINE bool receiver_eth_send_ack_to_sender_sequence( bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); // bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_flushed(noc_index); if (buffer_writes_flushed) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_wrptr - // << "\n"; eth_receiver_channel_done(eth_receiver_ackptr.index()); num_eth_sends_acked++; eth_receiver_ackptr.increment(); - // DPRINT << "rx: Sending eth ack. ackptr incrementing to " << (uint32_t)eth_receiver_ackptr.index() << - // "\n"; did_something = true; } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index 098103f5004..6017a9a4ef1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -109,8 +109,6 @@ struct sender_receiver_index_t { void kernel_main() { - // COMPILE TIME ARGS - // If true, will enable this erisc's sender functionality constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; // If true, will enable this erisc's receiver functionality @@ -119,13 +117,22 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = static_cast(get_compile_time_arg_val(4)); - constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = static_cast(get_compile_time_arg_val(5)); - using EDM_CONFIG_T = erisc::datamover::EriscDatamoverConfig; + static constexpr bool use_compile_time_designated_handshake_sender = false;//get_compile_time_arg_val(6) != 0; + static constexpr bool is_handshake_sender = get_compile_time_arg_val(7) != 0; + + static constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(8); + static constexpr uint32_t chip_id = get_compile_time_arg_val(9); + + static_assert(num_buffers_per_channel > 0, "compile time argument [9]: num_buffers_per_channel must be > 0"); + + using EDM_CONFIG_T = erisc::datamover:: + EriscDatamoverConfig; using ChannelBufferT = erisc::datamover::ChannelBuffer; std::array buffer_channels; @@ -233,23 +240,23 @@ void kernel_main() { if constexpr (enable_sender_side) { ChannelBufferT ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; switch (current_sender.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::SENDER_WAITING_FOR_WORKER: did_something_sender = erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case ChannelBufferT::STATE::READY_FOR_ETH_TRANSFER: + case ChannelBufferT::STATE::SENDER_READY_FOR_ETH_TRANSFER: did_something_sender = erisc::datamover::sender_eth_send_data_sequence(current_sender); break; - case ChannelBufferT::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::SENDER_SIGNALING_WORKER: did_something_sender = erisc::datamover::sender_notify_workers_if_buffer_available_sequence( current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case ChannelBufferT::STATE::WAITING_FOR_ETH: + case ChannelBufferT::STATE::SENDER_WAITING_FOR_ETH: did_something_sender = erisc::datamover::sender_eth_check_receiver_ack_sequence(current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; @@ -266,19 +273,19 @@ void kernel_main() { ChannelBufferT ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; switch (current_receiver.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_ETH: + case ChannelBufferT::STATE::RECEIVER_WAITING_FOR_ETH: did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, num_receivers_complete, eth_transaction_ack_word_addr); receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; break; - case ChannelBufferT::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::RECEIVER_SIGNALING_WORKER: did_something_receiver = erisc::datamover::receiver_eth_notify_workers_payload_available_sequence(current_receiver); break; - case ChannelBufferT::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::RECEIVER_WAITING_FOR_WORKER: did_something_receiver = erisc::datamover::receiver_noc_read_worker_completion_check_sequence( - current_receiver, num_receivers_complete, eth_transaction_complete_addr); + current_receiver, num_receivers_complete); receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; break; @@ -301,24 +308,38 @@ void kernel_main() { } } - for (uint32_t s = 0; s < num_senders + num_receivers; s++ ) { - auto const& channel = buffer_channels[s]; - // We need to explicitly check for channel send done because we may - // advance sender channel state as soon as we receive an ack. Since we - // may be the last active channel, and advance to done state just from ack - // from the receiver ("I got a payload"), then we need to wait for done - // at the very end here. Otherise if we invoke another erisc op back-to-back, - // we may mess up transaction state because it's possible for receiver of this - // op to send the completion done after that one has already started. - uint32_t wait_count = 0; - uint32_t wait_max = 50000; - while(!channel.eth_is_receiver_channel_send_done()) { - wait_count++; - if (wait_count > wait_max) { - - DEBUG_STATUS("STK"); - run_routing(); + { + for (uint32_t s = 0; s < num_senders + num_receivers; s++) { + auto &channel = buffer_channels[s]; + // We need to explicitly check for channel send done because we may + // advance sender channel state as soon as we receive an ack. Since we + // may be the last active channel, and advance to done state just from ack + // from the receiver ("I got a payload"), then we need to wait for done + // at the very end here. Otherise if we invoke another erisc op back-to-back, + // we may mess up transaction state because it's possible for receiver of this + // op to send the completion done after that one has already started. + uint32_t wait_count = 0; + uint32_t wait_max = 5000000; + for (uint8_t buffer_index = 0; buffer_index < num_buffers_per_channel; buffer_index++) { wait_count = 0; + channel.buffer_index = buffer_index; + if (!channel.is_sender_side) { + if (!channel.eth_is_receiver_channel_send_done()) { + channel.eth_receiver_channel_done(); + } + } + } + for (uint8_t buffer_index = 0; buffer_index < num_buffers_per_channel; buffer_index++) { + if (channel.is_sender_side) { + while (!channel.eth_is_receiver_channel_send_done()) { + wait_count++; + if (wait_count > wait_max) { + DEBUG_STATUS("STK"); + run_routing(); + wait_count = 0; + } + } + } } } } diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index 0a6ff1a1fc0..4248fdae0e0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -58,10 +58,7 @@ struct WorkerTransferInfo { static std::size_t decide_number_of_edm_channels( ttnn::ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { - return ccl_op_config.is_input_sharded() ? std::min( - ccl_op_config.get_shard_grid_size(), - std::min(max_num_workers, enable_bidirectional ? 8 : 4)) - : std::min(max_num_workers, enable_bidirectional ? 8 : 4); + return std::min(max_num_workers, enable_bidirectional ? 8 : 4); } struct ReduceScatterWorkerArgBuilder { @@ -377,7 +374,7 @@ static void add_worker_config_to_edm_builders( device->worker_core_from_logical_core(worker_cores.at(w)).y)); } - // Get the expected message size in bytes for this worker + // Get the maximum message size we'd like to use. Not the actual packet size uint32_t expected_message_size_bytes = tensor_slicer.get_worker_slice_size_bytes(global_worker_idx); bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear @@ -385,7 +382,7 @@ static void add_worker_config_to_edm_builders( auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) : counter_clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding sender EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = sender_edm_builder.add_sender_channel( worker_sender_semaphore_id, 1, // cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), @@ -403,7 +400,7 @@ static void add_worker_config_to_edm_builders( ? counter_clockwise_edm_builders.at(link) : clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding receiver EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel( worker_receiver_semaphore_id, // Since we are in worker signal EDM termination mode, we don't need to set the actual number of @@ -726,17 +723,21 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); - uint32_t per_step_dim_size = input_tensor.get_legacy_shape()[scatter_split_dim] / ring_size; - uint32_t input_tensor_num_units_per_scatter_dim = - per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout - TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); - uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); + // // The input tensor is fractured by ring_size so we divi + std::size_t input_tensor_n_elems_per_slice = input_tensor.volume() / ring_size; + uint32_t input_tensor_num_units_per_tensor_slice = + input_tensor_n_elems_per_slice / (tt::constants::TILE_WIDTH * tt::constants::TILE_HEIGHT); + + TT_ASSERT(input_tensor_num_units_per_tensor_slice > 0); + uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_tensor_slice); bool enable_bidirectional = true; auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional); log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - auto edm_termination_mode =ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + auto edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + constexpr std::size_t num_buffers_per_channel = 1; // enable double buffering later auto const& edm_builder = create_erisc_datamover_builder( - num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); + num_edm_channels, op_config.get_page_size(), num_buffers_per_channel, buffer_sharing_mode, edm_termination_mode); TT_ASSERT(num_edm_channels > 0); Tensor const& local_chip_tensor = input_tensor; @@ -757,9 +758,6 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); - auto dim_slice_factors = tt::tt_metal::Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); - dim_slice_factors[-1] = ring_size; - CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); @@ -870,7 +868,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( worker_receiver_kernels.push_back(receiver_kernel_id); worker_sender_kernels.push_back(sender_kernel_id); - TT_ASSERT(is_cb_buffering_sufficient_to_avoid_deadlock( + TT_FATAL(is_cb_buffering_sufficient_to_avoid_deadlock( worker_slice, cb_num_pages, cb_num_pages, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp index 7454a5bb0a3..ee5b8fc3fd8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp @@ -9,7 +9,6 @@ #include "dataflow_api.h" #include "debug/assert.h" #include "impl/buffers/buffer_constants.hpp" -#include "tensix_types.h" #include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" @@ -319,6 +318,7 @@ void kernel_main() { uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); ASSERT(n_pages > 0); // Fetch from input tensor + read_wrapped_chunk_from_output_tensor( curr_tile_id, offset_into_worker_slice, diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index 3b933b91882..8e92d5dc568 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -9,6 +9,20 @@ #include #include +/* + * ------ ATTENTION ATTENTION ATTENTION ATTENTION ATTENTION ------ + * This file is intended to be useable across both host and device code. Therefore. + * + * DO NOT include any headers that are not host/device agnostic. + * DO NOT use any types that do not have fixed sizes across host and device. + * e.g. int32_t -> good (always 32 bits), int -> bad (size depends on platform) + * + * The reason for dual inclusion across host/device is because this code is used + * on device, but is further tested on host through gtests. This enables us to + * sweep functionality quickly and easily without involving end-to-end device kernel + * invocations and program creation. + */ + namespace ttnn { namespace ccl { @@ -38,12 +52,12 @@ struct WorkerXY { uint16_t x; uint16_t y; - WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} + constexpr WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} - uint32_t to_uint32() const { return (y << 16) | x; } + constexpr uint32_t to_uint32() const { return (y << 16) | x; } - bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } - bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } + constexpr bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } + constexpr bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } }; struct coord_t { diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp index b98fff16766..ce0730f32da 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 +#pragma once #include #include