diff --git a/.github/workflows/ttnn-post-commit.yaml b/.github/workflows/ttnn-post-commit.yaml index 31a8b635cea..378f53732ee 100644 --- a/.github/workflows/ttnn-post-commit.yaml +++ b/.github/workflows/ttnn-post-commit.yaml @@ -60,6 +60,8 @@ jobs: fast_runtime_mode_off: true - name: ttnn examples and cpp tests cmd: ./build/test/ttnn/unit_tests_ttnn && ./tests/scripts/run_ttnn_examples.sh + - name: ttnn ccl cpp unit tests + cmd: ./build/test/ttnn/unit_tests_ttnn_ccl name: ${{ matrix.test-group.name }} ${{ inputs.arch }} ${{ inputs.runner-label }} env: TT_METAL_ENV: ${{ vars.TT_METAL_ENV }} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f26ba585ef7..a2500c51947 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,5 +8,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_eager) # this should go away and be replaced with link to ttnn add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ttnn/unit_tests/gtests) -set(TESTS_DEPENDS_LIST metal_tests eager_tests unit_tests_ttnn test_multi_device galaxy_unit_tests_ttnn ttnn watcher_dump) +set(TESTS_DEPENDS_LIST metal_tests eager_tests unit_tests_ttnn unit_tests_ttnn_ccl test_multi_device galaxy_unit_tests_ttnn ttnn watcher_dump) add_custom_target(tests DEPENDS ${TESTS_DEPENDS_LIST}) diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 06f237638c7..d7a8ce42166 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -34,6 +34,7 @@ run_t3000_ttnn_tests() { echo "LOG_METAL: Running run_t3000_ttnn_tests" WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml ./build/test/ttnn/test_multi_device WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml ./build/test/ttnn/unit_tests_ttnn + ./build/test/ttnn/unit_tests_ttnn_ccl WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_trace.py ; fail+=$? WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/test_multi_device_events.py ; fail+=$? pytest -n auto tests/ttnn/unit_tests/test_multi_device.py ; fail+=$? diff --git a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp index ba32f967434..3cd023cfa67 100644 --- a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp @@ -14,7 +14,8 @@ TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingM ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; ttnn::ccl::EriscDataMoverTerminationMode termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; - auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, buffer_sharing_mode, termination_mode); + std::size_t num_buffers_per_channel = 1; + auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, num_buffers_per_channel, buffer_sharing_mode, termination_mode); std::vector worker_semaphore_ids = {0, 1, 2, 3}; std::vector message_counts = {256, 512, 24, 1}; std::vector> const& worker_coords = { diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp deleted file mode 100644 index d29943843c6..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_reader.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" -#include "debug/dprint.h" - -FORCE_INLINE void fetch_chunk( - const uint32_t max_pages_per_chunk, - const uint32_t total_pages_to_read, - uint32_t& num_pages_read, - const uint32_t& cb_id, - const uint32_t& page_size, - uint64_t remote_l1_read_addr) { - const uint32_t num_pages_this_chunk = std::min(total_pages_to_read - num_pages_read, max_pages_per_chunk); - - for (uint32_t i = 0; i < num_pages_this_chunk; ++i) { - cb_reserve_back(cb_id, 1); - uint32_t l1_write_addr = get_write_ptr(cb_id); - noc_async_read(remote_l1_read_addr, l1_write_addr, page_size); - remote_l1_read_addr += page_size; - noc_async_read_barrier(); - cb_push_back(cb_id, 1); - } - - num_pages_read += num_pages_this_chunk; -} - -void kernel_main() { - const uint32_t eth_receiver_l1_base_addr = get_compile_time_arg_val(0); - const uint32_t eth_receiver_l1_sem_addr = get_compile_time_arg_val(1); - const uint32_t num_pages_per_read_chunk = get_arg_val(0); - const uint32_t total_pages_to_read = get_arg_val(1); - const uint32_t page_size = get_arg_val(2); - const uint32_t receiver_erisc_datamover_noc_x = get_arg_val(3); - const uint32_t receiver_erisc_datamover_noc_y = get_arg_val(4); - // Worker local L1 semaphore that erisc datamover signals to - const uint32_t receiver_read_sem_addr = get_semaphore(get_arg_val(5)); - - DPRINT << " rwr: args: eth_receiver_l1_base_addr="<< - eth_receiver_l1_base_addr<< - "\n\teth_receiver_l1_sem_addr="<(receiver_read_sem_addr); - - // Address of the buffer on the eth receiver, this is different per receiver worker core - const uint64_t eth_receiver_l1_base_noc_addr = - get_noc_addr(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y, eth_receiver_l1_base_addr); - // Address of the semaphore on the eth receiver, this is the same per receiver worker core - const uint64_t eth_receiver_l1_semaphore_noc_addr = - get_noc_addr(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y, eth_receiver_l1_sem_addr); - - DPRINT << " rwr: noc_index " << (uint32_t)noc_index << "\n"; - DPRINT << " rwr: my_x[0],my_y[0] " << (uint32_t)my_x[0] << "," << (uint32_t)my_y[0] << "\n"; - DPRINT << " rwr: my_x[1],my_y[1] " << (uint32_t)my_x[1] << "," << (uint32_t)my_y[1] << "\n"; - uint32_t num_pages_read = 0; - while (num_pages_read < total_pages_to_read) { - DPRINT << " rwr: page " << num_pages_read << " waiting for semaphore at " << (uint32_t)receiver_read_sem_addr << "\n"; - noc_semaphore_wait(receiver_read_semaphore_addr_ptr, 1); - DPRINT << " rwr: got semaphore signal from sender erisc\n"; - noc_semaphore_set(receiver_read_semaphore_addr_ptr, 0); - // Read page by page so that writer can be kicked off instead of being blocked waiting for full chunk to be read - // Look into perf/optimizations for this - DPRINT << " rwr: fetch chunk\n"; - fetch_chunk( - num_pages_per_read_chunk, - total_pages_to_read, - num_pages_read, - cb_id_in0, - page_size, - eth_receiver_l1_base_noc_addr); - DPRINT << " rwr: increment semaphore on eth core at address " << eth_receiver_l1_sem_addr << "\n"; - noc_semaphore_inc(eth_receiver_l1_semaphore_noc_addr, 1); - } - -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp deleted file mode 100644 index ab0f6a92f5d..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_receiver_worker_sender.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -#include "debug/dprint.h" - -void kernel_main() { - const uint32_t dst_addr = get_arg_val(0); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t num_pages_total = get_compile_time_arg_val(1); - constexpr uint32_t page_size = get_compile_time_arg_val(2); - - constexpr uint32_t cb_id_in0 = tt::CB::c_in0; - InterleavedAddrGen dest_addr_generator = { - .bank_base_address = dst_addr, .page_size = page_size}; - DPRINT << " rws: args: " << - "\n\tdst_addr="<(l1_read_addr) << "\n"; - uint64_t dst_noc_addr = get_noc_addr(p, dest_addr_generator); - noc_async_write(l1_read_addr, dst_noc_addr, page_size); - DPRINT << "rws: write barrier complete\n"; - noc_async_write_barrier(); - DPRINT << "rws: cb_pop_front\n"; - cb_pop_front(cb_id_in0, 1); - } - - // DPRINT << "rws: DONE\n"; - // ncrisc_noc_full_sync(); - // DPRINT << "rws: DONE DONE\n"; -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp deleted file mode 100644 index de99988c729..00000000000 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_sender.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "noc_nonblocking_api.h" -#include "noc_parameters.h" - -FORCE_INLINE void send_chunk( - const uint32_t max_pages_per_chunk, - const uint32_t total_pages_to_send, - uint32_t& num_pages_sent, - const uint32_t& cb_id, - const uint32_t& page_size, - uint64_t remote_l1_write_addr, - volatile tt_l1_ptr uint32_t* writer_send_semaphore_addr_ptr - ) { - - const uint32_t num_pages_this_chunk = std::min(total_pages_to_send - num_pages_sent, max_pages_per_chunk); - for (uint32_t i = 0; i < num_pages_this_chunk; ++i) { - cb_wait_front(cb_id, 1); - uint32_t l1_read_addr = get_read_ptr(cb_id); - noc_async_write(l1_read_addr, remote_l1_write_addr, page_size); - remote_l1_write_addr += page_size; - noc_async_write_barrier(); - cb_pop_front(cb_id, 1); - } - num_pages_sent += num_pages_this_chunk; -} - -// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). -// -> takes input from local cb and pushes to erisc L1 -void kernel_main() { - const uint32_t eth_l1_base_addr = get_arg_val(0); - // erisc l1 semaphore address - const uint32_t eth_sender_l1_sem_addr = get_arg_val(1); - const uint32_t writer_send_sem_addr = get_semaphore(get_arg_val(2)); - const uint32_t eth_sender_noc_x = get_arg_val(3); - const uint32_t eth_sender_noc_y = get_arg_val(4); - - constexpr uint32_t num_pages_per_send = get_compile_time_arg_val(0); - constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); - constexpr uint32_t page_size = get_compile_time_arg_val(2); - - DPRINT << " sws: args:" << - "\n\teth_sender_l1_base_addr="<(writer_send_sem_addr); - - // This is different per writer core - const uint64_t eth_l1_sender_base_noc_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_l1_base_addr); - // Used to signal eth sender that data is available. This is different per writer core - const uint64_t eth_l1_sender_semaphore_addr = - get_noc_addr(eth_sender_noc_x, eth_sender_noc_y, eth_sender_l1_sem_addr); - - // num_transfers = num_devices - 1 - uint32_t num_pages_sent = 0; - DPRINT << " sws: noc_index " << (uint32_t)noc_index << "\n"; - DPRINT << " sws: my_x[0],my_y[0] " << (uint32_t)my_x[0] << "," << (uint32_t)my_y[0] << "\n"; - DPRINT << " sws: my_x[1],my_y[1] " << (uint32_t)my_x[1] << "," << (uint32_t)my_y[1] << "\n"; - - uint32_t old_val_NIU_SLV_CMD_ACCEPTED = NOC_STATUS_READ_REG(noc_index, NIU_SLV_REQ_ACCEPTED); - uint32_t old_val_NIU_SLV_ATOMIC_RESP_SENT = NOC_STATUS_READ_REG(noc_index, NIU_SLV_ATOMIC_RESP_SENT); - uint32_t old_val_NIU_SLV_POSTED_ATOMIC_RECEIVED = NOC_STATUS_READ_REG(noc_index, NIU_SLV_POSTED_ATOMIC_RECEIVED); - uint32_t old_val_NIU_SLV_NONPOSTED_ATOMIC_SENT = NOC_STATUS_READ_REG(noc_index, NIU_SLV_NONPOSTED_ATOMIC_RECEIVED); - - bool diffed_NIU_SLV_CMD_ACCEPTED = false; - bool diffed_NIU_SLV_ATOMIC_RESP_SENT = false; - bool diffed_NIU_SLV_POSTED_ATOMIC_RECEIVED = false; - bool diffed_NIU_SLV_NONPOSTED_ATOMIC_SENT = false; - while (num_pages_sent < total_pages_to_send) { - noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); - - noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); - send_chunk(num_pages_per_send, total_pages_to_send, num_pages_sent, cb_id_in0, page_size, eth_l1_sender_base_noc_addr, writer_send_semaphore_addr_ptr); - noc_semaphore_inc(eth_l1_sender_semaphore_addr, 1); - } -} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp index e0936a79645..b2ec93643a7 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_bidirectional_ubench.cpp @@ -120,7 +120,6 @@ void kernel_main() { channel_addrs[sender_channel], channel_addrs[sender_channel], message_size_payload, - sender_channel, message_size_payload, message_size_payload_eth_words + 1); ready_to_send_payload &= ~(1 << s_i); diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp index 6777713fc3c..2328352269f 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_receiver.cpp @@ -46,7 +46,6 @@ FORCE_INLINE void run_loop_iteration( reinterpret_cast(channel_sync_addrs[i]), reinterpret_cast(channel_sync_addrs[i]), sizeof(eth_channel_sync_t), - i, // remove this field - it's superfluous sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t) >> 4); } @@ -67,7 +66,6 @@ FORCE_INLINE void run_loop_iteration( reinterpret_cast(channel_sync_addrs[i]), reinterpret_cast(channel_sync_addrs[i]), sizeof(eth_channel_sync_t), - i, // remove this field - it's superfluous sizeof(eth_channel_sync_t), sizeof(eth_channel_sync_t) >> 4); } diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp index 72aa3da064c..ed9923692a5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_ping_latency_ubench_sender.cpp @@ -40,7 +40,6 @@ FORCE_INLINE void run_loop_iteration( channel_addrs[i], channel_addrs[i], full_payload_size, - i, full_payload_size, full_payload_size_eth_words); } @@ -60,7 +59,6 @@ FORCE_INLINE void run_loop_iteration( channel_addrs[i], channel_addrs[i], full_payload_size, - i, full_payload_size, full_payload_size_eth_words); } diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 5e5be0e16bd..070659e6431 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -8,8 +8,12 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_reflect.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_to_and_from_json.cpp ) +set(TTNN_CCL_UNIT_TESTS_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp +) add_executable(unit_tests_ttnn ${TTNN_UNIT_TESTS_SRC}) +add_executable(unit_tests_ttnn_ccl ${TTNN_CCL_UNIT_TESTS_SRC}) add_executable(test_multi_device ${CMAKE_CURRENT_SOURCE_DIR}/test_multi_device.cpp) add_executable(galaxy_unit_tests_ttnn ${CMAKE_CURRENT_SOURCE_DIR}/test_ccl_on_tg.cpp) @@ -28,5 +32,6 @@ endfunction() # Set up properties for both targets setup_ttnn_test_target(unit_tests_ttnn) +setup_ttnn_test_target(unit_tests_ttnn_ccl) setup_ttnn_test_target(test_multi_device) setup_ttnn_test_target(galaxy_unit_tests_ttnn) diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp new file mode 100644 index 00000000000..481457ac8cd --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" + +void kernel_main() { + constexpr uint32_t eth_receiver_l1_base_addr = get_compile_time_arg_val(0); + constexpr uint32_t eth_receiver_l1_sem_addr = get_compile_time_arg_val(1); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(2); + constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast(get_compile_time_arg_val(3)); + const uint32_t num_pages_per_read_chunk = get_arg_val(0); + const uint32_t total_pages_to_read = get_arg_val(1); + const uint32_t page_size = get_arg_val(2); + const uint32_t receiver_erisc_datamover_noc_x = get_arg_val(3); + const uint32_t receiver_erisc_datamover_noc_y = get_arg_val(4); + // Worker local L1 semaphore that erisc datamover signals to + volatile uint32_t* const receiver_read_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(5))); + const uint32_t num_buffers_per_edm_channel = get_arg_val(6); + + ccl::edm::WorkerToEdmReader reader( + ttnn::ccl::WorkerXY(receiver_erisc_datamover_noc_x, receiver_erisc_datamover_noc_y), + eth_receiver_l1_base_addr, + num_buffers_per_channel, + eth_receiver_l1_sem_addr, + num_pages_per_read_chunk * page_size, + receiver_read_sem_addr); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + for (uint32_t i = 0; i < total_pages_to_read; i += num_pages_per_read_chunk) { + bool last_message = (i + num_pages_per_read_chunk) >= total_pages_to_read; + uint32_t num_pages_to_read = std::min(total_pages_to_read - i, num_pages_per_read_chunk); + reader.wait_for_payload_available(); + reader.fetch_payload_blocking(cb_id_in0, num_pages_to_read, page_size, last_message); + } + + reader.close(); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp new file mode 100644 index 00000000000..fecb458407b --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + const uint32_t dst_addr = get_arg_val(0); + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t num_pages_total = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = get_compile_time_arg_val(3); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + InterleavedAddrGen dest_addr_generator = { + .bank_base_address = dst_addr, .page_size = page_size}; + + for (uint32_t p = 0; p < num_pages_total; p += pages_per_edm_buffer) { + uint32_t num_pages_to_send = std::min(pages_per_edm_buffer, num_pages_total - p); + cb_wait_front(cb_id_in0, num_pages_to_send); + uint32_t l1_read_addr = get_read_ptr(cb_id_in0); + + for (uint32_t i = 0; i < num_pages_to_send; ++i) { + uint64_t dst_noc_addr = get_noc_addr(p + i, dest_addr_generator); + noc_async_write(l1_read_addr, dst_noc_addr, page_size); + l1_read_addr += page_size; + } + noc_async_write_barrier(); + + cb_pop_front(cb_id_in0, num_pages_to_send); + } + +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp similarity index 62% rename from tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp rename to tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp index fc875d64f87..41d453e2793 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/erisc_datamover_sender_worker_reader.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp @@ -7,11 +7,14 @@ #include "debug/dprint.h" void kernel_main() { - const uint32_t src_addr = get_arg_val(0); constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; constexpr uint32_t num_pages_to_read_total = get_compile_time_arg_val(1); constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = get_compile_time_arg_val(3); constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + const uint32_t src_addr = get_arg_val(0); + const InterleavedAddrGen source_address_generator = { .bank_base_address = src_addr, .page_size = page_size}; @@ -21,19 +24,22 @@ void kernel_main() { "\n\tnum_pages_to_read_total="<(pages_per_edm_buffer, num_pages_to_read_total - num_pages_read); + cb_reserve_back(cb_id_in0, pages_to_read); uint32_t local_l1_read_addr = get_write_ptr(cb_id_in0); - uint64_t src_noc_addr = get_noc_addr(num_pages_read, source_address_generator); - DPRINT << "swr: async_read\n"; - noc_async_read(src_noc_addr, local_l1_read_addr, page_size); - DPRINT << "swr: read_barrier\n"; + for (uint32_t p = 0; p < pages_to_read; ++p) { + + uint64_t src_noc_addr = get_noc_addr(num_pages_read + p, source_address_generator); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); + local_l1_read_addr += page_size; + } noc_async_read_barrier(); - DPRINT << "swr: cb_push_back\n"; - cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in0, pages_to_read); + // DPRINT << "SR " << num_pages_read << "\n"; } + DPRINT << "SR DONE\n"; } diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp new file mode 100644 index 00000000000..4cff4c2ec51 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" + + + +// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). +// -> takes input from local cb and pushes to erisc L1 +void kernel_main() { + const uint32_t eth_l1_base_addr = get_arg_val(0); + // erisc l1 semaphore address + const uint32_t eth_sender_l1_sem_addr = get_arg_val(1); + volatile uint32_t* const writer_send_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(2))); + const uint32_t eth_sender_noc_x = get_arg_val(3); + const uint32_t eth_sender_noc_y = get_arg_val(4); + const uint32_t num_buffers_per_edm_channel = get_arg_val(5); + + constexpr uint32_t num_pages_per_send = get_compile_time_arg_val(0); + constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3); + constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast(get_compile_time_arg_val(4)); + + ccl::edm::WorkerToEdmSender sender( + ttnn::ccl::WorkerXY(eth_sender_noc_x, eth_sender_noc_y), + eth_l1_base_addr, + num_buffers_per_channel, + eth_sender_l1_sem_addr, + num_pages_per_send * page_size, + writer_send_sem_addr); + + std::array eth_buffer_addresses; + for (uint32_t i = 0; i < num_buffers_per_channel; i++) { + eth_buffer_addresses[i] = get_noc_addr( + eth_sender_noc_x, + eth_sender_noc_y, + eth_l1_base_addr + (i * ((num_pages_per_send * page_size) + 16)));//sizeof(eth_channel_sync_t)))); + } + + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + + uint32_t buffer_index = 0; + for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) { + uint32_t num_pages_to_send = std::min(num_pages_per_send, total_pages_to_send - p); + sender.wait_for_empty_write_slot(); + sender.send_payload_blocking(cb_id_in0, num_pages_to_send, page_size); + } + + sender.close(); +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp new file mode 100644 index 00000000000..f78c198258c --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp @@ -0,0 +1,1178 @@ + +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "device/tt_arch_types.h" +// #include "tt_backend_api_types.hpp" +#include "tt_metal/common/core_coord.h" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +#include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/test_utils/stimulus.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" + +// #include "impl/kernels/kernel_types.hpp" + +using namespace tt; +using namespace tt::test_utils; +using namespace tt::test_utils::df; + +// Taken from ccl_common... some dependency annoyance to deal with so just copying it here for now... resolve before merging +namespace ttnn { +namespace ccl { +void set_edm_runtime_args( + tt_metal::Program& program, + KernelHandle edm_kernel_handle, + ccl::EriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core +) { + std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); + tt_metal::SetRuntimeArgs(program, edm_kernel_handle, eth_core, edm_clockwise_kernel_rt_args); + + std::stringstream ss; + ss << "EDM ARGS:\n"; + for (auto const& s : edm_clockwise_kernel_rt_args) { + ss << "\t" << s << "\n"; + } + log_info(tt::LogOp, "{}", ss.str()); +} + +} // namespace ccl +} // namespace ttnn + + +class N300TestDevice { + public: + N300TestDevice() : device_open(false) { + arch_ = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + + num_devices_ = tt::tt_metal::GetNumAvailableDevices(); + if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() >= 2 and + tt::tt_metal::GetNumPCIeDevices() >= 1) { + std::vector ids(num_devices_,0); + std::iota(ids.begin(), ids.end(), 0); + devices_ = tt::tt_metal::detail::CreateDevices(ids); + + } else { + TT_THROW("This suite can only be run on N300 Wormhole devices"); + } + device_open = true; + } + ~N300TestDevice() { + if (device_open) { + TearDown(); + } + } + + void TearDown() { + device_open = false; + for (auto [device_id, device_ptr] : devices_) { + tt::tt_metal::CloseDevice(device_ptr); + } + } + + std::map devices_; + tt::ARCH arch_; + size_t num_devices_; + + private: + bool device_open; +}; + +struct BankedConfig { + size_t num_pages; + size_t size_bytes; + size_t page_size_bytes; + BufferType input_buffer_type; // = BufferType::L1; + BufferType output_buffer_type; // = BufferType::L1; + tt::DataFormat l1_data_format; // = tt::DataFormat::Float16_b; +}; + +struct KernelXY { + uint16_t x; + uint16_t y; + + uint32_t to_uint32() const { return y << 16 | x; } +}; + +void generate_receiver_worker_kernels( + Program &program, + Device *device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& edm_channel, + uint32_t page_size, + uint32_t num_pages, + std::size_t num_buffers_per_edm_channel, + uint32_t num_pages_per_edm_buffer, + uint32_t worker_semaphore_address, + uint32_t dram_output_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool dest_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + // Just want a dummy DF + uint32_t src0_cb_index = CB::c_in0; + tt::DataFormat df = page_size == 1024 ? tt::DataFormat::Bfp8 : + page_size == 2048 ? tt::DataFormat::Float16 : + tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size); + + CBHandle receiver_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + std::vector receiver_worker_writer_compile_args{ + dest_is_dram, // + num_pages, // + page_size, + num_pages_per_edm_buffer}; + std::vector receiver_worker_writer_runtime_args{dram_output_buffer_base_addr}; + log_info(tt::LogTest, "\tReceiverWriter CT Args"); + for (auto const& arg : receiver_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tReceiverWriter RT Args"); + for (auto const& arg : receiver_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + + std::vector receiver_worker_receiver_compile_args{ + edm_channel.eth_buffer_l1_address, + edm_channel.eth_semaphore_l1_address, + num_buffers_per_edm_channel, + edm_termination_mode + }; + std::vector receiver_worker_receiver_runtime_args{ + num_pages_per_edm_buffer, + num_pages, + page_size, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + worker_semaphore_address, + num_buffers_per_edm_channel}; + log_info(tt::LogTest, "\tReceiverReader CT Args"); + for (auto const& arg : receiver_worker_receiver_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tReceiverReader RT Args"); + for (auto const& arg : receiver_worker_receiver_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + + auto receiver_worker_receiver_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = receiver_worker_receiver_compile_args}); + auto receiver_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_receiver_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = receiver_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs( + program, + receiver_worker_receiver_kernel, + worker_core, + receiver_worker_receiver_runtime_args); + tt_metal::SetRuntimeArgs( + program, + receiver_worker_writer_kernel, + worker_core, + receiver_worker_writer_runtime_args); +} + +void generate_sender_worker_kernels( + Program &program, + Device *device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& edm_channel, + uint32_t page_size, + uint32_t num_pages_total, + std::size_t num_buffers_per_edm_channel, + uint32_t num_pages_per_edm_buffer, + uint32_t worker_semaphore_address, + uint32_t dram_output_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool src_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + std::vector sender_worker_reader_compile_args{ + src_is_dram, // + num_pages_total, // + page_size, + num_pages_per_edm_buffer}; + std::vector sender_worker_reader_runtime_args{dram_output_buffer_base_addr}; + + log_info(tt::LogTest, "\tSenderReader CT Args"); + for (auto const& arg : sender_worker_reader_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderReader RT Args"); + for (auto const& arg : sender_worker_reader_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + std::vector sender_worker_writer_compile_args{ + num_pages_per_edm_buffer, + num_pages_total, + page_size, + num_buffers_per_edm_channel, + edm_termination_mode + }; + std::vector sender_worker_writer_runtime_args{ + edm_channel.eth_buffer_l1_address, + edm_channel.eth_semaphore_l1_address, + worker_semaphore_address, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + num_buffers_per_edm_channel + }; + uint32_t src0_cb_index = CB::c_in0; + log_info(tt::LogTest, "\tSenderWriter CT Args"); + for (auto const& arg : sender_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderWriter RT Args"); + for (auto const& arg : sender_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + // Just want a dummy DF + tt::DataFormat df = page_size == 1024 ? tt::DataFormat::Bfp8 : + page_size == 2048 ? tt::DataFormat::Float16 : + tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size); + CBHandle sender_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + auto sender_worker_reader_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = sender_worker_reader_compile_args}); + auto sender_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = sender_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs( + program, + sender_worker_reader_kernel, + worker_core, + sender_worker_reader_runtime_args); + tt_metal::SetRuntimeArgs( + program, + sender_worker_writer_kernel, + worker_core, + sender_worker_writer_runtime_args); +} + + + +bool RunWriteBWTest( + tt_metal::Device* sender_device, + tt_metal::Device* receiver_device, + + const CoreCoord& eth_sender_core, + const CoreCoord& eth_receiver_core, + + const uint32_t num_local_sender_channels, + const uint32_t num_remote_sender_channels, + + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + std::size_t num_buffers_per_edm_channel, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram, + + ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode +) { + + std::size_t tensor_size_bytes = num_pages_total * page_size; + + tt_metal::Program sender_program{}; + tt_metal::Program receiver_program{}; + + std::vector worker_cores; + { + std::size_t row = 0; + std::size_t col = 0; + for (uint32_t i = 0; i < num_local_sender_channels + num_remote_sender_channels; i++) { + worker_cores.push_back(CoreCoord(col, row)); + col++; + if (col == 8) { + col = 0; + row++; + } + } + } + + std::vector local_worker_semaphore_addresses; + std::vector remote_worker_semaphore_addresses; + for (auto const& worker_core : worker_cores) { + local_worker_semaphore_addresses.push_back(tt::tt_metal::CreateSemaphore(sender_program, worker_core, 0)); + remote_worker_semaphore_addresses.push_back(tt::tt_metal::CreateSemaphore(receiver_program, worker_core, 0)); + log_info(tt::LogTest, "worker_core=(x={},y={}), local_worker_semaphore_address={}, remote_worker_semaphore_address={}", + worker_core.x, worker_core.y, local_worker_semaphore_addresses.back(), remote_worker_semaphore_addresses.back()); + } + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + auto inputs = generate_uniform_random_vector(0, 100, tensor_size_bytes / sizeof(uint32_t)); + std::iota(inputs.begin(), inputs.end(), 0); + + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + auto local_input_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + auto remote_input_buffer = CreateBuffer(InterleavedBufferConfig{ + receiver_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + bool input_is_dram = test_config.input_buffer_type == BufferType::DRAM; + + tt_metal::detail::WriteToBuffer(local_input_buffer, inputs); + tt_metal::detail::WriteToBuffer(remote_input_buffer, inputs); + + std::vector local_input_buffer_addresses(num_local_sender_channels, local_input_buffer->address()); + std::vector remote_input_buffer_addresses(num_remote_sender_channels, remote_input_buffer->address()); + + //////////////////////////////////////////////////////////////////////////// + // EMPTY INITIALIZE THE OUTPUT CB + //////////////////////////////////////////////////////////////////////////// + + // Clear expected value at ethernet L1 address + std::vector all_zeros(inputs.size(), 0); + + std::vector> local_output_buffers; + std::vector> remote_output_buffers; + + for (std::size_t i = 0; i < num_local_sender_channels; i++) { + auto output_buffer = CreateBuffer(InterleavedBufferConfig{ + receiver_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + remote_output_buffers.push_back(output_buffer); + } + for (std::size_t i = 0; i < num_remote_sender_channels; i++) { + auto output_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + local_output_buffers.push_back(output_buffer); + } + + bool output_is_dram = test_config.output_buffer_type == BufferType::DRAM; + for (auto buffer_id : local_output_buffers) { + tt_metal::detail::WriteToBuffer(buffer_id, all_zeros); + } + for (auto buffer_id : remote_output_buffers) { + tt_metal::detail::WriteToBuffer(buffer_id, all_zeros); + } + + uint32_t erisc_handshake_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + + uint32_t chip0_next_buffer_address = erisc_handshake_address + 16; + std::vector chip0_edm_args = {erisc_handshake_address}; + uint32_t chip0_sender_channels_offset = 0; + uint32_t chip0_arg_sender_num_channels = 1; + + //////////////////////////////////////////////////////////////////////////// + // EDM Builder Setup + //////////////////////////////////////////////////////////////////////////// + + ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED; + + const std::size_t num_edm_channels = num_local_sender_channels + num_remote_sender_channels; + // TODO: Allow an override of EDM buffer size + auto local_chip_edm_builder = ttnn::ccl::create_erisc_datamover_builder( + num_edm_channels, page_size, num_buffers_per_edm_channel, buffer_sharing_mode, edm_termination_mode); + auto remote_chip_edm_builder = ttnn::ccl::create_erisc_datamover_builder( + num_edm_channels, page_size, num_buffers_per_edm_channel, buffer_sharing_mode, edm_termination_mode); + + const uint32_t num_bytes_per_send = local_chip_edm_builder.get_eth_buffer_size_bytes(); + const uint32_t pages_per_send = num_bytes_per_send / page_size; + TT_ASSERT(num_bytes_per_send > 0); + TT_ASSERT(num_bytes_per_send >= page_size); + TT_ASSERT(num_bytes_per_send >= page_size); + const uint32_t num_messages_to_send = (((num_pages_total * page_size) - 1) / num_bytes_per_send) + 1; + log_info(tt::LogTest, "num_bytes_per_send={}", num_bytes_per_send); + log_info(tt::LogTest, "page_size={}", page_size); + log_info(tt::LogTest, "pages_per_send={}", pages_per_send); + log_info(tt::LogTest, "num_messages_to_send={}", num_messages_to_send); + std::vector num_messages_to_send_over_channel(num_edm_channels, num_messages_to_send); + + std::vector local_sender_workers; + std::vector remote_receiver_workers; + std::vector remote_sender_workers; + std::vector local_receiver_workers; + + // setup edm channels + std::vector local_edm_channels; + std::vector remote_edm_channels; + for (uint32_t i = 0; i < num_local_sender_channels; i++) { + auto const& worker_core_local_chip = ttnn::ccl::WorkerXY( + sender_device->worker_core_from_logical_core(worker_cores.at(i)).x, + sender_device->worker_core_from_logical_core(worker_cores.at(i)).y); + auto const& worker_core_remote_chip = ttnn::ccl::WorkerXY( + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).x, + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).y); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& local_sender_channel_buffer = local_chip_edm_builder.add_sender_channel( + local_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_local_chip}); + local_edm_channels.push_back(local_sender_channel_buffer); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& remote_receiver_channel_buffer = remote_chip_edm_builder.add_receiver_channel( + remote_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_remote_chip}); + remote_edm_channels.push_back(remote_receiver_channel_buffer); + } + for (uint32_t i = num_local_sender_channels; i < num_local_sender_channels + num_remote_sender_channels; i++) { + auto const& worker_core_remote_chip = ttnn::ccl::WorkerXY( + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).x, + receiver_device->worker_core_from_logical_core(worker_cores.at(i)).y); + auto const& worker_core_local_chip = ttnn::ccl::WorkerXY( + sender_device->worker_core_from_logical_core(worker_cores.at(i)).x, + sender_device->worker_core_from_logical_core(worker_cores.at(i)).y); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& local_receiver_channel_buffer = local_chip_edm_builder.add_receiver_channel( + local_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_remote_chip}); + local_edm_channels.push_back(local_receiver_channel_buffer); + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& remote_sender_channel_buffer = remote_chip_edm_builder.add_sender_channel( + remote_worker_semaphore_addresses.at(i), + num_messages_to_send_over_channel.at(i), + {worker_core_local_chip}); + remote_edm_channels.push_back(remote_sender_channel_buffer); + } + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + for (uint32_t i = 0; i < num_local_sender_channels; i++) { + auto const& worker_core = worker_cores.at(i); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", i, worker_core.x, worker_core.y); + generate_sender_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + local_edm_channels.at(i), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + local_worker_semaphore_addresses.at(i), + local_input_buffer_addresses.at(i), + src_is_dram, + edm_termination_mode + ); + generate_receiver_worker_kernels( + receiver_program, + receiver_device, + worker_core, + eth_receiver_core, + remote_edm_channels.at(i), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + remote_worker_semaphore_addresses.at(i), + remote_output_buffers.at(i)->address(), + dest_is_dram, + edm_termination_mode + ); + + } + log_info(tt::LogTest, "Generating remote_sender -> local_receiver workers"); + for (uint32_t i = 0; i < num_remote_sender_channels; i++) { + log_info(tt::LogTest, "Worker {}", i); + auto const& worker_core = worker_cores.at(i + num_local_sender_channels); + generate_sender_worker_kernels( + receiver_program, + receiver_device, + worker_core, + eth_receiver_core, + remote_edm_channels.at(i + num_local_sender_channels), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + remote_worker_semaphore_addresses.at(i + num_local_sender_channels), + remote_input_buffer_addresses.at(i), + src_is_dram, + edm_termination_mode + ); + + generate_receiver_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + local_edm_channels.at(i + num_local_sender_channels), + page_size, + num_pages_total, + num_buffers_per_edm_channel, + pages_per_send, + local_worker_semaphore_addresses.at(i + num_local_sender_channels), + local_output_buffers.at(i)->address(), + dest_is_dram, + edm_termination_mode + ); + } + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( + sender_program, + sender_device, + local_chip_edm_builder, + eth_sender_core, + NOC::NOC_0); + set_edm_runtime_args( + sender_program, + local_edm_kernel, + local_chip_edm_builder, + eth_sender_core + ); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + receiver_program, + receiver_device, + remote_chip_edm_builder, + eth_receiver_core, + NOC::NOC_0); + set_edm_runtime_args( + receiver_program, + remote_edm_kernel, + remote_chip_edm_builder, + eth_receiver_core + ); + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + + try { + tt::tt_metal::detail::CompileProgram(sender_device, sender_program); + tt::tt_metal::detail::CompileProgram(receiver_device, receiver_program); + } catch (std::exception& e) { + log_error("Failed compile: {}", e.what()); + throw e; + } + + log_info(tt::LogTest, "Running..."); + + if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE")) { + std::thread th2 = std::thread([&] { tt_metal::detail::LaunchProgram(sender_device, sender_program); }); + std::thread th1 = std::thread([&] { tt_metal::detail::LaunchProgram(receiver_device, receiver_program); }); + + th2.join(); + th1.join(); + } else { + tt_metal::EnqueueProgram(sender_device->command_queue(), sender_program, false); + tt_metal::EnqueueProgram(receiver_device->command_queue(), receiver_program, false); + + log_debug(tt::LogTest, "Calling Finish"); + tt_metal::Finish(sender_device->command_queue()); + tt_metal::Finish(receiver_device->command_queue()); + } + // tt::tt_metal::detail::DumpDeviceProfileResults(receiver_device); + // tt::tt_metal::detail::DumpDeviceProfileResults(sender_device); + log_info(tt::LogTest, "Reading back outputs"); + + auto is_output_correct = [&all_zeros, &inputs](std::shared_ptr output_buffer) { + constexpr bool debug_mode = false; + std::vector readback_data_vec; // init to 0 data for easier debug + readback_data_vec.reserve(all_zeros.size()); + std::fill(readback_data_vec.begin(), readback_data_vec.end(), 0); + + tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); + log_info(tt::LogTest, "Checking outputs"); + if (readback_data_vec.size() != inputs.size()) { + log_error(tt::LogTest, "Output size mismatch: expected {} got {}", inputs.size(), readback_data_vec.size()); + return false; + } + bool pass = (readback_data_vec == inputs); + TT_ASSERT( + std::any_of(inputs.begin(), inputs.end(), [](uint32_t x) { return x != 0; }), + "Input buffer expected to not be all 0"); + if (not pass) { + log_error("Output mismatch"); + if (debug_mode) { + std::size_t num_printed_mismatches = 0; + for (size_t i = 0; i < readback_data_vec.size() && num_printed_mismatches < 64; i++) { + if (readback_data_vec[i] != inputs[i]) { + log_error("[{}]: expected {} got {}", i, inputs[i], readback_data_vec[i]); + num_printed_mismatches++; + } + } + log_error("... (remaining mismatches omitted)"); + } + } + return pass; + }; + + bool pass = true; + constexpr bool enable_check = true; + if constexpr(enable_check) { + for (auto const& output_buffer : local_output_buffers) { + pass &= is_output_correct(output_buffer); + } + for (auto const& output_buffer : remote_output_buffers) { + pass &= is_output_correct(output_buffer); + } + } + + + return pass; +} + +int TestEntrypoint( + const uint32_t num_local_sender_channels, + const uint32_t num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + std::size_t num_buffers_per_edm_channel, + const uint32_t page_size, + const uint32_t num_pages_total, + const bool src_is_dram, + const bool dest_is_dram, + ttnn::ccl::EriscDataMoverTerminationMode termination_mode +) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 2) { + log_info("This test can only be run on n300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + N300TestDevice test_fixture; + + const auto& device_0 = test_fixture.devices_.at(0); + + auto const& active_eth_cores = device_0->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + chip_id_t device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end); + std::tie(device_id, eth_receiver_core) = device_0->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (device_id != 1); + TT_ASSERT(device_id == 1); + const auto& device_1 = test_fixture.devices_.at(device_id); + + bool success = false; + try { + success = RunWriteBWTest( + device_0, + device_1, + + eth_sender_core, + eth_receiver_core, + + num_local_sender_channels, // from args + num_remote_sender_channels, // from args + num_buffers_per_edm_channel, // from args + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + + termination_mode); + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +//////////////////////////////////////////////////////////////////// +/// MESSAGE COUNT TERMINATION MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_0ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 0; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_3BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 3; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_2ChannelForward_2ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 2; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_MessageCountTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + + + +//////////////////////////////////////////////////////////////////// +/// WORKER_INITIATED_TERMINATION_MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_1BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 1; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_0ChannelsReverse_3BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 0; + const uint32_t num_buffers_per_edm_channel = 3; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_1ChannelForward_1ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 1; + const uint32_t num_remote_sender_channels = 1; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +TEST(WorkerEdmDatapath, DISABLED_MergedPayloadAndSignal_4ChannelForward_4ChannelsReverse_2BufferPerChannel_2048PageSize_100kPages_WorkerInitiatedTermination) { + const uint32_t num_local_sender_channels = 4; + const uint32_t num_remote_sender_channels = 4; + const uint32_t num_buffers_per_edm_channel = 2; + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const bool merge_message_and_signal = true; + auto termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + auto result = TestEntrypoint( + num_local_sender_channels, + num_remote_sender_channels, + // default is 1. + // 2 means channel is double buffered + // 3 means channel is triple buffered + // ... and so on + num_buffers_per_edm_channel, + page_size, + num_pages_total, + src_is_dram, + dest_is_dram, + termination_mode + ); + ASSERT_EQ(result, 0); +} + +// EnablePersistentKernelCache diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index 7698269fee8..811e1dce457 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -41,7 +41,7 @@ def run_reduce_scatter_test( mem_config, use_program_cache, function_level_defaults, - enable_async=False, + enable_async=True, num_iters=1, ): if len(t3k_device_mesh.get_device_ids()) != 8: @@ -214,7 +214,7 @@ def run_reduce_scatter_sharded_test( tensor_mem_layout, use_program_cache, function_level_defaults, - enable_async=False, + enable_async=True, num_iters=1, ): if len(t3k_device_mesh.get_device_ids()) != 8: @@ -286,10 +286,6 @@ def run_reduce_scatter_sharded_test( ttl.device.Synchronize(t3k_device_mesh.get_device(device_id)) logger.info(f"Done iteration {i}") - for device_id in t3k_device_mesh.get_device_ids(): - ttl.device.Synchronize(t3k_device_mesh.get_device(device_id)) - logger.info(f"Done iteration {i}") - # Compute golden # TODO: Make it model how reduce scatter actually works for numerical correctness/ordering golden_canonical_out_tensor = torch.zeros(canonical_input_shape).bfloat16() diff --git a/tt_metal/hw/inc/ethernet/dataflow_api.h b/tt_metal/hw/inc/ethernet/dataflow_api.h index 94538d3f65f..d4f62477dd2 100644 --- a/tt_metal/hw/inc/ethernet/dataflow_api.h +++ b/tt_metal/hw/inc/ethernet/dataflow_api.h @@ -149,7 +149,6 @@ void eth_send_bytes_over_channel_payload_only( uint32_t src_addr, uint32_t dst_addr, uint32_t num_bytes, - uint32_t channel, uint32_t num_bytes_per_send = 16, uint32_t num_bytes_per_send_word_size = 1) { // assert(channel < 4); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 4f73c9160e3..a5a2316ac77 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -25,7 +25,7 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con return AllGatherBidirectionalMode::FULL_TENSOR; } -AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology) : +AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology, std::size_t num_buffers_per_worker) : num_links(num_links), semaphore_size(32), ring_size(ring_size), @@ -37,8 +37,11 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM), output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM), - bidirectional_mode(choose_bidirectional_mode(input_tensor)) + bidirectional_mode(choose_bidirectional_mode(input_tensor)), + enable_merged_payload_and_channel_sync(true), + num_buffers_per_worker(num_buffers_per_worker) { + TT_FATAL(num_buffers_per_worker > 0, "num_buffers_per_worker must be > 0"); TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE); TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16); TT_ASSERT((erisc_handshake_address & (16-1)) == 0); @@ -68,15 +71,17 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu this->num_workers_per_link = this->num_eth_buffers; this->eth_sems_l1_base_byte_address = this->erisc_handshake_address + 16 * 3;//16; + // Really should be called offset_after_semaphore_region this->semaphore_offset = this->semaphore_size * this->num_eth_buffers * num_duplicate_directions; // TODO: Remove this once dedicated semaphore space for user kernels are added this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset; + std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); uint32_t const page_size = input_tensor.buffer()->page_size(); - this->eth_buffer_size = tt::round_down((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions), page_size); + std::size_t l1_per_buffer_region = ((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions * this->num_buffers_per_worker)) - channel_sync_bytes_overhead; + this->eth_buffer_size = tt::round_down(l1_per_buffer_region, page_size); + TT_FATAL((this->eth_buffer_size + channel_sync_bytes_overhead) * (this->num_eth_buffers * num_duplicate_directions * this->num_buffers_per_worker) + this->semaphore_offset <= total_l1_buffer_space); TT_FATAL(eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); - TT_FATAL(this->eth_buffer_size * (this->num_eth_buffers * num_duplicate_directions) + this->semaphore_offset <= total_l1_buffer_space); - } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 058225db49d..f3a0956b310 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -41,13 +41,13 @@ class AllGatherConfig { static AllGatherBidirectionalMode choose_bidirectional_mode(Tensor const& input_tensor); public: - AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology); + AllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, all_gather_op::Topology topology, std::size_t num_buffers_per_worker); uint32_t get_erisc_handshake_address() const { return this->erisc_handshake_address; } - uint32_t get_semaphores_offset() const { return this->semaphore_offset; } uint32_t get_num_eth_buffers_per_edm() const { return this->num_eth_buffers; } uint32_t get_num_workers_per_link() const { return this->num_workers_per_link; } + uint32_t get_num_buffers_per_worker() const { return this->num_buffers_per_worker; } uint32_t get_num_workers() const { return this->num_workers_per_link * this->num_links; } uint32_t get_eth_buffer_size() const { return this->eth_buffer_size; } @@ -57,6 +57,7 @@ class AllGatherConfig { uint32_t get_eth_buffers_l1_base_byte_address() const { return this->eth_buffers_l1_base_byte_address; } uint32_t get_semaphore_size() const { return this->semaphore_size; } + std::size_t get_num_buffers_per_channel() const { return this->num_buffers_per_worker; } uint32_t get_num_edm_channels_in_clockwise_direction() const { return this->enable_bidirectional ? @@ -64,6 +65,7 @@ class AllGatherConfig { this->num_workers_per_link; } uint32_t get_ring_size() const { return this->ring_size; } + bool is_payload_and_channel_sync_merged() const { return enable_merged_payload_and_channel_sync;} bool is_buffer_in_clockwise_ring(const uint32_t buffer_index) const { // For now we split it as lower half => clockwise, upper half => counter-clockwise // This is slightly suboptimal since the non-full-chunks go to the upper half. @@ -89,6 +91,7 @@ class AllGatherConfig { log_trace(tt::LogOp, "\terisc_handshake_address: {}", erisc_handshake_address); log_trace(tt::LogOp, "\tnum_buffers: {}", num_eth_buffers); log_trace(tt::LogOp, "\tnum_workers_per_link: {}", num_workers_per_link); + log_trace(tt::LogOp, "\tnum_buffers_per_worker: {}", num_buffers_per_worker); log_trace(tt::LogOp, "\teth_buffer_size: {}", eth_buffer_size); log_trace(tt::LogOp, "\tsemaphore_size: {}", semaphore_size); log_trace(tt::LogOp, "\tsemaphore_offset: {}", semaphore_offset); @@ -104,6 +107,7 @@ class AllGatherConfig { uint32_t num_links; uint32_t num_eth_buffers; uint32_t num_workers_per_link; + uint32_t num_buffers_per_worker; uint32_t eth_buffer_size; uint32_t semaphore_size; uint32_t semaphore_offset; @@ -115,6 +119,7 @@ class AllGatherConfig { bool enable_bidirectional; const bool input_is_dram; const bool output_is_dram; + const bool enable_merged_payload_and_channel_sync; }; struct AllGather { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index b8bc34b2a4e..b4bd94c4938 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 +#pragma once #include "dataflow_api.h" #include "debug/assert.h" @@ -33,7 +34,7 @@ FORCE_INLINE void write_and_send_chunk( uint32_t l1_read_addr = get_read_ptr(cb_id); noc_async_write(l1_read_addr, remote_l1_write_addr, page_size * num_pages); noc_semaphore_inc(eth_l1_sender_semaphore_addr, 1); - // TODO: do eth semaphore inc here + for (uint32_t i = 0; i < num_pages; ++i) { #ifdef ROW_MAJOR_LAYOUT #ifdef INTERLEAVED_MEM_LAYOUT diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index c3e69e13104..244dc097a29 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -28,14 +28,14 @@ using namespace ccl; static std::tuple select_worker_cores(AllGatherConfig const& all_gather_config, uint32_t num_links, uint32_t link, uint32_t full_send_direction) { constexpr uint32_t worker_grid_width = 8; - const bool fit_sender_and_receiver_workers_on_same_row = (worker_grid_width / 2) >= all_gather_config.get_num_eth_buffers_per_edm(); + const bool fit_sender_and_receiver_workers_on_same_row = (worker_grid_width / 2) >= all_gather_config.get_num_workers_per_link(); std::set receiver_worker_cores = {}; std::set sender_worker_cores = {}; uint32_t max_cols = 8; - uint32_t curr_row = link * (((all_gather_config.get_num_eth_buffers_per_edm() * 2 - 1) / max_cols) + 1) + - (full_send_direction * num_links * (((all_gather_config.get_num_eth_buffers_per_edm() * 2 - 1) / max_cols) + 1)); + uint32_t curr_row = link * (((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1) + + (full_send_direction * num_links * (((all_gather_config.get_num_workers_per_link() * 2 - 1) / max_cols) + 1)); uint32_t curr_col = 0; - for (uint32_t r = 0; r < all_gather_config.get_num_eth_buffers_per_edm(); r++) { + for (uint32_t r = 0; r < all_gather_config.get_num_workers_per_link(); r++) { receiver_worker_cores.insert(CoreRange(CoreCoord(curr_col, curr_row))); curr_col ++; if (curr_col == max_cols) { @@ -43,7 +43,7 @@ static std::tuple select_worker_cores(AllGatherConfig curr_row++; } } - for (uint32_t s = 0; s < all_gather_config.get_num_eth_buffers_per_edm(); s++) { + for (uint32_t s = 0; s < all_gather_config.get_num_workers_per_link(); s++) { sender_worker_cores.insert(CoreRange(CoreCoord(curr_col, curr_row))); curr_col ++; if (curr_col == max_cols) { @@ -61,8 +61,8 @@ static std::vector> compute_worker_sender_num_transfers( std::vector> worker_sender_num_transfers; worker_sender_num_transfers.reserve(num_links); for (uint32_t l = 0; l < num_links; ++l) { - worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); switch (topology) { case all_gather_op::Topology::Linear: @@ -99,8 +99,8 @@ static std::vector> compute_worker_receiver_num_transfers( std::vector> worker_sender_num_transfers; worker_sender_num_transfers.reserve(num_links); for (uint32_t l = 0; l < num_links; ++l) { - worker_sender_num_transfers.emplace_back(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + worker_sender_num_transfers.emplace_back(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { uint32_t &worker_num_transfers = worker_sender_num_transfers.at(l).at(b); switch (topology) { case all_gather_op::Topology::Linear: @@ -177,11 +177,12 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); + std::size_t num_edm_buffers_per_channel = 1; tt::tt_metal::Program program{}; // Issue #10978: CCLs need to be tagged as having multi-device dependencies, when running on Galaxy. program.capture_multi_device_dependencies(); const auto& device = input_tensor.device(); - auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology); + auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology, num_edm_buffers_per_channel); auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); bool enable_print = false; @@ -214,8 +215,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& bool width = input_tensor.get_legacy_shape().rank() - 1 == dim; tt::DataFormat df = datatype_to_dataformat_converter(input_tensor.get_dtype()); - uint32_t global_num_workers = all_gather_config.get_num_eth_buffers_per_edm() * num_links; - std::map worker_defines; if (rm) { worker_defines["ROW_MAJOR_LAYOUT"] = "1"; @@ -235,7 +234,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& const uint32_t num_full_send_directions = full_send_both_directions ? 2 : 1; constexpr uint32_t max_num_full_send_directions = 2; // number of worker cores is 2x this since there is 1 worker for the sender buffer and 1 worker for the receiver buffer - uint32_t total_worker_core_pairs_used = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions; + uint32_t global_num_workers = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions; + uint32_t total_worker_core_pairs_used = global_num_workers; uint32_t num_input_pages = input_tensor.buffer()->size() / input_page_size; uint32_t min_pages_per_link = num_input_pages / num_links; @@ -257,26 +257,41 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& auto edm_sem_addrs_per_link = std::vector>(num_links); auto edm_buffer_addrs_per_link = std::vector>(num_links); for (uint32_t link = 0; link < num_links; link++) { - edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); - edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions); + edm_sem_addrs_per_link.at(link).reserve(all_gather_config.get_num_workers_per_link() * num_full_send_directions); + edm_buffer_addrs_per_link.at(link).reserve(all_gather_config.get_num_workers_per_link() * num_full_send_directions); uint32_t edm_sem_addr = all_gather_config.get_eth_sems_l1_base_byte_address(); uint32_t edm_buffer_addr = all_gather_config.get_eth_buffers_l1_base_byte_address(); for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { edm_sem_addrs_per_link.at(link).push_back(edm_sem_addr); edm_sem_addr += all_gather_config.get_semaphore_size(); edm_buffer_addrs_per_link.at(link).push_back(edm_buffer_addr); - edm_buffer_addr += all_gather_config.get_eth_buffer_size(); + edm_buffer_addr += ((all_gather_config.get_eth_buffer_size() + + (all_gather_config.is_payload_and_channel_sync_merged() > 0 ? EriscDatamoverConfig::get_eth_word_size() : 0)) * all_gather_config.get_num_buffers_per_channel()); TT_ASSERT((direction == 0 && b == 0) || (edm_buffer_addrs_per_link.at(link).back() != edm_buffer_addrs_per_link.at(link).front())); TT_ASSERT((direction == 0 && b == 0) || (edm_sem_addrs_per_link.at(link).back() != edm_sem_addrs_per_link.at(link).front())); } } clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), + all_gather_config.get_erisc_handshake_address(), + edm_sem_addrs_per_link.at(link), + edm_buffer_addrs_per_link.at(link), + ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + all_gather_config.get_num_buffers_per_channel(), + input_tensor.device()->id()); counter_clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), + all_gather_config.get_erisc_handshake_address(), + edm_sem_addrs_per_link.at(link), + edm_buffer_addrs_per_link.at(link), + ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, + ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + all_gather_config.get_num_buffers_per_channel(), + input_tensor.device()->id()); } for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { @@ -358,8 +373,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // We can't have overlap between the mcast grid for worker cores for different links since mcasting the semaphore in receiver would corrupt other link semaphores // We can have overlap between a link's sender and receiver worker grids if we have the semaphores at different addresses auto const& [receiver_workers, sender_workers] = select_worker_cores(all_gather_config, num_links, i, direction); - uint32_t worker_index = 0; - uint32_t workers_per_link = all_gather_config.get_num_workers_per_link() / all_gather_config.get_num_eth_buffers_per_edm(); // Circular Buffer Setup uint32_t cb_page_size = input_page_size; @@ -386,17 +399,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& // number of pages that can fit in a single ethernet L1 buffer (not the number of pages sent to this channel) std::vector pages_per_eth_l1_buffer; - pages_per_buffer.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + pages_per_buffer.reserve(all_gather_config.get_num_workers_per_link()); uint32_t max_pages_per_eth_l1_sender_buffer = all_gather_config.get_eth_buffer_size() / input_page_size; - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - pages_per_buffer.push_back((pages_per_link.at(i) / all_gather_config.get_num_eth_buffers_per_edm())); + for(uint32_t w = 0; w < all_gather_config.get_num_workers_per_link(); ++w) { + pages_per_buffer.push_back((pages_per_link.at(i) / all_gather_config.get_num_workers_per_link())); pages_per_eth_l1_buffer.push_back(max_pages_per_eth_l1_sender_buffer); - if (b < pages_per_link.at(i) % all_gather_config.get_num_eth_buffers_per_edm()) { + if (w < pages_per_link.at(i) % all_gather_config.get_num_workers_per_link()) { pages_per_buffer.back()++; } log_trace(tt::LogOp, "pages_per_link[{}]: {}", i, pages_per_link.at(i)); - log_trace(tt::LogOp, "pages_per_buffer[{}]: {}", b, pages_per_buffer.at(b)); + log_trace(tt::LogOp, "pages_per_buffer[{}]: {}", w, pages_per_buffer.at(w)); log_trace(tt::LogOp, "max_pages_per_eth_l1_sender_buffer: {}",max_pages_per_eth_l1_sender_buffer); } TT_ASSERT(std::accumulate(pages_per_buffer.begin(), pages_per_buffer.end(), 0) == pages_per_link.at(i)); @@ -420,41 +433,41 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(rem_pages < pages_per_chunk || num_full_chunks == 0); TT_ASSERT(rem_pages <= max_pages_per_chunk); - std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(),0); - std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); - std::vector is_channel_shrinkable(all_gather_config.get_num_eth_buffers_per_edm(), false); - std::vector largest_packets_per_channel(all_gather_config.get_num_eth_buffers_per_edm(), 0); + std::vector num_full_chunks_per_worker(all_gather_config.get_num_workers_per_link(),0); + std::vector rem_pages_per_worker(all_gather_config.get_num_workers_per_link(), 0); + std::vector is_channel_shrinkable(all_gather_config.get_num_workers_per_link(), false); + std::vector largest_packets_per_channel(all_gather_config.get_num_workers_per_link(), 0); std::vector clockwise_link_buffer_num_messages_to_send; std::vector counter_clockwise_link_buffer_num_messages_to_send; std::vector edm_semaphores_base_address; std::vector link_buffer_sender_addresses; - clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - edm_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - link_buffer_sender_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); + clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); + counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); + edm_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link()); + link_buffer_sender_addresses.reserve(all_gather_config.get_num_workers_per_link()); { - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { - num_full_chunks_per_worker.at(b) = num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm(); + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { + num_full_chunks_per_worker.at(b) = num_full_chunks / all_gather_config.get_num_workers_per_link(); } uint32_t worker_idx = 0; - for (worker_idx = 0; worker_idx < num_full_chunks % all_gather_config.get_num_eth_buffers_per_edm(); ++worker_idx) { + for (worker_idx = 0; worker_idx < num_full_chunks % all_gather_config.get_num_workers_per_link(); ++worker_idx) { num_full_chunks_per_worker.at(worker_idx)++; } if (rem_pages != 0) { - rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) = rem_pages; - TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_eth_buffers_per_edm()) * 2 <= cb_num_pages); + rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) = rem_pages; + TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= cb_num_pages); } { // Logging log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (clockwise):"); - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { if (is_buffer_in_clockwise_direction(b)) { log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); } } log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (counter-clockwise):"); - for (std::size_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); b++) { + for (std::size_t b = 0; b < all_gather_config.get_num_workers_per_link(); b++) { if (!is_buffer_in_clockwise_direction(b)) { log_trace(tt::LogOp, "\tworker {}: {}, {}", b, num_full_chunks_per_worker.at(b), rem_pages_per_worker.at(b)); } @@ -467,7 +480,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& is_channel_shrinkable.at(b) = shrinkable; largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size(); } - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { // link num messages clockwise_link_buffer_num_messages_to_send.push_back( (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * @@ -476,7 +489,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& (num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * receiver_worker_num_transfers.at(i).at(b)); } - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b)); log_trace(tt::LogOp, "num_full_chunks_per_worker[{}]: {}", b, num_full_chunks_per_worker.at(b)); log_trace(tt::LogOp, "clockwise_link_buffer_num_messages_to_send[{}]: {}", b, clockwise_link_buffer_num_messages_to_send.at(b)); @@ -485,31 +498,27 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector receiver_semaphores_base_address; std::vector link_buffer_receiver_addresses; - receiver_semaphores_base_address.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - link_buffer_receiver_addresses.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + receiver_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link()); + link_buffer_receiver_addresses.reserve(all_gather_config.get_num_workers_per_link()); + for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { receiver_semaphores_base_address.push_back(all_gather_config.get_eth_sems_l1_base_byte_address() + b * all_gather_config.get_semaphore_size()); link_buffer_receiver_addresses.push_back(all_gather_config.get_eth_buffers_l1_base_byte_address() + b * all_gather_config.get_eth_buffer_size()); } - std::vector sender_eth_sem_addrs; sender_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector sender_eth_buffer_addrs; sender_eth_buffer_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector receiver_eth_sem_addrs; receiver_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - std::vector receiver_eth_buffer_addrs; receiver_eth_buffer_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - uint32_t num_workers_per_eth_buffer = std::min(workers_per_link, all_gather_config.get_num_eth_buffers_per_edm() - worker_index); - + std::vector sender_eth_sem_addrs; sender_eth_sem_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector sender_eth_buffer_addrs; sender_eth_buffer_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector receiver_eth_sem_addrs; receiver_eth_sem_addrs.reserve(all_gather_config.get_num_workers_per_link()); + std::vector receiver_eth_buffer_addrs; receiver_eth_buffer_addrs.reserve(all_gather_config.get_num_workers_per_link()); + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { std::vector sender_worker_coords; std::vector receiver_worker_coords; - for (uint32_t w = b * num_workers_per_eth_buffer; w < (b + 1) * num_workers_per_eth_buffer; ++w) { - sender_worker_coords.push_back( - ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(sender_worker_cores.at(w)).x, - device->worker_core_from_logical_core(sender_worker_cores.at(w)).y)); - receiver_worker_coords.push_back( - ttnn::ccl::WorkerXY( - device->worker_core_from_logical_core(receiver_worker_cores.at(w)).x, - device->worker_core_from_logical_core(receiver_worker_cores.at(w)).y)); - } + sender_worker_coords.push_back( + ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(sender_worker_cores.at(b)).x, + device->worker_core_from_logical_core(sender_worker_cores.at(b)).y)); + receiver_worker_coords.push_back( + ttnn::ccl::WorkerXY( + device->worker_core_from_logical_core(receiver_worker_cores.at(b)).x, + device->worker_core_from_logical_core(receiver_worker_cores.at(b)).y)); bool sender_enabled = (!is_linear || !is_last_chip_in_chain); if (sender_enabled) { @@ -543,9 +552,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } - // 1 Worker per buffer - for (uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { - uint32_t global_worker_index = all_gather_config.get_num_eth_buffers_per_edm() * i + b; + for (uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { + uint32_t global_worker_index = all_gather_config.get_num_workers_per_link() * i + b; bool is_clockwise_direction = is_buffer_in_clockwise_direction(b); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index a26499b9693..7d55b8ba911 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -256,6 +256,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( } } + KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, @@ -294,6 +295,7 @@ KernelHandle generate_edm_kernel( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, + std::size_t num_buffers_per_channel, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, ccl::EriscDataMoverTerminationMode termination_mode) { TT_ASSERT(num_channels > 0); @@ -304,23 +306,26 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); - const uint32_t buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, page_size); + const uint32_t channel_buffer_size = ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; + TT_ASSERT(edm_buffer_addr % EriscDatamoverConfig::get_eth_word_size() == 0); edm_buffer_addresses.at(c) = edm_buffer_addr; - edm_buffer_addr += buffer_size; + log_trace(tt::LogOp, " edm_buffer_addresses({}) = {}", c, edm_buffer_addr); + edm_buffer_addr += num_buffers_per_channel * (channel_buffer_size + (ccl::EriscDatamoverConfig::enable_merged_payload_and_channel_sync ? ccl::EriscDatamoverConfig::get_eth_channel_sync_size_bytes() : 0)); TT_ASSERT((c == 0) || (edm_buffer_addresses.back() != edm_buffer_addresses.front())); TT_ASSERT((c == 0) || (edm_sem_addresses.back() != edm_sem_addresses.front())); } return ccl::EriscDatamoverBuilder( - buffer_size, + channel_buffer_size, ccl::EriscDatamoverConfig::get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode, - termination_mode); + termination_mode, + num_buffers_per_channel); } template @@ -340,12 +345,8 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor this->slice_dim_is_width = input_tensor.get_legacy_shape().rank() - 1 == slice_dim; this->is_sharded = input_tensor.is_sharded(); - int32_t shard_size_in_bytes = - is_sharded ? (input_tensor.buffer()->page_size() * input_tensor.buffer()->shard_spec().tensor2d_shape[0] * - input_tensor.buffer()->shard_spec().tensor2d_shape[1]) / - input_tensor.shard_spec()->num_cores() - : -1; - this->input_page_size = is_sharded ? shard_size_in_bytes : input_tensor.buffer()->page_size(); + this->input_page_size = input_tensor.buffer()->page_size(); + log_trace(tt::LogOp, "input_page_size={}", input_page_size); if (row_major) { this->num_cols = input_tensor.get_legacy_shape()[-1]; auto input_shape = input_tensor.get_legacy_shape(); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 6719203d8df..a47e7fd0ea0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -455,6 +455,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, + std::size_t num_buffers_per_channel, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp index 3d7a556a8f0..fd9e061b248 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp @@ -3,12 +3,55 @@ // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" namespace ttnn { namespace ccl { +std::size_t EriscDatamoverConfig::get_eth_channel_sync_size_bytes() { return eth_channel_sync_size_bytes; } + +uint32_t EriscDatamoverConfig::get_edm_handshake_address() { return usable_l1_base_address; } + +std::size_t EriscDatamoverConfig::get_semaphores_region_size(std::size_t num_edm_channels) { + return (num_edm_channels * semaphore_size); +} +std::size_t EriscDatamoverConfig::get_semaphores_region_start_offset(std::size_t num_edm_channels) { + return handshake_location_size + edm_receiver_first_level_ack_source_word_size; +} +uint32_t EriscDatamoverConfig::get_semaphores_base_address(std::size_t num_edm_channels) { + return usable_l1_base_address + get_semaphores_region_start_offset(num_edm_channels); +} +uint32_t EriscDatamoverConfig::get_buffers_region_start_offset(std::size_t num_edm_channels) { + return get_semaphores_region_start_offset(num_edm_channels) + get_semaphores_region_size(num_edm_channels); +} +std::size_t EriscDatamoverConfig::get_eth_word_size() { return eth_word_size_bytes; } +uint32_t EriscDatamoverConfig::get_buffers_base_address(std::size_t num_edm_channels) { + uint32_t base_address = tt::round_up(usable_l1_base_address + get_buffers_region_start_offset(num_edm_channels), eth_word_size_bytes); + TT_ASSERT(base_address % eth_word_size_bytes == 0); + return base_address; +} +uint32_t EriscDatamoverConfig::compute_buffer_size(std::size_t num_edm_channels, std::size_t num_buffers_per_channel, uint32_t page_size) { + page_size = std::max(page_size, eth_word_size_bytes); + TT_ASSERT(num_edm_channels > 0); + std::size_t channel_sync_bytes_overhead = (enable_merged_payload_and_channel_sync * 16); + std::size_t total_usable_space = total_l1_buffer_space - get_buffers_region_start_offset(num_edm_channels); + std::size_t l1_per_buffer_region = (total_usable_space / (num_edm_channels * num_buffers_per_channel)) - channel_sync_bytes_overhead; + uint32_t buffer_size = tt::round_down(l1_per_buffer_region, page_size); + log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); + log_trace( + tt::LogOp, "get_buffers_base_address(num_edm_channels): {}", get_buffers_base_address(num_edm_channels)); + log_trace( + tt::LogOp, "usable buffer space: {}", total_l1_buffer_space - get_buffers_base_address(num_edm_channels)); + log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); + log_trace(tt::LogOp, "page_size: {}", page_size); + + log_trace(tt::LogOp, "Buffer size: {}", buffer_size); + + TT_ASSERT(buffer_size > 0 && buffer_size % page_size == 0); + return buffer_size; +} + CCLOpConfig::CCLOpConfig( std::vector& input_tensors, const std::vector& output_tensors, Topology topology) : input_tensors(&input_tensors), diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index 0529c10e805..af3dfae46af 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -5,7 +5,7 @@ #pragma once #include "eth_l1_address_map.h" -#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include @@ -21,42 +21,26 @@ struct EriscDatamoverConfig { static constexpr std::size_t semaphore_size = 32; static constexpr std::size_t handshake_location_size = 16; // ethernet word size + static constexpr std::size_t handshake_padding_multiple = 3; // ethernet word size // The EDM uses this fixed address as a source for a first level ack sent from receiver -> sender // side. We have this dedicated source address to avoid a race between first and second level ack // where second level ack overwrites the first level ack in L1 before the first one is sent out. // The memory contents in L1 will be {1, 1, x, x}. By having this dedicated source memory, we // avoid the race static constexpr std::size_t edm_receiver_first_level_ack_source_word_size = 16; // ethernet word size + static constexpr std::size_t eth_channel_sync_size_bytes = 16; static constexpr std::size_t eth_word_size_bytes = 16; + static constexpr bool enable_merged_payload_and_channel_sync = true; + static std::size_t get_eth_channel_sync_size_bytes(); + static uint32_t get_edm_handshake_address(); + static std::size_t get_semaphores_region_size(std::size_t num_edm_channels); + static std::size_t get_semaphores_region_start_offset(std::size_t num_edm_channels); + static uint32_t get_semaphores_base_address(std::size_t num_edm_channels); + static uint32_t get_buffers_region_start_offset(std::size_t num_edm_channels); + static std::size_t get_eth_word_size(); + static uint32_t get_buffers_base_address(std::size_t num_edm_channels); + static uint32_t compute_buffer_size(std::size_t num_edm_channels, std::size_t num_buffers_per_channel = 1, uint32_t page_size = eth_word_size_bytes); - static uint32_t get_edm_handshake_address() { return usable_l1_base_address; } - static uint32_t get_semaphores_base_address(std::size_t num_edm_channels) { - return usable_l1_base_address + (handshake_location_size * 3) + edm_receiver_first_level_ack_source_word_size; - } - static uint32_t get_buffers_base_address(std::size_t num_edm_channels) { - uint32_t base_address =tt::round_up( - get_semaphores_base_address(num_edm_channels) + num_edm_channels * semaphore_size, eth_word_size_bytes); - TT_ASSERT(base_address % eth_word_size_bytes == 0); - return base_address; - } - static uint32_t compute_buffer_size(std::size_t num_edm_channels, uint32_t page_size = eth_word_size_bytes) { - page_size = std::max(page_size, eth_word_size_bytes); - TT_ASSERT(num_edm_channels > 0); - uint32_t buffer_size =tt::round_down( - (total_l1_buffer_space - get_buffers_base_address(num_edm_channels)) / (num_edm_channels), page_size); - log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); - log_trace( - tt::LogOp, "get_buffers_base_address(num_edm_channels): {}", get_buffers_base_address(num_edm_channels)); - log_trace( - tt::LogOp, "usable buffer space: {}", total_l1_buffer_space - get_buffers_base_address(num_edm_channels)); - log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - log_trace(tt::LogOp, "page_size: {}", page_size); - - log_trace(tt::LogOp, "Buffer size: {}", buffer_size); - - TT_ASSERT(buffer_size > 0 && buffer_size % page_size == 0); - return buffer_size; - } }; struct CCLOpConfig { @@ -97,6 +81,7 @@ class EriscDatamoverBuilder { uint32_t worker_semaphore_id, uint32_t num_eth_messages_to_forward, uint32_t channel, + uint32_t num_buffers, std::vector const& worker_coords, uint32_t largest_message_size_bytes = 0) : worker_coords(worker_coords), @@ -111,6 +96,7 @@ class EriscDatamoverBuilder { uint32_t num_eth_messages_to_forward; uint32_t channel; uint32_t largest_message_size_bytes; + uint32_t num_buffers; bool is_sender; }; @@ -143,6 +129,8 @@ class EriscDatamoverBuilder { ccl::EriscDataMoverTerminationMode const termination_mode; uint32_t num_senders; uint32_t num_receivers; + std::size_t num_buffers_per_channel; + chip_id_t chip_id; bool enable_sender; bool enable_receiver; @@ -160,19 +148,24 @@ class EriscDatamoverBuilder { std::vector const& local_semaphore_addresses, std::vector const& local_buffer_addresses, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, - ccl::EriscDataMoverTerminationMode termination_mode = - ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) : + ccl::EriscDataMoverTerminationMode termination_mode = ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, + std::size_t num_buffers_per_channel = 1, + chip_id_t chip_id = -1) : local_semaphore_addresses(local_semaphore_addresses), local_buffer_addresses(local_buffer_addresses), eth_buffer_size_bytes(eth_buffer_size), handshake_addr(handshake_addr), num_channel_buffers(local_buffer_addresses.size()), buffer_sharing_mode(buffer_sharing_mode), + num_buffers_per_channel(num_buffers_per_channel), termination_mode(termination_mode), enable_sender(false), enable_receiver(false), num_senders(0), - num_receivers(0) { + num_receivers(0), + chip_id(chip_id) { + + TT_ASSERT(num_buffers_per_channel > 0); TT_ASSERT(local_buffer_addresses.size() == local_semaphore_addresses.size()); active_channels.reserve(num_channel_buffers); TT_ASSERT(eth_buffer_size_bytes < 163000); @@ -199,7 +192,7 @@ class EriscDatamoverBuilder { this->num_senders++; auto channel = active_channels.size(); active_channels.emplace_back( - true, worker_semaphore_id, num_eth_messages_to_forward, channel, worker_coords, expected_message_size_bytes); + true, worker_semaphore_id, num_eth_messages_to_forward, channel, this->num_buffers_per_channel, worker_coords, expected_message_size_bytes); log_trace(tt::LogOp, "Adding sender channel:"); log_trace(tt::LogOp, "\tworker_semaphore_id: {}", active_channels.back().worker_semaphore_id); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -229,7 +222,7 @@ class EriscDatamoverBuilder { this->num_receivers++; auto channel = active_channels.size(); active_channels.emplace_back( - false, worker_semaphore_id, num_eth_messages_to_forward, channel, worker_coords, expected_message_size_bytes); + false, worker_semaphore_id, num_eth_messages_to_forward, channel, this->num_buffers_per_channel, worker_coords, expected_message_size_bytes); log_trace(tt::LogOp, "Adding receiver channel:"); log_trace(tt::LogOp, "\tworker_semaphore_id: {}", active_channels.back().worker_semaphore_id); log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); @@ -247,7 +240,12 @@ class EriscDatamoverBuilder { this->num_senders, this->num_receivers, this->buffer_sharing_mode, - this->termination_mode}; + this->termination_mode, + 1, + static_cast(this->num_senders > 0 && active_channels.at(0).is_sender), + this->num_buffers_per_channel, + chip_id + }; } [[nodiscard]] diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp new file mode 100644 index 00000000000..50692a5f4cb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "dataflow_api.h" + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" + +namespace ccl { +namespace edm { + +template +struct WorkerToEdmReader{ + constexpr WorkerToEdmReader ( + ttnn::ccl::WorkerXY edm_worker_xy, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_addr, + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_l1_sem_addr)), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(edm_l1_sem_addr), + buffer_size_bytes(buffer_size_bytes), + buffer_index(0) + {} + + FORCE_INLINE void wait_for_payload_available() const { + noc_semaphore_wait_min(worker_sem_addr, 1); + if (*worker_sem_addr > 1) { + DPRINT << "ERROR!!!!!!!!!!!!!!!!!!!!!!!!!\n"; + ASSERT(false); + } + noc_semaphore_set(worker_sem_addr, 0); + } + + FORCE_INLINE void fetch_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size, bool last_message) { + uint64_t buffer_address = edm_buffer_addr + (buffer_index * (buffer_size_bytes + sizeof(eth_channel_sync_t))); + fetch_chunk(cb_id, num_pages, page_size, buffer_address); + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + if (!last_message) { + DPRINT << "fetch_payload_blocking: incrementing semaphore to " << (uint32_t)(edm_semaphore_addr & 0xFFFFFFFF) << "\n"; + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + } + } else { + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + } + buffer_index = (buffer_index == last_buffer_index) ? 0 : buffer_index + 1; + } + + FORCE_INLINE void fetch_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + // With worker initiated termination mode, we must always specify if we are sending the last message or not + ASSERT(termination_mode != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); + fetch_payload_blocking(cb_id, num_pages, page_size, false); + } + + FORCE_INLINE void close() { + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t buffer_index; +}; + + +template +struct WorkerToEdmSender{ + constexpr WorkerToEdmSender ( + ttnn::ccl::WorkerXY edm_worker_xy, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_addr, + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_xy.x, edm_worker_xy.y, edm_l1_sem_addr)), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(edm_l1_sem_addr), + buffer_size_bytes(buffer_size_bytes), + buffer_index(0) + { + ASSERT(buffer_size_bytes > 0); + } + + FORCE_INLINE void wait_for_empty_write_slot() const { + noc_semaphore_wait(worker_sem_addr, 1); + noc_semaphore_set(worker_sem_addr, 0); + } + + FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + uint64_t buffer_address = edm_buffer_addr + (buffer_index * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + DPRINT << "SENDER SEND buffer_size_bytes = " << (uint32_t)(this->buffer_size_bytes) << "\n"; + DPRINT << "SENDER SEND " << (uint32_t)(buffer_address & 0xffffffff) << " -> " << (uint32_t)((buffer_address & 0xffffffff) + (page_size * num_pages)) << "\n"; + send_chunk(cb_id, num_pages, page_size, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + buffer_index = (buffer_index == last_buffer_index) ? 0 : buffer_index + 1; + } + + FORCE_INLINE void close() { + if constexpr (termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) { + this->wait_for_empty_write_slot(); + noc_semaphore_inc(edm_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + } + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t buffer_index; +}; + + +} // namespace edm +} // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index d1037849dd8..d865141dc42 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -6,7 +6,6 @@ #include "dataflow_api.h" #include "debug/assert.h" -#include "debug/dprint.h" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" using ttnn::ccl::ShardType; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp index 6c45db409de..a0232abc8b9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp @@ -2,16 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 +#pragma once + #include #include -#include #include "dataflow_api.h" #include "debug/assert.h" #include "eth_l1_address_map.h" #include "ethernet/dataflow_api.h" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "tt_metal/hw/inc/wormhole/noc/noc.h" using ttnn::ccl::EriscDataMoverBufferSharingMode; using ttnn::ccl::EriscDataMoverTerminationMode; @@ -20,10 +20,11 @@ using ttnn::ccl::EriscDataMoverWorkerSignal; namespace erisc { namespace datamover { -template +template struct EriscDatamoverConfig { static constexpr EriscDataMoverBufferSharingMode BUFFER_SHARING_MODE = buffer_sharing_mode; static constexpr EriscDataMoverTerminationMode TERMINATION_MODE = termination_mode; + static constexpr uint8_t NUM_BUFFERS_PER_CHANNEL = num_buffers_per_channel; }; template @@ -56,43 +57,46 @@ class ChannelBuffer final { enum STATE : uint8_t { DONE = 0, - // For sender: means we are ready to tell the worker(s) that the buffer is available for writing into - // - SIGNALING_WORKER, + // we are ready to tell the worker(s) that the buffer is available for writing into + SENDER_SIGNALING_WORKER, - // For sender: we are waiting for the payload to arrive in L1; we are checking local semaphore for worker - // completion For receiver: we are waiting for worker to complete pull of payload from L1; we are checking local - // semaphore for worker completion - WAITING_FOR_WORKER, + // we are waiting for the payload to arrive in L1; we are checking local semaphore for worker + // completion + SENDER_WAITING_FOR_WORKER, + + // means workers have signalled (via semaphores) that the buffer payload is + SENDER_READY_FOR_ETH_TRANSFER, + + // means we are waiting for ack from receiver that payload was received + SENDER_WAITING_FOR_ETH, + + // We received a packet from ethernet and we can signal the downstream worker to signal + // packet availability + RECEIVER_SIGNALING_WORKER, - // For sender: means workers have signalled (via semaphores) that the buffer payload is - // ready in L1 - // For receiver: - READY_FOR_ETH_TRANSFER, + // we are waiting for worker to complete pull of payload from L1; we are checking local + // semaphore for worker completion + RECEIVER_WAITING_FOR_WORKER, - // For sender: means we are waiting for ack from receiver that payload was received - // For receiver: means we are waitinf for a payload from sender - WAITING_FOR_ETH, + // means we are waitinf for a payload from sender + RECEIVER_WAITING_FOR_ETH, }; // for default initialization in arrays ChannelBuffer() : local_semaphore_address(0), worker_coords(0), - address(0), size_in_bytes(0), worker_semaphore_l1_address(0), num_workers(0), num_messages_moved(0), - channel_bytes_sent_address(0), - channel_bytes_acked_address(0), total_num_messages_to_move(0), state(STATE::DONE) {} ChannelBuffer( uint32_t eth_transaction_channel, size_t address, - size_t size_in_bytes, + size_t payload_size_in_bytes, uint32_t worker_semaphore_l1_address, uint32_t num_workers, uint32_t total_num_messages_to_move, @@ -102,19 +106,38 @@ class ChannelBuffer final { eth_transaction_channel(eth_transaction_channel), local_semaphore_address(local_semaphore_address), worker_coords(worker_coords), - address(address), - size_in_bytes(size_in_bytes), + size_in_bytes(payload_size_in_bytes + sizeof(eth_channel_sync_t)), worker_semaphore_l1_address(worker_semaphore_l1_address), num_workers(num_workers), num_messages_moved(0), - channel_bytes_sent_address(&erisc_info->channels[eth_transaction_channel].bytes_sent), - channel_bytes_acked_address(&erisc_info->channels[eth_transaction_channel].receiver_ack), total_num_messages_to_move(total_num_messages_to_move), - state(is_sender_side ? STATE::WAITING_FOR_WORKER : STATE::WAITING_FOR_ETH), + state( + is_sender_side ? TERMINATION_MODE == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED + ? STATE::SENDER_WAITING_FOR_WORKER + : STATE::SENDER_WAITING_FOR_WORKER + : TERMINATION_MODE == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED + ? STATE::RECEIVER_WAITING_FOR_ETH + : STATE::RECEIVER_WAITING_FOR_ETH), + + buffer_index(0), is_sender_completion_pending(false), is_sender_side(is_sender_side) { clear_local_semaphore(); + for (uint8_t i = 0; i < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL; i++) { + this->addresses[i] = address + i * (this->size_in_bytes); + + uint32_t channel_sync_addr = this->addresses[i] + payload_size_in_bytes; + volatile uint32_t* bytes_sent_addr = &(reinterpret_cast(channel_sync_addr)->bytes_sent); + volatile uint32_t* bytes_acked_addr = &(reinterpret_cast(channel_sync_addr)->receiver_ack); + channel_bytes_sent_addresses[i] = bytes_sent_addr; + channel_bytes_acked_addresses[i] = bytes_acked_addr; + + ASSERT((uint32_t)channel_bytes_acked_addresses[i] != (uint32_t)channel_bytes_sent_addresses[i]); + *(channel_bytes_sent_addresses[i]) = 0; + *(channel_bytes_acked_addresses[i]) = 0; + } + if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel @@ -126,7 +149,9 @@ class ChannelBuffer final { } } // Resets the semaphore in local L1, which workers write to remotely. - FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + FORCE_INLINE void clear_local_semaphore() { + noc_semaphore_set(local_semaphore_address, 0); + } // Increment the semaphore in the remote L1s of every worker associated with this ChannelBuffer FORCE_INLINE void increment_worker_semaphores() { @@ -185,15 +210,15 @@ class ChannelBuffer final { [[nodiscard]] FORCE_INLINE bool is_done() const { return this->state == STATE::DONE; } [[nodiscard]] FORCE_INLINE uint32_t get_eth_transaction_channel() const { - ASSERT(this->eth_transaction_channel < eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); return this->eth_transaction_channel; } - [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->address; } [[nodiscard]] FORCE_INLINE std::size_t get_size_in_bytes() const { return this->size_in_bytes; } [[nodiscard]] FORCE_INLINE std::size_t get_current_payload_size() const { return this->get_size_in_bytes(); } - [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { return this->address; } + [[nodiscard]] FORCE_INLINE std::size_t get_buffer_address() const { + return this->addresses[buffer_index]; } + [[nodiscard]] FORCE_INLINE std::size_t get_remote_eth_buffer_address() const { return this->get_buffer_address(); } FORCE_INLINE uint32_t get_messages_moved() { return this->num_messages_moved; } FORCE_INLINE void increment_messages_moved() { this->num_messages_moved++; } @@ -204,27 +229,72 @@ class ChannelBuffer final { FORCE_INLINE void set_send_completion_pending(bool value) { this->is_sender_completion_pending = value; } [[nodiscard]] FORCE_INLINE bool is_send_completion_pending() const { return this->is_sender_completion_pending; } - FORCE_INLINE bool eth_is_receiver_channel_send_done() const { return *this->channel_bytes_sent_address == 0; } - FORCE_INLINE bool eth_bytes_are_available_on_channel() const { return *this->channel_bytes_sent_address != 0; } - FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { return *this->channel_bytes_acked_address != 0; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_address; } - volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_address; } + FORCE_INLINE bool eth_is_receiver_channel_send_done() const { + ASSERT(buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + return *(this->channel_bytes_sent_addresses[buffer_index]) == 0; } + FORCE_INLINE bool eth_bytes_are_available_on_channel() const { + ASSERT(buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + return *(this->channel_bytes_sent_addresses[buffer_index]) != 0; + } + FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { + return *(this->channel_bytes_acked_addresses[buffer_index]) != 0; } + FORCE_INLINE void eth_clear_sender_channel_ack() const { *(this->channel_bytes_acked_addresses[buffer_index]) = 0; } + FORCE_INLINE void eth_receiver_channel_ack(uint32_t eth_transaction_ack_word_addr) const { + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)[0] == 1); + reinterpret_cast(eth_transaction_ack_word_addr)[1] = 1; + // Make sure we don't alias the erisc_info eth_channel_sync_t + ASSERT(eth_transaction_ack_word_addr != ((uint32_t)(this->channel_bytes_acked_addresses[buffer_index])) >> 4); + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)->bytes_sent != 0); + ASSERT(reinterpret_cast(eth_transaction_ack_word_addr)->receiver_ack == 1); + internal_::eth_send_packet( + 0, + eth_transaction_ack_word_addr >> 4, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + 1); + } + FORCE_INLINE void eth_receiver_channel_done() const { + *(this->channel_bytes_sent_addresses[buffer_index]) = 0; + *(this->channel_bytes_acked_addresses[buffer_index]) = 0; + internal_::eth_send_packet( + 0, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + ((uint32_t)(this->channel_bytes_sent_addresses[buffer_index])) >> 4, + 1); + } + + FORCE_INLINE void advance_buffer_index() { + if constexpr (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL == 1) { + return; + } else if constexpr (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL == 2) { + this->buffer_index = 1 - this->buffer_index; + } else if constexpr (((EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL) & (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1)) == 0) { + this->buffer_index = (buffer_index + 1) & (EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1); + } else { + this->buffer_index = (buffer_index == EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL - 1) ? 0 : buffer_index + 1; + } + + ASSERT(this->buffer_index < EDM_CONFIG::NUM_BUFFERS_PER_CHANNEL); + } + + volatile tt_l1_ptr uint32_t *const get_channel_bytes_sent_address() { return this->channel_bytes_sent_addresses[buffer_index]; } + volatile tt_l1_ptr uint32_t *const get_channel_bytes_acked_address() { return this->channel_bytes_acked_addresses[buffer_index]; } public: uint32_t eth_transaction_channel; // volatile tt_l1_ptr uint32_t *const local_semaphore_address; WorkerXY const *const worker_coords; - std::size_t const address; + std::array addresses; std::size_t const size_in_bytes; // Even for multiple workers, this address will be the same std::size_t const worker_semaphore_l1_address; uint32_t const num_workers; uint32_t num_messages_moved; - volatile tt_l1_ptr uint32_t *const channel_bytes_sent_address; - volatile tt_l1_ptr uint32_t *const channel_bytes_acked_address; + std::array channel_bytes_sent_addresses; + std::array channel_bytes_acked_addresses; const uint32_t total_num_messages_to_move; STATE state; edm_worker_index worker_index; + uint8_t buffer_index; bool is_sender_completion_pending; bool is_sender_side; }; @@ -280,6 +350,13 @@ class QueueIndexPointer { uint8_t wrap_around; }; + +/* + * Before any payload messages can be exchanged over the link, we must ensure that the other end + * of the link is ready to start sending/receiving messages. We perform a handshake to ensure that's + * case. Before handshaking, we make sure to clear any of the channel sync datastructures local + * to our core. + */ FORCE_INLINE void eth_setup_handshake(std::uint32_t handshake_register_address, bool is_sender) { reinterpret_cast(handshake_register_address)[4] = 1; reinterpret_cast(handshake_register_address)[5] = 1; @@ -318,32 +395,27 @@ FORCE_INLINE void initialize_transaction_buffer_addresses( ///////////////////////////////////////////// // SENDER SIDE HELPERS ///////////////////////////////////////////// - template FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sender_buffer_channel) { bool did_something = false; if (sender_buffer_channel.eth_is_receiver_channel_send_done()) { bool need_to_send_completion = sender_buffer_channel.is_send_completion_pending(); - if (!sender_buffer_channel.is_send_completion_pending() && !eth_txq_is_busy()) { + if (!eth_txq_is_busy()) { static constexpr std::size_t ETH_BYTES_TO_WORDS_SHIFT = 4; + ASSERT((uint32_t)sender_buffer_channel.get_channel_bytes_sent_address() == + ((uint32_t)sender_buffer_channel.get_buffer_address() + (uint32_t)sender_buffer_channel.get_current_payload_size() - (uint32_t)sizeof(eth_channel_sync_t))); + *sender_buffer_channel.get_channel_bytes_sent_address() = sender_buffer_channel.get_current_payload_size(); + *sender_buffer_channel.get_channel_bytes_acked_address() = 0; + eth_send_bytes_over_channel_payload_only( sender_buffer_channel.get_buffer_address(), sender_buffer_channel.get_remote_eth_buffer_address(), sender_buffer_channel.get_current_payload_size(), - sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size(), sender_buffer_channel.get_current_payload_size() >> ETH_BYTES_TO_WORDS_SHIFT); - sender_buffer_channel.set_send_completion_pending(true); - need_to_send_completion = true; - did_something = true; - } - - if (need_to_send_completion && !eth_txq_is_busy()) { - eth_send_payload_complete_signal_over_channel( - sender_buffer_channel.get_eth_transaction_channel(), sender_buffer_channel.get_current_payload_size()); - sender_buffer_channel.set_send_completion_pending(false); - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + sender_buffer_channel.advance_buffer_index(); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_WAITING_FOR_ETH); did_something = true; } } @@ -354,6 +426,7 @@ FORCE_INLINE bool sender_eth_send_data_sequence(ChannelBuffer &sende template FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( ChannelBuffer &sender_buffer_channel, uint32_t &num_senders_complete) { + bool channel_done = false; if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { channel_done = sender_buffer_channel.all_messages_moved(); @@ -367,7 +440,7 @@ FORCE_INLINE bool sender_notify_workers_if_buffer_available_sequence( sender_buffer_channel.increment_worker_semaphores(); if (!channel_done) { - sender_buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_WAITING_FOR_WORKER); } else { sender_buffer_channel.goto_state(ChannelBuffer::DONE); num_senders_complete++; @@ -384,19 +457,19 @@ FORCE_INLINE bool sender_eth_check_receiver_ack_sequence( bool transimission_acked_by_receiver = sender_buffer_channel.eth_is_receiver_channel_send_acked() || sender_buffer_channel.eth_is_receiver_channel_send_done(); if (transimission_acked_by_receiver) { - eth_clear_sender_channel_ack(sender_buffer_channel.get_eth_transaction_channel()); + sender_buffer_channel.eth_clear_sender_channel_ack(); sender_buffer_channel.increment_messages_moved(); - sender_buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + sender_buffer_channel.goto_state(ChannelBuffer::SENDER_SIGNALING_WORKER); + + // Don't need to guard as we can unconditionally notify the workers right away now that + // we're in the current state sender_notify_workers_if_buffer_available_sequence(sender_buffer_channel, num_senders_complete); - did_something = true; } return did_something; } -/* - * - */ + template FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( ChannelBuffer &sender_channel_buffer, uint32_t &num_senders_complete) { @@ -413,15 +486,14 @@ FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( bool read_finished = sender_channel_buffer.is_local_semaphore_full(); if (read_finished) { - // We can clear the semaphore, and wait for space on receiver - // sender_channel_buffer.clear_local_semaphore(); - sender_channel_buffer.goto_state(ChannelBuffer::READY_FOR_ETH_TRANSFER); - did_something = true; + sender_channel_buffer.goto_state(ChannelBuffer::SENDER_READY_FOR_ETH_TRANSFER); erisc::datamover::sender_eth_send_data_sequence(sender_channel_buffer); + did_something = true; } return did_something; + } ///////////////////////////////////////////// @@ -434,9 +506,10 @@ FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( template FORCE_INLINE bool receiver_eth_notify_workers_payload_available_sequence(ChannelBuffer &buffer_channel) { buffer_channel.clear_local_semaphore(); + uint32_t worker_semaphore_address = buffer_channel.worker_semaphore_l1_address; buffer_channel.increment_worker_semaphores(); - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_WORKER); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_WAITING_FOR_WORKER); return true; } @@ -453,8 +526,8 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( if (buffer_channel.eth_bytes_are_available_on_channel()) { if (!eth_txq_is_busy()) { - eth_receiver_channel_ack(buffer_channel.get_eth_transaction_channel(), eth_transaction_ack_word_addr); - buffer_channel.goto_state(ChannelBuffer::SIGNALING_WORKER); + buffer_channel.eth_receiver_channel_ack(eth_transaction_ack_word_addr); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_SIGNALING_WORKER); did_something = true; // FIXME: Decouple these so we can still signal workers even if eth command queue is busy @@ -467,6 +540,7 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( return did_something; } + /* * Does something if we are waiting for workers to complete their read and the read is complete: * - increment messages moved (that transfer is done) @@ -476,8 +550,7 @@ FORCE_INLINE bool receiver_eth_accept_payload_sequence( template FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( ChannelBuffer &buffer_channel, - uint32_t &num_receivers_complete, - uint32_t eth_transaction_complete_addr) { + uint32_t &num_receivers_complete) { bool did_something = false; bool workers_are_finished_reading = buffer_channel.is_local_semaphore_full(); @@ -492,9 +565,11 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( bool can_notify_sender_of_buffer_available = workers_are_finished_reading; if (can_notify_sender_of_buffer_available) { if (!eth_txq_is_busy()) { - eth_receiver_channel_done(buffer_channel.get_eth_transaction_channel()); + buffer_channel.eth_receiver_channel_done(); buffer_channel.increment_messages_moved(); + buffer_channel.advance_buffer_index(); + bool channel_done = false; if constexpr (EDM_CONFIG::TERMINATION_MODE == EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED) { channel_done = buffer_channel.all_messages_moved(); @@ -505,12 +580,11 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( } if (!channel_done) { - buffer_channel.goto_state(ChannelBuffer::WAITING_FOR_ETH); + buffer_channel.goto_state(ChannelBuffer::RECEIVER_WAITING_FOR_ETH); } else { buffer_channel.goto_state(ChannelBuffer::DONE); num_receivers_complete++; } - did_something = true; } } @@ -518,6 +592,7 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( return did_something; } + //////////////////////////// // DEPRECATED //////////////////////////// @@ -658,7 +733,6 @@ bool receiver_eth_accept_payload_sequence( if (!receive_pointers_full) { if (eth_bytes_are_available_on_channel(eth_receiver_ptr.index())) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)eth_receiver_ptr << "\n"; eth_receiver_channel_ack(eth_receiver_ptr.index(), eth_channel_sync_ack_addr); eth_receiver_ptr.increment(); did_something = true; @@ -683,8 +757,6 @@ FORCE_INLINE bool receiver_noc_read_worker_completion_check_sequence( bool writes_finished = ncrisc_noc_nonposted_writes_sent(noc_index); #endif if (writes_finished) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_ackptr - // << "\n"; noc_writer_buffer_ackptr.increment(); did_something = true; @@ -708,13 +780,9 @@ FORCE_INLINE bool receiver_eth_send_ack_to_sender_sequence( bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); // bool buffer_writes_flushed = ncrisc_noc_nonposted_writes_flushed(noc_index); if (buffer_writes_flushed) { - // DPRINT << "rx: accepting payload, sending receive ack on channel " << (uint32_t)noc_writer_buffer_wrptr - // << "\n"; eth_receiver_channel_done(eth_receiver_ackptr.index()); num_eth_sends_acked++; eth_receiver_ackptr.increment(); - // DPRINT << "rx: Sending eth ack. ackptr incrementing to " << (uint32_t)eth_receiver_ackptr.index() << - // "\n"; did_something = true; } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index 098103f5004..6017a9a4ef1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -109,8 +109,6 @@ struct sender_receiver_index_t { void kernel_main() { - // COMPILE TIME ARGS - // If true, will enable this erisc's sender functionality constexpr bool enable_sender_side = get_compile_time_arg_val(0) != 0; // If true, will enable this erisc's receiver functionality @@ -119,13 +117,22 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = static_cast(get_compile_time_arg_val(4)); - constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = static_cast(get_compile_time_arg_val(5)); - using EDM_CONFIG_T = erisc::datamover::EriscDatamoverConfig; + static constexpr bool use_compile_time_designated_handshake_sender = false;//get_compile_time_arg_val(6) != 0; + static constexpr bool is_handshake_sender = get_compile_time_arg_val(7) != 0; + + static constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(8); + static constexpr uint32_t chip_id = get_compile_time_arg_val(9); + + static_assert(num_buffers_per_channel > 0, "compile time argument [9]: num_buffers_per_channel must be > 0"); + + using EDM_CONFIG_T = erisc::datamover:: + EriscDatamoverConfig; using ChannelBufferT = erisc::datamover::ChannelBuffer; std::array buffer_channels; @@ -233,23 +240,23 @@ void kernel_main() { if constexpr (enable_sender_side) { ChannelBufferT ¤t_sender = buffer_channels[send_recv_index.real_index.sender]; switch (current_sender.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::SENDER_WAITING_FOR_WORKER: did_something_sender = erisc::datamover::sender_noc_receive_payload_ack_check_sequence(current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case ChannelBufferT::STATE::READY_FOR_ETH_TRANSFER: + case ChannelBufferT::STATE::SENDER_READY_FOR_ETH_TRANSFER: did_something_sender = erisc::datamover::sender_eth_send_data_sequence(current_sender); break; - case ChannelBufferT::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::SENDER_SIGNALING_WORKER: did_something_sender = erisc::datamover::sender_notify_workers_if_buffer_available_sequence( current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; break; - case ChannelBufferT::STATE::WAITING_FOR_ETH: + case ChannelBufferT::STATE::SENDER_WAITING_FOR_ETH: did_something_sender = erisc::datamover::sender_eth_check_receiver_ack_sequence(current_sender, num_senders_complete); senders_in_progress = senders_in_progress && num_senders_complete != sender_num_channels; @@ -266,19 +273,19 @@ void kernel_main() { ChannelBufferT ¤t_receiver = buffer_channels[send_recv_index.real_index.receiver]; switch (current_receiver.get_state()) { - case ChannelBufferT::STATE::WAITING_FOR_ETH: + case ChannelBufferT::STATE::RECEIVER_WAITING_FOR_ETH: did_something_receiver = erisc::datamover::receiver_eth_accept_payload_sequence(current_receiver, num_receivers_complete, eth_transaction_ack_word_addr); receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; break; - case ChannelBufferT::STATE::SIGNALING_WORKER: + case ChannelBufferT::STATE::RECEIVER_SIGNALING_WORKER: did_something_receiver = erisc::datamover::receiver_eth_notify_workers_payload_available_sequence(current_receiver); break; - case ChannelBufferT::STATE::WAITING_FOR_WORKER: + case ChannelBufferT::STATE::RECEIVER_WAITING_FOR_WORKER: did_something_receiver = erisc::datamover::receiver_noc_read_worker_completion_check_sequence( - current_receiver, num_receivers_complete, eth_transaction_complete_addr); + current_receiver, num_receivers_complete); receivers_in_progress = receivers_in_progress && num_receivers_complete != receiver_num_channels; break; @@ -301,24 +308,38 @@ void kernel_main() { } } - for (uint32_t s = 0; s < num_senders + num_receivers; s++ ) { - auto const& channel = buffer_channels[s]; - // We need to explicitly check for channel send done because we may - // advance sender channel state as soon as we receive an ack. Since we - // may be the last active channel, and advance to done state just from ack - // from the receiver ("I got a payload"), then we need to wait for done - // at the very end here. Otherise if we invoke another erisc op back-to-back, - // we may mess up transaction state because it's possible for receiver of this - // op to send the completion done after that one has already started. - uint32_t wait_count = 0; - uint32_t wait_max = 50000; - while(!channel.eth_is_receiver_channel_send_done()) { - wait_count++; - if (wait_count > wait_max) { - - DEBUG_STATUS("STK"); - run_routing(); + { + for (uint32_t s = 0; s < num_senders + num_receivers; s++) { + auto &channel = buffer_channels[s]; + // We need to explicitly check for channel send done because we may + // advance sender channel state as soon as we receive an ack. Since we + // may be the last active channel, and advance to done state just from ack + // from the receiver ("I got a payload"), then we need to wait for done + // at the very end here. Otherise if we invoke another erisc op back-to-back, + // we may mess up transaction state because it's possible for receiver of this + // op to send the completion done after that one has already started. + uint32_t wait_count = 0; + uint32_t wait_max = 5000000; + for (uint8_t buffer_index = 0; buffer_index < num_buffers_per_channel; buffer_index++) { wait_count = 0; + channel.buffer_index = buffer_index; + if (!channel.is_sender_side) { + if (!channel.eth_is_receiver_channel_send_done()) { + channel.eth_receiver_channel_done(); + } + } + } + for (uint8_t buffer_index = 0; buffer_index < num_buffers_per_channel; buffer_index++) { + if (channel.is_sender_side) { + while (!channel.eth_is_receiver_channel_send_done()) { + wait_count++; + if (wait_count > wait_max) { + DEBUG_STATUS("STK"); + run_routing(); + wait_count = 0; + } + } + } } } } diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index 0a6ff1a1fc0..4248fdae0e0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -58,10 +58,7 @@ struct WorkerTransferInfo { static std::size_t decide_number_of_edm_channels( ttnn::ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { - return ccl_op_config.is_input_sharded() ? std::min( - ccl_op_config.get_shard_grid_size(), - std::min(max_num_workers, enable_bidirectional ? 8 : 4)) - : std::min(max_num_workers, enable_bidirectional ? 8 : 4); + return std::min(max_num_workers, enable_bidirectional ? 8 : 4); } struct ReduceScatterWorkerArgBuilder { @@ -377,7 +374,7 @@ static void add_worker_config_to_edm_builders( device->worker_core_from_logical_core(worker_cores.at(w)).y)); } - // Get the expected message size in bytes for this worker + // Get the maximum message size we'd like to use. Not the actual packet size uint32_t expected_message_size_bytes = tensor_slicer.get_worker_slice_size_bytes(global_worker_idx); bool sender_enabled = true; // (!is_linear || !is_last_chip_in_chain); // update for linear @@ -385,7 +382,7 @@ static void add_worker_config_to_edm_builders( auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) : counter_clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding sender EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = sender_edm_builder.add_sender_channel( worker_sender_semaphore_id, 1, // cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), @@ -403,7 +400,7 @@ static void add_worker_config_to_edm_builders( ? counter_clockwise_edm_builders.at(link) : clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding receiver EDM channel"); - ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel( worker_receiver_semaphore_id, // Since we are in worker signal EDM termination mode, we don't need to set the actual number of @@ -726,17 +723,21 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); - uint32_t per_step_dim_size = input_tensor.get_legacy_shape()[scatter_split_dim] / ring_size; - uint32_t input_tensor_num_units_per_scatter_dim = - per_step_dim_size / tt::constants::TILE_WIDTH; // TODO: find the divisibility based on layout - TT_ASSERT(input_tensor_num_units_per_scatter_dim > 0); - uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_scatter_dim); + // // The input tensor is fractured by ring_size so we divi + std::size_t input_tensor_n_elems_per_slice = input_tensor.volume() / ring_size; + uint32_t input_tensor_num_units_per_tensor_slice = + input_tensor_n_elems_per_slice / (tt::constants::TILE_WIDTH * tt::constants::TILE_HEIGHT); + + TT_ASSERT(input_tensor_num_units_per_tensor_slice > 0); + uint32_t max_num_workers = std::min(8, input_tensor_num_units_per_tensor_slice); bool enable_bidirectional = true; auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional); log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - auto edm_termination_mode =ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + auto edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + + constexpr std::size_t num_buffers_per_channel = 1; // enable double buffering later auto const& edm_builder = create_erisc_datamover_builder( - num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); + num_edm_channels, op_config.get_page_size(), num_buffers_per_channel, buffer_sharing_mode, edm_termination_mode); TT_ASSERT(num_edm_channels > 0); Tensor const& local_chip_tensor = input_tensor; @@ -757,9 +758,6 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); - auto dim_slice_factors = tt::tt_metal::Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); - dim_slice_factors[-1] = ring_size; - CoreRangeSet const& worker_core_range = select_worker_cores(op_config, num_links, num_edm_channels); auto const& worker_cores = corerange_to_cores(worker_core_range, std::nullopt, true); @@ -870,7 +868,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( worker_receiver_kernels.push_back(receiver_kernel_id); worker_sender_kernels.push_back(sender_kernel_id); - TT_ASSERT(is_cb_buffering_sufficient_to_avoid_deadlock( + TT_FATAL(is_cb_buffering_sufficient_to_avoid_deadlock( worker_slice, cb_num_pages, cb_num_pages, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp index 7454a5bb0a3..ee5b8fc3fd8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp @@ -9,7 +9,6 @@ #include "dataflow_api.h" #include "debug/assert.h" #include "impl/buffers/buffer_constants.hpp" -#include "tensix_types.h" #include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" @@ -319,6 +318,7 @@ void kernel_main() { uint32_t n_pages = std::min(args.full_chunk_num_pages, worker_slice_n_pages - p); ASSERT(n_pages > 0); // Fetch from input tensor + read_wrapped_chunk_from_output_tensor( curr_tile_id, offset_into_worker_slice, diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index 3b933b91882..8e92d5dc568 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -9,6 +9,20 @@ #include #include +/* + * ------ ATTENTION ATTENTION ATTENTION ATTENTION ATTENTION ------ + * This file is intended to be useable across both host and device code. Therefore. + * + * DO NOT include any headers that are not host/device agnostic. + * DO NOT use any types that do not have fixed sizes across host and device. + * e.g. int32_t -> good (always 32 bits), int -> bad (size depends on platform) + * + * The reason for dual inclusion across host/device is because this code is used + * on device, but is further tested on host through gtests. This enables us to + * sweep functionality quickly and easily without involving end-to-end device kernel + * invocations and program creation. + */ + namespace ttnn { namespace ccl { @@ -38,12 +52,12 @@ struct WorkerXY { uint16_t x; uint16_t y; - WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} + constexpr WorkerXY(uint16_t x, uint16_t y) : x(x), y(y) {} - uint32_t to_uint32() const { return (y << 16) | x; } + constexpr uint32_t to_uint32() const { return (y << 16) | x; } - bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } - bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } + constexpr bool operator==(const WorkerXY &rhs) const { return x == rhs.x && y == rhs.y; } + constexpr bool operator!=(const WorkerXY &rhs) const { return !(*this == rhs); } }; struct coord_t { diff --git a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp index b98fff16766..ce0730f32da 100644 --- a/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 +#pragma once #include #include