From 5434660193a4682756268ce03c55b1040ebff283 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Sun, 10 Nov 2024 18:54:32 +0000 Subject: [PATCH] add initial fabric erisc data mover (EDM) impl Note only supports line topologies. Fabric mcast currently untested and is work in progress. In the mean-time for functional bringup of fabric EDM users, replace mcast with looped unicasts. The fabric Erisc Data Mover (EDM) is a component that can be used to build *very* simple linear topology fabrics. One of these EDMs can be instantiated on each ethernet link. It is built from 3 "channels" (though the definition of channel here is a little loose since two of the 3 will merge traffic, so this setup could be interpreted as a two channel setup.). This EDM implements packet based packets only - concepts like sockets are not supported. !! EDM Structure There are two sender channels and one receiver channel. "Sender" and "receiver" are relative to the Ethernet link, not the chip. Sender sends over the link and receiver receives from the link. Each sender channel serves a different purpose: - Sender channel 0 : Accepts packets from a workers on the local chip - Sender channel 1: accepts packets from an upstream EDM (i.e. an upstream EDM receiver channel on the same chip but different core) The receiver channel accepts packets from the Ethernet link and can do one (or both) of: - Write the packet to local chhip if it is the intended destination (unicast or mcast) - Forward the packet to the next chip in the line if: - Unicast and not the target chip - Multicast and this chip is in the multicast target range Sender channels will merge traffic into the remote EDM's receiver channel. !! Building a "Fabric" At present, only linear topologies are supported, and one per ethernet link along that given line. Below shows the intended connectivity of EDMs across chips in a hypothetical 3-chip fabric. For longer lines, the pattern would be extended. !! Connecting Workers to Channels As mentioned, only one worker can push to a given EDM sender channel at a time. In order to send to an EDM sender channel, the worker must establish a connection. The connection protocol is as follows and is started by the worker (the EDM is a slave in this protocol). *NOTE*: If multiple workers try to connect to the same EDM sender channel at the same time, the behavior is undefined. *NOTE*: Additionally, if a worker pushes packets to a channel it isn't connected to, behaviour is undefined. *NOTE*: Undefined == likely hang The `WorkerToFabricEdmSender` from `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp` provides an implementation of the connection protocol. `WorkerToFabricEdmSender` also acts as a wrapper around that protocol so workers can simply call `open()` to execute the connection protocol without having to manually reimplement for each kernel. !!! Protocol Worker: - Read from EDM sender channel buffer_index address - Required so that the worker knows where to write its first packet (since the channel may already contain packets from a previous connection) - Write worker core X/Y (NOC 0 based) - Write worker flow control semaphore L1 address EDM Sender Channel: - Check local connection valid semaphore for new established connection - When the connection semaphore indicates an active connection, the channel assumes all other relevant fields were correctly populated by the worker: - Worker core_x (on NOC 0) - Worker core_y (on NOC 0) - Worker flow control semaphore L1 address !! Tearing Down Connections Every worker is required to explicitly teardown its connection with the EDM before terminating. To do this, the worker must simply write a `0` to the EDM sender channel's connection semaphore address. As long as the worker has sent all of its packets to the EDM before this, then the EDM will guarantee to forward the messages correctly. At this point, it is safe for another kernel to establish a connection. !! Packet Structure Workers are responsible for populating packet headers before sending to the EDM. The packet header structure is defined in `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp`. !! Channel structure Each EDM channel is built from one or more buffers. Each buffer is the same size and can hold atmost one packet. Neighbouring packets occupy nehighouring buffers - with the exception of the last buffer index. The next packet after a write into the last buffer index will wrap around to the first buffer index. Even if packets do not occupy the full buffer, subsequent packets will always be written into the next logical buffer. A gap will exist in memory but the EDM will not send that padded data (unless it is more performant - which is possible in some special cases) A detail of the channel structure is omitted from the above description, namely the EDM <-> EDM flow control region for each buffer. Each buffer really looks something like this: &header-> |----------------| channel_base_address | header | &payload-> |----------------| | | | payload | | | &channel_sync-> |----------------| | channel_sync | // This is new ------------------ The "channel_sync" is an `eth_channel_sync_t` and is internal to the EDM implementation and is used to indicate packet transmission state between sender and receiver EDMs. The protocol for its use is: 1) Sender updates the field indicating new data: - set `bytes_sent` to a non-zero value indicating new data - clear `receiver_ack` to 0 - set `src_id` to the sender channel id so the receiver knows who the sender was (and where the ack should go) 2) Sender sends this channel sync to the corresponding location in the receiver channel (either in the same transmission as the packet or separately) 3) Receiver sees that `bytes_sent` is non-zero, indicating a new packet. It sends back an acknowledgement (first level): - set `receiver_ack` to non-zero *NOTE* IMPORTANT: To avoid a race, the receiver must be sure to send its channel_sync_t from a different address it uses as for the second level acknowledgement 3b) When sender receives an ack, it understands it can overwrite its local copy of the packet with new data 4) After receiver properly writes out its packet, it sends a second level acknowledgement, indicating it can receive new data into this specific buffer index: - clear the bytes_sent and receiver_ack fields and send back the `channel_sync` to the sender !! Sending Packets Sending a packet is done as follows: 1) Worker waits for flow control semaphore increment from EDM sender channel - Indicates there is space at the next buffer index for a packet 2) Worker performs a noc write of its packet to the EDM sender channel at the buffer index *NOTE*: !!!ALL PACKETS MUST CONTAIN DESTINATION NOC X/Y AS NOC 0 COORDINATES, REGARDLESS OF THE `noc_index` OF THE SENDER!!! For more diagrams, see `fabric_erisc_datamover.cpp` --- tests/ttnn/unit_tests/gtests/CMakeLists.txt | 5 +- .../erisc_datamover_sender_worker_reader.cpp | 1 - ...c_erisc_datamover_sender_worker_reader.cpp | 46 + ...c_erisc_datamover_sender_worker_sender.cpp | 196 ++++ ...erisc_data_mover_loopback_with_workers.cpp | 949 ++++++++++++++++++ tt_metal/hw/inc/ethernet/dataflow_api.h | 17 + tt_metal/hw/inc/ethernet/tunneling.h | 15 +- .../hw/inc/wormhole/noc_nonblocking_api.h | 1 + ttnn/CMakeLists.txt | 1 + ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 46 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 8 + .../ccl/erisc_datamover_builder.cpp | 197 ++++ .../ccl/erisc_datamover_builder.hpp | 146 +++ .../edm_fabric/edm_fabric_worker_adapters.hpp | 195 ++++ .../edm_fabric/fabric_edm_packet_header.hpp | 214 ++++ .../fabric_edm_packet_header_validate.hpp | 18 + .../fabric_edm_packet_transmission.hpp | 228 +++++ .../kernels/edm_fabric/fabric_edm_types.hpp | 56 ++ .../edm_fabric/fabric_erisc_datamover.cpp | 881 ++++++++++++++++ .../fabric_erisc_datamover_channels.hpp | 225 +++++ 20 files changed, 3432 insertions(+), 13 deletions(-) create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp create mode 100644 tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index 6bf761175437..865b29daefc5 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -8,7 +8,10 @@ set(TTNN_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_to_and_from_json.cpp ) -set(TTNN_CCL_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp) +set(TTNN_CCL_UNIT_TESTS_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_erisc_data_mover_with_workers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +) set(TTNN_TENSOR_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensor/common_tensor_test_utils.cpp diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp index 41d453e2793c..66662d02630e 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/erisc_datamover_sender_worker_reader.cpp @@ -38,7 +38,6 @@ void kernel_main() { } noc_async_read_barrier(); cb_push_back(cb_id_in0, pages_to_read); - // DPRINT << "SR " << num_pages_read << "\n"; } DPRINT << "SR DONE\n"; diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp new file mode 100644 index 000000000000..3437c819346a --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +void kernel_main() { + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t num_pages_to_read_total = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t pages_per_edm_buffer = 1; + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + const uint32_t src_addr = get_arg_val(0); + + const InterleavedAddrGen source_address_generator = { + .bank_base_address = src_addr, .page_size = page_size}; + + DPRINT << "swr: args " << + "\n\tsrc_addr="<(pages_per_edm_buffer, num_pages_to_read_total - num_pages_read); + cb_reserve_back(cb_id_in0, pages_to_read); + uint32_t local_l1_read_addr = get_write_ptr(cb_id_in0); + local_l1_read_addr += sizeof(tt::fabric::PacketHeader); + + for (uint32_t p = 0; p < pages_to_read; ++p) { + uint64_t src_noc_addr = get_noc_addr(num_pages_read + p, source_address_generator); + noc_async_read(src_noc_addr, local_l1_read_addr, page_size); + local_l1_read_addr += page_size; + } + noc_async_read_barrier(); + cb_push_back(cb_id_in0, pages_to_read); + } + +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp new file mode 100644 index 000000000000..e0cb2f50a172 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp @@ -0,0 +1,196 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" + + +struct unicast_mode { + uint8_t distance; +}; +struct mcast_mode { + uint8_t distance; + uint8_t range; +}; + +union transmit_config { + unicast_mode unicast; + mcast_mode mcast; +}; + +// Worker core - Data Movement Writer -> Sends to Erisc Data Mover (sender side). +// -> takes input from local cb and pushes to erisc L1 +void kernel_main() { + + // Test doesn't support multiple pages per send yet since we are writing + // to interleaved which will never have subsequent pages on the same core + // (and hence, able to share a packet header) + constexpr uint32_t num_pages_per_send = 1;//get_compile_time_arg_val(0); + constexpr uint32_t total_pages_to_send = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3); + constexpr bool dest_is_dram = get_compile_time_arg_val(4) != 0; + constexpr bool mcast_mode = get_compile_time_arg_val(5) == 1; + + size_t arg_idx = 0; + const uint32_t eth_l1_base_addr = get_arg_val(arg_idx++); + // erisc l1 semaphore address + const uint32_t eth_sender_l1_sem_addr = get_arg_val(arg_idx++); + volatile uint32_t* const writer_send_sem_addr = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + const uint32_t eth_sender_noc_x = get_arg_val(arg_idx++); + const uint32_t eth_sender_noc_y = get_arg_val(arg_idx++); + const uint32_t num_buffers_per_edm_channel = get_arg_val(arg_idx++); + + size_t edm_connection_handshake_addr = get_semaphore(get_arg_val(arg_idx++)); + size_t edm_worker_location_info_addr = get_arg_val(arg_idx++); + size_t edm_buffer_size_bytes = get_arg_val(arg_idx++); + size_t dest_addr = get_arg_val(arg_idx++); + volatile uint32_t* const last_message_semaphore_address = reinterpret_cast(get_semaphore(get_arg_val(arg_idx++))); + *last_message_semaphore_address = 0; + auto worker_buffer_index_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(writer_send_sem_addr)); + ASSERT(worker_buffer_index_semaphore_addr != reinterpret_cast(last_message_semaphore_address)); + + transmit_config config; + if (mcast_mode) { + config.mcast.distance = static_cast(get_arg_val(arg_idx++)); + config.mcast.range = static_cast(get_arg_val(arg_idx++)); + } else { + config.unicast.distance = static_cast(get_arg_val(arg_idx++)); + } + + const InterleavedAddrGen dest_addr_gen = { + .bank_base_address = dest_addr, .page_size = page_size}; + + + ASSERT(num_buffers_per_channel > 0); + auto sender = tt::fabric::WorkerToFabricEdmSender( + eth_sender_noc_x, + eth_sender_noc_y, + eth_l1_base_addr, + num_buffers_per_channel, + eth_sender_l1_sem_addr, + + edm_connection_handshake_addr, + edm_worker_location_info_addr, + edm_buffer_size_bytes, + writer_send_sem_addr, + worker_buffer_index_semaphore_addr + ); + + sender.open(); + + constexpr uint32_t cb_id_in0 = tt::CB::c_in0; + + // We need to normalize all noc addresses to be for a consistent noc ID + // so the remote sender core can correctly send the packet. In the future + // we can decide if it's better for the noc index to be embedded in the packet + // header (for now we don't do that) + constexpr size_t NORMALIZED_NOC_INDEX = 0; + + uint32_t buffer_index = 0; + cb_wait_front(cb_id_in0, 1); + auto a_packet_header_addr = get_read_ptr(cb_id_in0); + for (uint32_t p = 0; p < total_pages_to_send; p += num_pages_per_send) { + uint32_t pages_to_send = std::min(num_pages_per_send, total_pages_to_send - p); + sender.wait_for_empty_write_slot(); + cb_wait_front(cb_id_in0, pages_to_send); + + // bit of a hack to extract X/Y + const auto dest_noc_address = get_noc_addr(p, dest_addr_gen, 0, NORMALIZED_NOC_INDEX); + const size_t dest_addr = dest_noc_address & 0xFFFFFFFF; + const size_t dest_noc_x = (dest_noc_address >> NOC_ADDR_LOCAL_BITS) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t dest_noc_y = (dest_noc_address >> (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) & ((1 << NOC_ADDR_NODE_ID_BITS) - 1); + const size_t packet_size = page_size + sizeof(tt::fabric::PacketHeader); + + auto packet_addr = get_read_ptr(cb_id_in0); + auto &packet_header = *reinterpret_cast(packet_addr); + if constexpr (mcast_mode) { + packet_header.to_write() + .to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y) + }); + packet_header.reserved2 = 0x1111; // debug only + } else { + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{config.unicast.distance}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, + (pages_to_send * page_size) + sizeof(tt::fabric::PacketHeader), + static_cast(dest_noc_x), + static_cast(dest_noc_y) + }); + packet_header.reserved2 = 0x1111; // debug only + } + + uint64_t buffer_address = sender.edm_buffer_addr + (*sender.buffer_index_ptr * (sender.buffer_size_bytes + sizeof(eth_channel_sync_t))); + sender.send_payload_blocking_from_address(packet_addr, packet_size); + noc_async_writes_flushed(); + cb_pop_front(cb_id_in0, pages_to_send); + } + + if constexpr (!mcast_mode) { + sender.wait_for_empty_write_slot(); + + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + ASSERT(*last_message_semaphore_address == 0); + packet_header.reserved = 0xE; + packet_header.reserved2 = 0xFFFF; + packet_header.to_atomic_inc(); + packet_header.to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{1}); + packet_header.to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader( + reinterpret_cast(last_message_semaphore_address), + 1, + 32, + my_x[0], + my_y[0] + )); + + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + + noc_semaphore_wait(last_message_semaphore_address, 1); + } + + bool closed = false; + size_t num_endpoints_to_terminate = get_arg_val(arg_idx++); + for (size_t i = 0; i < num_endpoints_to_terminate; i++) { + size_t edm_noc_x = get_arg_val(arg_idx++); + size_t edm_noc_y = get_arg_val(arg_idx++); + size_t distance = get_arg_val(arg_idx++); + size_t termination_addr = get_arg_val(arg_idx++); + + if (!closed && distance == 0) { + closed = true; + sender.close(); + } + if (distance == 0) { + noc_inline_dw_write(get_noc_addr(edm_noc_x, edm_noc_y, termination_addr), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + } else { + auto &packet_header = *reinterpret_cast(a_packet_header_addr); + reinterpret_cast(a_packet_header_addr)[sizeof(tt::fabric::PacketHeader) >> 2] = tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE; + sender.wait_for_empty_write_slot(); + packet_header.to_write() + .to_chip_unicast(tt::fabric::UnicastRoutingCommandHeader{static_cast(distance - 1)}) + .to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + termination_addr, + sizeof(tt::fabric::PacketHeader) + sizeof(uint32_t), + static_cast(edm_noc_x), + static_cast(edm_noc_y) + }); + sender.send_payload_blocking_from_address(a_packet_header_addr, packet_header.get_payload_size_including_header()); + noc_async_writes_flushed(); + } + } + if (!closed) { + sender.close(); + } + +} diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp new file mode 100644 index 000000000000..1cb446d470ea --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -0,0 +1,949 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "device/tt_arch_types.h" +#include "gtest/gtest.h" +// #include "tt_backend_api_types.hpp" +#include "tt_metal/common/core_coord.hpp" +#include "tt_metal/common/math.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/kernels/kernel.hpp" +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/env_vars.hpp" +#include "tt_metal/test_utils/print_helpers.hpp" +#include "tt_metal/test_utils/stimulus.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +using namespace tt; +using namespace tt::test_utils; +using namespace tt::test_utils::df; + +class T3000TestDevice { + public: + T3000TestDevice() : device_open(false) { + arch_ = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + + num_devices_ = tt::tt_metal::GetNumAvailableDevices(); + if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() >= 4 and + tt::tt_metal::GetNumPCIeDevices() >= 1) { + std::vector ids(num_devices_, 0); + std::iota(ids.begin(), ids.end(), 0); + devices_ = tt::tt_metal::detail::CreateDevices(ids); + + } else { + TT_THROW("This suite can only be run on T3000 Wormhole devices"); + } + device_open = true; + } + ~T3000TestDevice() { + if (device_open) { + TearDown(); + } + } + + void TearDown() { + device_open = false; + for (auto [device_id, device_ptr] : devices_) { + tt::tt_metal::CloseDevice(device_ptr); + } + } + + std::map devices_; + tt::ARCH arch_; + size_t num_devices_; + + private: + bool device_open; +}; + +struct BankedConfig { + size_t num_pages; + size_t size_bytes; + size_t page_size_bytes; + BufferType input_buffer_type; // = BufferType::L1; + BufferType output_buffer_type; // = BufferType::L1; + tt::DataFormat l1_data_format; // = tt::DataFormat::Float16_b; +}; + +struct KernelXY { + uint16_t x; + uint16_t y; + + uint32_t to_uint32() const { return y << 16 | x; } +}; + +struct edm_termination_info_t { + uint32_t distance; + uint32_t edm_noc_x; + uint32_t edm_noc_y; + uint32_t termination_addr; +}; + +enum Correctness { Correct, Incorrect }; + +struct EthLinkBuilder { + ttnn::ccl::FabricEriscDatamoverBuilder sender_edm_builder; // chip_0_edm_builder, + ttnn::ccl::FabricEriscDatamoverBuilder receiver_edm_builder; // chip_0_edm_builder, + tt_xy_pair sender_core; + tt_xy_pair receiver_core; + size_t downstream_edm_buffer_index_semaphore_id; +}; + +Correctness run_output_check( + std::vector const& all_zeros, + std::vector const& inputs, + std::shared_ptr output_buffer) { + constexpr bool debug_mode = true; + std::vector readback_data_vec; // init to 0 data for easier debug + readback_data_vec.reserve(all_zeros.size()); + std::fill(readback_data_vec.begin(), readback_data_vec.end(), 0); + + tt_metal::detail::ReadFromBuffer(output_buffer, readback_data_vec); + log_info(tt::LogTest, "Checking outputs"); + if (readback_data_vec.size() != inputs.size()) { + log_error(tt::LogTest, "Output size mismatch: expected {} got {}", inputs.size(), readback_data_vec.size()); + return Correctness::Incorrect; + } + bool pass = (readback_data_vec == inputs); + TT_ASSERT( + std::any_of(inputs.begin(), inputs.end(), [](uint32_t x) { return x != 0; }), + "Input buffer expected to not be all 0"); + if (not pass) { + log_error("Output mismatch"); + if (debug_mode) { + std::size_t num_printed_mismatches = 0; + for (size_t i = 0; i < readback_data_vec.size() && num_printed_mismatches < 64; i++) { + if (readback_data_vec[i] != inputs[i]) { + log_error("[{}]: expected {} got {}", i, inputs[i], readback_data_vec[i]); + num_printed_mismatches++; + } + } + log_error("... (remaining mismatches omitted)"); + } + } + return Correctness::Correct; +}; + +void run_programs(std::vector& programs, std::vector const& devices) { + EXPECT_EQ(programs.size(), devices.size()); + const size_t num_programs = programs.size(); + try { + for (size_t i = 0; i < num_programs; i++) { + tt::tt_metal::detail::CompileProgram(devices.at(i), programs.at(i)); + } + } catch (std::exception& e) { + log_error("Failed compile: {}", e.what()); + throw e; + } + + log_info(tt::LogTest, "Running..."); + + std::vector threads; + threads.reserve(num_programs); + if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE")) { + for (size_t i = 0; i < num_programs; i++) { + threads.emplace_back(std::thread([&] { tt_metal::detail::LaunchProgram(devices.at(i), programs.at(i)); })); + } + + std::ranges::for_each(threads, [](std::thread& t) { t.join(); }); + } else { + for (size_t i = 0; i < num_programs; i++) { + tt_metal::EnqueueProgram(devices.at(i)->command_queue(), programs.at(i), false); + } + + log_debug(tt::LogTest, "Calling Finish"); + for (size_t i = 0; i < num_programs; i++) { + tt_metal::Finish(devices.at(i)->command_queue()); + } + } +} + +std::tuple, std::vector> build_input_buffer( + Device* first_device, size_t tensor_size_bytes, BankedConfig const& test_config) { + auto inputs = std::vector(tensor_size_bytes / sizeof(uint32_t), 0); + std::iota(inputs.begin(), inputs.end(), 0); + + // Input buffer + auto local_input_buffer = CreateBuffer(InterleavedBufferConfig{ + first_device, test_config.size_bytes, test_config.page_size_bytes, test_config.input_buffer_type}); + tt_metal::detail::WriteToBuffer(local_input_buffer, inputs); + return {local_input_buffer, inputs}; +} + +struct EthLinkHop { + CoreCoord hop_src; + CoreCoord hop_dest; +}; + +struct ChipConnection { + std::vector links; +}; + +struct unicast_send { + size_t distance; +}; +struct mcast_send { + size_t distance; + size_t range; +}; + + +using mode_variant_t = std::variant; + +static constexpr size_t PACKET_HEADER_SIZE_BYTES = sizeof(tt::fabric::PacketHeader); +void generate_sender_worker_kernels( + Program& program, + Device* device, + CoreCoord const& worker_core, + CoreCoord const& edm_core, + ttnn::ccl::SenderWorkerAdapterSpec const& worker_fabric_connection, + mode_variant_t const& mode, + std::size_t edm_buffer_size, + uint32_t page_plus_header_size, + uint32_t num_pages_total, + uint32_t num_pages_per_edm_buffer, + uint32_t local_worker_fabric_semaphore_id, + uint32_t local_worker_last_message_semaphore_id, + uint32_t dram_input_buffer_base_addr, // remote_output_buffers.at(i)->address(); + bool src_is_dram, + uint32_t dram_output_buffer_base_addr, + bool dest_is_dram, + uint32_t worker_buffer_index_semaphore_id, + // farthest to closest + std::vector const& edm_termination_infos) { + std::vector sender_worker_reader_compile_args{ + src_is_dram, // + num_pages_total, // + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + num_pages_per_edm_buffer}; + std::vector sender_worker_reader_runtime_args{dram_input_buffer_base_addr}; + + log_info(tt::LogTest, "\tSenderReader CT Args"); + for (auto const& arg : sender_worker_reader_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderReader RT Args"); + for (auto const& arg : sender_worker_reader_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + std::vector sender_worker_writer_compile_args{ + num_pages_per_edm_buffer, + num_pages_total, + page_plus_header_size - PACKET_HEADER_SIZE_BYTES, + worker_fabric_connection.num_buffers_per_channel, + dest_is_dram, + std::holds_alternative(mode) ? 1 : 0}; + log_info(tt::LogTest, "worker_fabric_connection.edm_l1_sem_addr: {}", worker_fabric_connection.edm_l1_sem_addr); + log_info(tt::LogTest, "worker_buffer_index_semaphore_id: {}", worker_buffer_index_semaphore_id); + log_info(tt::LogTest, "last_message_semaphore_address: {}", local_worker_last_message_semaphore_id); + log_info( + tt::LogTest, + "Sender communicating with EDM: x={}, y={}", + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y); + std::vector sender_worker_writer_runtime_args{ + worker_fabric_connection.edm_buffer_base_addr, + worker_fabric_connection.edm_l1_sem_addr, + local_worker_fabric_semaphore_id, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).x, + (uint32_t)device->ethernet_core_from_logical_core(edm_core).y, + worker_fabric_connection.num_buffers_per_channel, + + worker_fabric_connection.edm_connection_handshake_addr, + worker_fabric_connection.edm_worker_location_info_addr, + edm_buffer_size, + dram_output_buffer_base_addr, + local_worker_last_message_semaphore_id, + worker_buffer_index_semaphore_id}; + + if (std::holds_alternative(mode)) { + sender_worker_writer_runtime_args.push_back(std::get(mode).distance); + sender_worker_writer_runtime_args.push_back(std::get(mode).range); + } else { + sender_worker_writer_runtime_args.push_back(std::get(mode).distance); + } + + sender_worker_writer_runtime_args.push_back(edm_termination_infos.size()); + for (auto const& info : edm_termination_infos) { + sender_worker_writer_runtime_args.push_back(info.edm_noc_x); + sender_worker_writer_runtime_args.push_back(info.edm_noc_y); + sender_worker_writer_runtime_args.push_back(info.distance); + sender_worker_writer_runtime_args.push_back(info.termination_addr); + log_info( + tt::LogTest, + "EDM termination info: x={}, y={}, distance={}, termination_addr={}", + info.edm_noc_x, + info.edm_noc_y, + info.distance, + info.termination_addr); + } + + uint32_t src0_cb_index = CB::c_in0; + log_info(tt::LogTest, "\tSenderWriter CT Args"); + for (auto const& arg : sender_worker_writer_compile_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + log_info(tt::LogTest, "\tSenderWriter RT Args"); + for (auto const& arg : sender_worker_writer_runtime_args) { + log_info(tt::LogTest, "\t\t{}", arg); + } + + // Just want a dummy DF + tt::DataFormat df = (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 1024 ? tt::DataFormat::Bfp8 + : (page_plus_header_size - PACKET_HEADER_SIZE_BYTES) == 2048 ? tt::DataFormat::Float16 + : tt::DataFormat::Float32; + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(2 * num_pages_per_edm_buffer * page_plus_header_size, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_plus_header_size); + CBHandle sender_workers_cb = CreateCircularBuffer(program, worker_core, cb_src0_config); + auto sender_worker_reader_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_reader.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = sender_worker_reader_compile_args}); + auto sender_worker_writer_kernel = tt_metal::CreateKernel( + program, + "tests/ttnn/unit_tests/gtests/ccl/kernels/fabric_erisc_datamover_sender_worker_sender.cpp", + worker_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .compile_args = sender_worker_writer_compile_args}); + tt_metal::SetRuntimeArgs(program, sender_worker_reader_kernel, worker_core, sender_worker_reader_runtime_args); + tt_metal::SetRuntimeArgs(program, sender_worker_writer_kernel, worker_core, sender_worker_writer_runtime_args); +} + +bool RunLoopbackTest( + tt_metal::Device* sender_device, + tt_metal::Device* receiver_device, + + const CoreCoord& eth_sender_core, + const CoreCoord& eth_receiver_core, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram) { + std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); + std::size_t tensor_size_bytes = num_pages_total * page_size; + + std::vector programs(2); + auto& sender_program = programs.at(0); + auto& receiver_program = programs.at(1); + + std::vector worker_cores = {CoreCoord(0, 0)}; + + auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(sender_program, worker_cores.at(0), 0); + + std::optional chip0_receiver_channel_downstream_flow_control_semaphore_id = std::nullopt; + auto chip0_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + auto chip0_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(sender_program, eth_sender_core, 0, CoreType::ETH); + + std::optional chip1_receiver_channel_downstream_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + auto chip1_downstream_edm_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(receiver_program, eth_receiver_core, 0, CoreType::ETH); + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + auto [local_input_buffer, inputs] = build_input_buffer(sender_device, tensor_size_bytes, test_config); + + std::vector all_zeros(inputs.size(), 0); + auto local_output_buffer = CreateBuffer(InterleavedBufferConfig{ + sender_device, test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type}); + + tt_metal::detail::WriteToBuffer(local_output_buffer, all_zeros); + + auto local_input_buffer_address = local_input_buffer->address(); + auto local_output_buffer_address = local_output_buffer->address(); + + //////////////////////////////////////////////////////////////////////////// + // EDM Builder Setup + //////////////////////////////////////////////////////////////////////////// + + static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; + const size_t local_chip_id = 0; + const size_t remote_chip_id = 1; + auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + sender_device->ethernet_core_from_logical_core(eth_sender_core).x, + sender_device->ethernet_core_from_logical_core(eth_sender_core).y, + local_chip_id, + remote_chip_id, + + chip0_receiver_channel_downstream_flow_control_semaphore_id, + chip0_sender_channel_0_flow_control_semaphore_id, + chip0_sender_channel_1_flow_control_semaphore_id, + chip0_sender_channel_0_connection_semaphore_id, + chip0_sender_channel_1_connection_semaphore_id, + + edm_config); + auto chip0_worker_fabric_connection = chip_0_edm_builder.build_connection_to_worker_channel(); + auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + receiver_device->ethernet_core_from_logical_core(eth_receiver_core).x, + receiver_device->ethernet_core_from_logical_core(eth_receiver_core).y, + remote_chip_id, + local_chip_id, + + chip1_receiver_channel_downstream_flow_control_semaphore_id, // this is the receiver channel's local sem for + // flow controlling with downstream fabric sender + chip1_sender_channel_0_flow_control_semaphore_id, + chip1_sender_channel_1_flow_control_semaphore_id, + chip1_sender_channel_0_connection_semaphore_id, + chip1_sender_channel_1_connection_semaphore_id, + + edm_config); + // Create the loopback connection on the second device + chip_1_edm_builder.connect_to_downstream_edm(chip_1_edm_builder, chip1_downstream_edm_buffer_index_semaphore_id); + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + const std::size_t pages_per_send = + (chip0_worker_fabric_connection.buffer_size_bytes - PACKET_HEADER_SIZE_BYTES) / page_size; + auto const& worker_core = worker_cores.at(0); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); + + std::vector const& edm_termination_infos = { + {1, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).x, + sender_device->ethernet_core_from_logical_core(eth_receiver_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}, + {0, + sender_device->ethernet_core_from_logical_core(eth_sender_core).x, + sender_device->ethernet_core_from_logical_core(eth_sender_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}}; + + generate_sender_worker_kernels( + sender_program, + sender_device, + worker_core, + eth_sender_core, + chip0_worker_fabric_connection, + unicast_send{1}, + edm_buffer_size, + page_plus_header_size, + num_pages_total, + pages_per_send, + local_worker_fabric_semaphore_id, + local_worker_last_message_semaphore_id, + local_input_buffer_address, + src_is_dram, + local_output_buffer_address, + dest_is_dram, + worker_buffer_index_semaphore_id, + edm_termination_infos); + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + auto local_edm_kernel = + ttnn::ccl::generate_edm_kernel(sender_program, sender_device, chip_0_edm_builder, eth_sender_core, NOC::NOC_0); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + receiver_program, receiver_device, chip_1_edm_builder, eth_receiver_core, NOC::NOC_0); + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + run_programs(programs, {sender_device, receiver_device}); + log_info(tt::LogTest, "Reading back outputs"); + + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + pass &= run_output_check(all_zeros, inputs, local_output_buffer) == Correctness::Correct; + } + return pass; +} + +bool RunLineFabricTest( + std::vector devices, + std::vector const& hops, + + const size_t mcast_first_chip, + const size_t mcast_last_chip, + + const uint32_t page_size, + const uint32_t num_pages_total, + bool src_is_dram, + bool dest_is_dram) { + std::size_t page_plus_header_size = page_size + sizeof(tt::fabric::PacketHeader); + std::size_t tensor_size_bytes = num_pages_total * page_size; + + static constexpr std::size_t edm_buffer_size = 4096 + PACKET_HEADER_SIZE_BYTES; + const size_t local_chip_id = 0; + const size_t remote_chip_id = 1; + const size_t num_hops = hops.size(); + auto programs = std::vector(devices.size()); + + std::vector worker_cores = {CoreCoord(0, 0)}; + + // Generate inputs + //////////////////////////////////////////////////////////////////////////// + // SETUP THE INPUT CB + //////////////////////////////////////////////////////////////////////////// + BankedConfig test_config = BankedConfig{ + .num_pages = num_pages_total, + .size_bytes = tensor_size_bytes, + .page_size_bytes = page_size, + .input_buffer_type = src_is_dram ? BufferType::DRAM : BufferType::L1, + .output_buffer_type = dest_is_dram ? BufferType::DRAM : BufferType::L1, + .l1_data_format = tt::DataFormat::Float16_b}; + + // Input buffer + auto [local_input_buffer, inputs] = build_input_buffer(devices[0], tensor_size_bytes, test_config); + auto local_input_buffer_address = local_input_buffer->address(); + + std::vector all_zeros(inputs.size(), 0); + // output buffers + TT_ASSERT(mcast_first_chip <= mcast_last_chip, "mcast_first_chip must be less than or equal to mcast_last_chip"); + TT_ASSERT(mcast_last_chip < devices.size(), "mcast_last_chip must be less than the number of devices"); + std::vector> output_buffers; + output_buffers.reserve(mcast_last_chip - mcast_first_chip + 1); + for (size_t i = mcast_first_chip; i <= mcast_last_chip; i++) { + output_buffers.push_back(CreateBuffer(InterleavedBufferConfig{ + devices.at(i), test_config.size_bytes, test_config.page_size_bytes, test_config.output_buffer_type})); + tt_metal::detail::WriteToBuffer(output_buffers.back(), all_zeros); + } + auto local_output_buffer_address = output_buffers[0]->address(); + bool all_same_addr = std::ranges::all_of(output_buffers, [local_output_buffer_address](auto const& buffer) { + return buffer->address() == local_output_buffer_address; + }); + TT_ASSERT(all_same_addr, "All output buffers must have the same address"); + + //////////////////////////////////////////////////////////////////////////// + // Setup Semaphores and Builders + //////////////////////////////////////////////////////////////////////////// + std::vector edm_hop_builders; + edm_hop_builders.reserve(num_hops); + + auto local_worker_fabric_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto local_worker_last_message_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto worker_buffer_index_semaphore_id = tt::tt_metal::CreateSemaphore(programs[0], worker_cores.at(0), 0); + auto const& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + + for (size_t i = 0; i < num_hops; i++) { + const auto sender_device = devices.at(i); + const auto receiver_device = devices.at(i + 1); + const auto edm_sender_core = hops.at(i).hop_src; + const auto edm_receiver_core = hops.at(i).hop_dest; + + const std::optional chip0_receiver_channel_downstream_flow_control_semaphore_id = std::nullopt; + const auto chip0_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + const auto chip0_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i), hops.at(i).hop_src, 0, CoreType::ETH); + + std::optional chip1_receiver_channel_downstream_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_0_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_1_flow_control_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_0_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_sender_channel_1_connection_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + const auto chip1_downstream_edm_buffer_index_semaphore_id = + tt::tt_metal::CreateSemaphore(programs.at(i + 1), hops.at(i).hop_dest, 0, CoreType::ETH); + + auto chip_0_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + sender_device->ethernet_core_from_logical_core(edm_sender_core).x, + sender_device->ethernet_core_from_logical_core(edm_sender_core).y, + local_chip_id, + remote_chip_id, + + chip0_receiver_channel_downstream_flow_control_semaphore_id, + chip0_sender_channel_0_flow_control_semaphore_id, + chip0_sender_channel_1_flow_control_semaphore_id, + chip0_sender_channel_0_connection_semaphore_id, + chip0_sender_channel_1_connection_semaphore_id, + + edm_config); + auto chip_1_edm_builder = ttnn::ccl::FabricEriscDatamoverBuilder( + receiver_device->ethernet_core_from_logical_core(edm_receiver_core).x, + receiver_device->ethernet_core_from_logical_core(edm_receiver_core).y, + remote_chip_id, + local_chip_id, + + chip1_receiver_channel_downstream_flow_control_semaphore_id, // this is the receiver channel's local sem + // for flow controlling with downstream fabric + // sender + chip1_sender_channel_0_flow_control_semaphore_id, + chip1_sender_channel_1_flow_control_semaphore_id, + chip1_sender_channel_0_connection_semaphore_id, + chip1_sender_channel_1_connection_semaphore_id, + + edm_config); + + edm_hop_builders.push_back(EthLinkBuilder{ + .sender_edm_builder = std::move(chip_0_edm_builder), + .receiver_edm_builder = std::move(chip_1_edm_builder), + .sender_core = edm_sender_core, + .receiver_core = edm_receiver_core, + .downstream_edm_buffer_index_semaphore_id = chip1_downstream_edm_buffer_index_semaphore_id}); + } + + for (size_t i = 0; i < num_hops - 1; i++) { + edm_hop_builders.at(i).receiver_edm_builder.connect_to_downstream_edm( + edm_hop_builders.at(i + 1).sender_edm_builder, + edm_hop_builders.at(i).downstream_edm_buffer_index_semaphore_id); + } + + //////////////////////////////////////////////////////////////////////////// + // Build Workers + //////////////////////////////////////////////////////////////////////////// + log_info(tt::LogTest, "Generating local_sender -> remote_receiver workers"); + auto const& worker_core = worker_cores.at(0); + log_info(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); + + std::vector edm_termination_infos; + edm_termination_infos.reserve(num_hops * 2); + for (int i = num_hops - 1; i >= 0; i--) { + const std::size_t distance_receiver = i + 1; + const auto& receiver_core = hops.at(i).hop_dest; + auto receiver_device = devices.at(i + 1); + edm_termination_infos.push_back( + {distance_receiver, + receiver_device->ethernet_core_from_logical_core(receiver_core).x, + receiver_device->ethernet_core_from_logical_core(receiver_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + const std::size_t distance_sender = i; + const auto& sender_core = hops.at(i).hop_src; + auto sender_device = devices.at(i); + edm_termination_infos.push_back( + {distance_sender, + sender_device->ethernet_core_from_logical_core(sender_core).x, + sender_device->ethernet_core_from_logical_core(sender_core).y, + ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + }; + + auto chip0_worker_fabric_connection = edm_hop_builders[0].sender_edm_builder.build_connection_to_worker_channel(); + const std::size_t pages_per_send = + (chip0_worker_fabric_connection.buffer_size_bytes - PACKET_HEADER_SIZE_BYTES) / page_size; + generate_sender_worker_kernels( + programs[0], + devices[0], + worker_core, + hops[0].hop_src, + chip0_worker_fabric_connection, + mcast_send{mcast_first_chip - 1, mcast_last_chip - mcast_first_chip}, + edm_buffer_size, + page_plus_header_size, + num_pages_total, + pages_per_send, + local_worker_fabric_semaphore_id, + local_worker_last_message_semaphore_id, + local_input_buffer_address, + src_is_dram, + local_output_buffer_address, + dest_is_dram, + worker_buffer_index_semaphore_id, + edm_termination_infos); + + //////////////////////////////////////////////////////////////////////////// + // Build EDMs + //////////////////////////////////////////////////////////////////////////// + for (std::size_t i = 0; i < num_hops; i++) { + auto local_edm_kernel = ttnn::ccl::generate_edm_kernel( + programs.at(i), // sender_program, + devices.at(i), // sender_device, + edm_hop_builders.at(i).sender_edm_builder, // chip_0_edm_builder, + edm_hop_builders.at(i).sender_core, // eth_sender_core, + NOC::NOC_0); + + auto remote_edm_kernel = ttnn::ccl::generate_edm_kernel( + programs.at(i + 1), + devices.at(i + 1), + edm_hop_builders.at(i).receiver_edm_builder, + edm_hop_builders.at(i).receiver_core, + NOC::NOC_0); + } + + //////////////////////////////////////////////////////////////////////////// + // Compile and Execute Application + //////////////////////////////////////////////////////////////////////////// + + run_programs(programs, devices); + log_info(tt::LogTest, "Reading back outputs"); + + bool pass = true; + constexpr bool enable_check = true; + if constexpr (enable_check) { + + for (size_t i = mcast_first_chip; i <= mcast_last_chip; i++) { + bool compare_with_input = (mcast_first_chip <= i && i <= mcast_last_chip); + auto &golden_tensor = compare_with_input ? inputs : all_zeros; + pass &= run_output_check(all_zeros, golden_tensor, output_buffers.at(i)) == Correctness::Correct; + } + } + + return pass; +} + +// RESUME HERE AND IMPLEMENT MCAST TEST +int TestLineFabricEntrypoint( + const size_t mcast_first_chip, + const size_t mcast_last_chip, + const uint32_t page_size, + const uint32_t num_pages_total, + const bool src_is_dram, + const bool dest_is_dram) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on N300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + T3000TestDevice test_fixture; + + // build a line of devices + static constexpr size_t fabric_line_length = 4; + std::vector devices = { + test_fixture.devices_.at(0), + test_fixture.devices_.at(1), + test_fixture.devices_.at(2), + test_fixture.devices_.at(3)}; + std::vector fabric_hops; + fabric_hops.reserve(fabric_line_length); + + for (size_t hop = 0; hop < fabric_line_length - 1; hop++) { + auto src_device = devices[hop]; + auto dest_device = devices[hop + 1]; + auto target_dest_device_id = devices[hop + 1]->id(); + log_info(tt::LogTest, "Finding links between device {} and {}", src_device->id(), dest_device->id()); + auto const& active_eth_cores = src_device->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + + chip_id_t dest_device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end, "Error"); + std::tie(dest_device_id, eth_receiver_core) = + src_device->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (dest_device_id != target_dest_device_id); + TT_ASSERT(dest_device_id == target_dest_device_id); + + fabric_hops.push_back({eth_sender_core, eth_receiver_core}); + } + + bool success = false; + try { + success = RunLineFabricTest( + devices, + fabric_hops, + + mcast_first_chip, + mcast_last_chip, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram); + + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +int TestLoopbackEntrypoint( + const uint32_t page_size, const uint32_t num_pages_total, const bool src_is_dram, const bool dest_is_dram) { + // argv[0]: program + // argv[1]: buffer_size_bytes + // argv[2]: num_loops + + auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + auto num_devices = tt::tt_metal::GetNumAvailableDevices(); + if (num_devices < 4) { + log_info("This test can only be run on N300 devices"); + return 0; + } + if (arch == tt::ARCH::GRAYSKULL) { + log_info("Test must be run on WH"); + return 0; + } + + T3000TestDevice test_fixture; + + const auto& device_0 = test_fixture.devices_.at(0); + + auto const& active_eth_cores = device_0->get_active_ethernet_cores(true); + auto eth_sender_core_iter = active_eth_cores.begin(); + auto eth_sender_core_iter_end = active_eth_cores.end(); + chip_id_t device_id = std::numeric_limits::max(); + tt_xy_pair eth_receiver_core; + bool initialized = false; + tt_xy_pair eth_sender_core; + do { + TT_FATAL(eth_sender_core_iter != eth_sender_core_iter_end, "Error"); + std::tie(device_id, eth_receiver_core) = device_0->get_connected_ethernet_core(*eth_sender_core_iter); + eth_sender_core = *eth_sender_core_iter; + eth_sender_core_iter++; + } while (device_id != 1); + TT_ASSERT(device_id == 1); + const auto& device_1 = test_fixture.devices_.at(device_id); + + bool success = false; + try { + success = RunLoopbackTest( + device_0, + device_1, + + eth_sender_core, + eth_receiver_core, + + page_size, + num_pages_total, + src_is_dram, + dest_is_dram); + } catch (std::exception& e) { + log_error("Caught exception: {}", e.what()); + test_fixture.TearDown(); + return -1; + } + + test_fixture.TearDown(); + + return success ? 0 : -1; +} + +//////////////////////////////////////////////////////////////////// +/// MESSAGE COUNT TERMINATION MODE +//////////////////////////////////////////////////////////////////// + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_SingleMessage) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_2_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 2; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} +// Will wrapp sender but not receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_10_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 10; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Will wrapp sender and receiver buffers +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers_20_messages) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 20; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +TEST(WorkerFabricEdmDatapath, FabricEDMLoopback_With_Workers) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 100000; + const bool src_is_dram = true; + const bool dest_is_dram = true; + + auto result = TestLoopbackEntrypoint(page_size, num_pages_total, src_is_dram, dest_is_dram); + ASSERT_EQ(result, 0); +} + +// Currently disabled until mcast properly tested/broughtup +TEST(WorkerFabricEdmDatapath, DISABLED_LineFabricMcast) { + const uint32_t page_size = 2048; + const uint32_t num_pages_total = 1; + const bool src_is_dram = true; + const bool dest_is_dram = true; + const size_t mcast_first_chip = 1; + const size_t mcast_last_chip = 3; + + auto result = TestLineFabricEntrypoint( + mcast_first_chip, mcast_last_chip, page_size, num_pages_total, src_is_dram, dest_is_dram); + + ASSERT_EQ(result, 0); +} + +// EnablePersistentKernelCache diff --git a/tt_metal/hw/inc/ethernet/dataflow_api.h b/tt_metal/hw/inc/ethernet/dataflow_api.h index 8901021fac5b..5b0ddafb9958 100644 --- a/tt_metal/hw/inc/ethernet/dataflow_api.h +++ b/tt_metal/hw/inc/ethernet/dataflow_api.h @@ -203,6 +203,23 @@ void eth_send_bytes_over_channel_payload_only( } } +// Calls the unsafe variant of eth_send_packet under the hood which is guaranteed not to context switch +// We want this for code size reasons +FORCE_INLINE +void eth_send_bytes_over_channel_payload_only_unsafe( + uint32_t src_addr, + uint32_t dst_addr, + uint32_t num_bytes, + uint32_t num_bytes_per_send = 16, + uint32_t num_bytes_per_send_word_size = 1) { + uint32_t num_bytes_sent = 0; + while (num_bytes_sent < num_bytes) { + internal_::eth_send_packet_unsafe( + 0, ((num_bytes_sent + src_addr) >> 4), ((num_bytes_sent + dst_addr) >> 4), num_bytes_per_send_word_size); + num_bytes_sent += num_bytes_per_send; + } +} + /* * Sends the write completion signal to the receiver ethernet core, for transfers where the payload was already sent. * The second half of a full ethernet send. diff --git a/tt_metal/hw/inc/ethernet/tunneling.h b/tt_metal/hw/inc/ethernet/tunneling.h index b6e4cdd0bd5b..043a133eeb0e 100644 --- a/tt_metal/hw/inc/ethernet/tunneling.h +++ b/tt_metal/hw/inc/ethernet/tunneling.h @@ -26,7 +26,11 @@ struct eth_channel_sync_t { // First level ack that signals to sender that the payload was received by receiver, // indicating that sender can reuse the sender side buffer safely. volatile uint32_t receiver_ack; - uint32_t reserved_1; + + // Logical channel ID tagged by the sender. Not required when channels + // are connected 1:1 (single producer - single consumer) + volatile uint32_t src_id; + uint32_t reserved_2; }; @@ -66,6 +70,15 @@ void eth_send_packet(uint32_t q_num, uint32_t src_word_addr, uint32_t dest_word_ eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); } +FORCE_INLINE +void eth_send_packet_unsafe(uint32_t q_num, uint32_t src_word_addr, uint32_t dest_word_addr, uint32_t num_words) { + ASSERT(eth_txq_reg_read(q_num, ETH_TXQ_CMD) == 0); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_START_ADDR, src_word_addr << 4); + eth_txq_reg_write(q_num, ETH_TXQ_DEST_ADDR, dest_word_addr << 4); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_SIZE_BYTES, num_words << 4); + eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); +} + FORCE_INLINE void eth_write_remote_reg(uint32_t q_num, uint32_t reg_addr, uint32_t val) { while (eth_txq_reg_read(q_num, ETH_TXQ_CMD) != 0) { diff --git a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h index 48b6411911d7..37c19fae91e2 100644 --- a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h +++ b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h @@ -292,6 +292,7 @@ inline __attribute__((always_inline)) void noc_fast_write_dw_inline(uint32_t noc uint32_t be32 = be; uint32_t be_shift = (dest_addr & (NOC_WORD_BYTES-1)); + // If we're given a misaligned address, don't write to the bytes in the word below the address be32 = (be32 << be_shift); while (!noc_cmd_buf_ready(noc, cmd_buf)); diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index bc2b1773cc20..16aa324c5206 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -10,6 +10,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 92e8b46e8055..865f1a7e0bd4 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -198,16 +198,17 @@ void generate_edm_kernels_for_ring_or_linear_topology( } } - -KernelHandle generate_edm_kernel( - tt::tt_metal::Program& program, +template +KernelHandle generate_edm_kernel_impl( + tt::tt_metal::Program& program, Device const* device, - ccl::EriscDatamoverBuilder const& edm_builder, + EDMBuilder const& edm_builder, + std::string const& kernel_path, CoreCoord const& eth_core, NOC noc_id) { edm_builder.dump_to_log(); - std::vector const& edm_clockwise_kernel_rt_args = edm_builder.emit_runtime_args(); + std::vector const& edm_kernel_rt_args = edm_builder.emit_runtime_args(); // Ethernet Kernels std::vector eth_sender_ct_args = edm_builder.emit_compile_time_args(); log_trace(tt::LogOp, "EDM core (x={},y={}):", eth_core.x, eth_core.y); @@ -216,17 +217,17 @@ KernelHandle generate_edm_kernel( log_trace(tt::LogOp, "\t{}", s); } - auto eth_sender_kernel =tt::tt_metal::CreateKernel( + auto eth_sender_kernel = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp", + kernel_path, eth_core, - tt::tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); + tt::tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); - tt::tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_clockwise_kernel_rt_args); + tt::tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_kernel_rt_args); std::stringstream ss; ss << "EDM ARGS:\n"; - for (auto const& s : edm_clockwise_kernel_rt_args) { + for (auto const& s : edm_kernel_rt_args) { ss << "\t" << s << "\n"; } log_trace(tt::LogOp, "{}", ss.str()); @@ -234,6 +235,31 @@ KernelHandle generate_edm_kernel( return eth_sender_kernel; } +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::FabricEriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id) { + return generate_edm_kernel_impl( + program, + device, + edm_builder, + "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp", + eth_core, + noc_id); +} + +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::EriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id) { + return generate_edm_kernel_impl( + program, device, edm_builder, "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp", eth_core, noc_id); +} + ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_channels, uint32_t page_size, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 51228970005b..0ad4d35b3f11 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -11,6 +11,7 @@ #include "ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/operations/ccl/common/types/ccl_types.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/program/program.hpp" #include "ttnn/tensor/types.hpp" @@ -467,6 +468,13 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { }; +KernelHandle generate_edm_kernel( + tt::tt_metal::Program& program, + Device const* device, + ccl::FabricEriscDatamoverBuilder const& edm_builder, + CoreCoord const& eth_core, + NOC noc_id); + KernelHandle generate_edm_kernel( tt::tt_metal::Program& program, Device const* device, diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp new file mode 100644 index 000000000000..2b84cf6bac46 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp" + +#include "common/math.hpp" +#include "eth_l1_address_map.h" +#include "tt_metal/common/assert.hpp" +#include "ttnn/operations/math.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +namespace ttnn::ccl { + + +// The channel structure is as follows: +// &header-> |----------------| channel_base_address +// | header | +// &payload-> |----------------| +// | | +// | payload | +// | | +// &channel_sync-> |----------------| +// | channel_sync | +// ------------------ +// + +FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( + std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size) { + TT_ASSERT(channel_buffer_size_bytes > sizeof(tt::fabric::PacketHeader) + 2 * FabricEriscDatamoverConfig::eth_channel_sync_size); + const std::size_t channel_buffer_size_with_channel_sync = + channel_buffer_size_bytes + sizeof(tt::fabric::PacketHeader); // + 16 // sizeof(tt::fabric::PacketHeader); + + this->channel_buffer_size_bytes = channel_buffer_size_bytes; + this->channel_buffer_size_bytes_with_channel_sync = channel_buffer_size_with_channel_sync; + const std::size_t total_ratio_count = 2 * sender_ratio_size + receiver_ratio_size; + this->sender_0_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, + channel_buffer_size_with_channel_sync); + this->sender_0_num_buffers = this->sender_0_channel_size_bytes / channel_buffer_size_with_channel_sync; + this->sender_1_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, + channel_buffer_size_with_channel_sync); + this->sender_1_num_buffers = this->sender_1_channel_size_bytes / channel_buffer_size_with_channel_sync; + this->receiver_channel_size_bytes = tt::round_down( + (available_channel_buffering_space / total_ratio_count) * receiver_ratio_size, + channel_buffer_size_with_channel_sync); + this->receiver_num_buffers = this->receiver_channel_size_bytes / channel_buffer_size_with_channel_sync; + + this->sender_0_channel_base_address = buffer_region_start; + this->sender_1_channel_base_address = this->sender_0_channel_base_address + this->sender_0_channel_size_bytes; + this->receiver_channel_base_address = this->sender_1_channel_base_address + this->sender_1_channel_size_bytes; + + log_trace(tt::LogOp, "Sender 0 channel_start: {}", this->sender_0_channel_base_address); + log_trace(tt::LogOp, "Sender 1 channel_start: {}", this->sender_1_channel_base_address); + log_trace(tt::LogOp, "Receiver channel_start: {}", this->receiver_channel_base_address); + + TT_ASSERT( + this->sender_0_channel_size_bytes + this->sender_1_channel_size_bytes + this->receiver_channel_size_bytes <= + this->available_channel_buffering_space); + TT_ASSERT( + this->receiver_channel_base_address + this->receiver_channel_size_bytes < + eth_l1_mem::address_map::MAX_L1_LOADING_SIZE); +} + +FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( + size_t my_noc_x, + size_t my_noc_y, + size_t my_chip_id, + size_t peer_chip_id, + + std::optional receiver_channel_downstream_flow_control_semaphore_id, + size_t sender_channel_0_flow_control_semaphore_id, + size_t sender_channel_1_flow_control_semaphore_id, + size_t sender_channel_0_connection_semaphore_id, + size_t sender_channel_1_connection_semaphore_id, + + FabricEriscDatamoverConfig const& config) : + my_noc_x(my_noc_x), + my_noc_y(my_noc_y), + config(config), + my_chip_id(my_chip_id), + peer_chip_id(peer_chip_id), + handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, FabricEriscDatamoverConfig::eth_channel_sync_size)), + channel_buffer_size(config.channel_buffer_size_bytes), + sender_0_num_buffers(config.sender_0_num_buffers), + sender_1_num_buffers(config.sender_1_num_buffers), + receiver_num_buffers(config.receiver_num_buffers), + + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + receiver_channel_downstream_flow_control_semaphore_id(receiver_channel_downstream_flow_control_semaphore_id), + sender_channel_0_flow_control_semaphore_id(sender_channel_0_flow_control_semaphore_id), + sender_channel_1_flow_control_semaphore_id(sender_channel_1_flow_control_semaphore_id), + sender_channel_0_connection_semaphore_id(sender_channel_0_connection_semaphore_id), + sender_channel_1_connection_semaphore_id(sender_channel_1_connection_semaphore_id), + + local_sender_channel_0_buffer_address(config.sender_0_channel_base_address), + local_sender_channel_0_connection_info_addr( + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address), + local_sender_channel_1_buffer_address(config.sender_1_channel_base_address), + local_sender_channel_1_connection_info_addr( + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address), + local_receiver_channel_buffer_address(config.receiver_channel_base_address), + + termination_signal_ptr(FabricEriscDatamoverConfig::termination_signal_address) {} + +std::vector FabricEriscDatamoverBuilder::emit_compile_time_args() const { + const bool is_handshake_master = this->my_chip_id < this->peer_chip_id; + TT_ASSERT(this->my_chip_id != this->peer_chip_id); + TT_ASSERT( + this->sender_0_num_buffers == this->sender_1_num_buffers); //, "Implementation expects sender_0_num_buffers and + // sender_1_num_buffers to be the same for now"); + return std::vector{ + is_handshake_master, + this->handshake_address, + this->channel_buffer_size, + + this->sender_0_num_buffers, + this->receiver_num_buffers, + + config.sender_0_channel_base_address, + FabricEriscDatamoverConfig::sender_channel_0_buffer_index_address, + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + config.sender_1_channel_base_address, + FabricEriscDatamoverConfig::sender_channel_1_buffer_index_address, + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + config.receiver_channel_base_address, + config.receiver_channel_base_address, + + config.sender_0_channel_base_address, + config.sender_1_channel_base_address, + + this->termination_signal_ptr}; +} + +std::vector FabricEriscDatamoverBuilder::emit_runtime_args() const { + return std::vector{ + this->sender_channel_0_connection_semaphore_id, + this->sender_channel_1_connection_semaphore_id, + this->downstream_edm_buffer_base_address != std::nullopt, + this->downstream_edm_buffer_base_address.value_or(0), + this->downstream_edm_noc_x.value_or(0), + this->downstream_edm_noc_y.value_or(0), + this->downstream_edm_semaphore_address.value_or(0), + this->downstream_edm_worker_registration_address.value_or(0), + this->downstream_edm_worker_location_info_address.value_or(0), + this->downstream_noc_interface_buffer_index_addr.value_or(0), + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + this->receiver_channel_downstream_flow_control_semaphore_id.value_or(0), + this->sender_channel_0_flow_control_semaphore_id, + this->sender_channel_1_flow_control_semaphore_id, + + }; +} + + +SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_channel() const { + return SenderWorkerAdapterSpec { + this->my_noc_x, + this->my_noc_y, + this->local_sender_channel_0_buffer_address, + this->sender_0_num_buffers, + this->sender_channel_0_flow_control_semaphore_id, + this->sender_channel_0_connection_semaphore_id, + FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + this->config.channel_buffer_size_bytes + }; +} + + +SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_fabric_channel() const { + return SenderWorkerAdapterSpec { + this->my_noc_x, + this->my_noc_y, + this->local_sender_channel_1_buffer_address, + this->sender_1_num_buffers, + this->sender_channel_1_flow_control_semaphore_id, + this->sender_channel_1_connection_semaphore_id, + FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + this->config.channel_buffer_size_bytes + }; +} + +void FabricEriscDatamoverBuilder::connect_to_downstream_edm(FabricEriscDatamoverBuilder const& downstream_edm, uint32_t downstream_edm_buffer_index_semaphore_id) { + auto const& adapter_spec = downstream_edm.build_connection_to_fabric_channel(); + + log_trace(tt::LogTest, "Connecting to downstream EDM at x={}, y={}", adapter_spec.edm_worker_x, adapter_spec.edm_worker_y); + + this->downstream_edm_noc_x = adapter_spec.edm_worker_x; + this->downstream_edm_noc_y = adapter_spec.edm_worker_y; + this->downstream_edm_buffer_base_address = adapter_spec.edm_buffer_base_addr; + this->downstream_edm_semaphore_address = adapter_spec.edm_l1_sem_addr; + this->downstream_edm_worker_registration_address = adapter_spec.edm_connection_handshake_addr; + this->downstream_edm_worker_location_info_address = adapter_spec.edm_worker_location_info_addr; + this->downstream_noc_interface_buffer_index_addr = downstream_edm_buffer_index_semaphore_id; +} + +} // namespace ttnn::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp new file mode 100644 index 000000000000..889c42405a1c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "eth_l1_address_map.h" +#include "tt_metal/third_party/umd/device/tt_cluster_descriptor_types.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +namespace ttnn { +namespace ccl { + +struct FabricEriscDatamoverConfig { + static constexpr std::size_t field_size = 16; + static constexpr std::size_t buffer_alignment = 32; + static_assert(((buffer_alignment - 1) & buffer_alignment) == 0); + + // Global + static constexpr std::size_t eth_channel_sync_size = 16; + static constexpr std::size_t handshake_addr = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + static constexpr std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; + static constexpr std::size_t termination_signal_address = + edm_channel_ack_addr + (2 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work + + // Sender Channel 0 + static constexpr std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; + static constexpr std::size_t sender_channel_0_worker_connection_info_address = + sender_channel_0_buffer_index_address + field_size; + static_assert(field_size >= sizeof(tt::fabric::EDMChannelWorkerLocationInfo)); + + // Sender Channel 1 + static constexpr std::size_t sender_channel_1_buffer_index_address = + sender_channel_0_worker_connection_info_address + field_size; + static constexpr std::size_t sender_channel_1_worker_connection_info_address = + sender_channel_1_buffer_index_address + field_size; + + // Channel Allocations + static constexpr std::size_t buffer_region_start = + (sender_channel_1_worker_connection_info_address + field_size + buffer_alignment) & ~(buffer_alignment - 1); // Align + static constexpr std::size_t available_channel_buffering_space = + eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - buffer_region_start; + + FabricEriscDatamoverConfig( + std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size); + + std::size_t channel_buffer_size_bytes; + std::size_t channel_buffer_size_bytes_with_channel_sync; + std::size_t sender_0_channel_size_bytes; + std::size_t sender_0_num_buffers; + std::size_t sender_1_channel_size_bytes; + std::size_t sender_1_num_buffers; + std::size_t receiver_channel_size_bytes; + std::size_t receiver_num_buffers; + + std::size_t sender_0_channel_base_address; + std::size_t sender_1_channel_base_address; + std::size_t receiver_channel_base_address; +}; + +struct SenderWorkerAdapterSpec { + size_t edm_worker_x; + size_t edm_worker_y; + size_t edm_buffer_base_addr; + size_t num_buffers_per_channel; + size_t edm_l1_sem_addr; + size_t edm_connection_handshake_addr; + size_t edm_worker_location_info_addr; // The EDM's location for `EDMChannelWorkerLocationInfo` + size_t buffer_size_bytes; +}; +class FabricEriscDatamoverBuilder { + public: + FabricEriscDatamoverBuilder( + size_t my_noc_x, + size_t my_noc_y, + size_t my_chip_id, + size_t peer_chip_id, + + std::optional receiver_channel_downstream_flow_control_semaphore_id, + size_t sender_channel_0_flow_control_semaphore_id, + size_t sender_channel_1_flow_control_semaphore_id, + size_t sender_channel_0_connection_semaphore_id, + size_t sender_channel_1_connection_semaphore_id, + + FabricEriscDatamoverConfig const& config); + + [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_worker_channel() const; + [[nodiscard]] SenderWorkerAdapterSpec build_connection_to_fabric_channel() const; + + [[nodiscard]] std::vector emit_compile_time_args() const; + + [[nodiscard]] std::vector emit_runtime_args() const; + + void connect_to_downstream_edm( + FabricEriscDatamoverBuilder const& downstream_edm, uint32_t downstream_edm_semaphore_id); + + void dump_to_log() const { + // TODO + } + + private: + size_t my_noc_x; + size_t my_noc_y; + FabricEriscDatamoverConfig config; + + size_t my_chip_id; + size_t peer_chip_id; + size_t handshake_address; + size_t channel_buffer_size; + + size_t sender_0_num_buffers; + size_t sender_1_num_buffers; + size_t receiver_num_buffers; + + size_t local_sender_channel_0_buffer_address; + size_t local_sender_channel_0_connection_info_addr; + size_t local_sender_channel_1_buffer_address; + size_t local_sender_channel_1_connection_info_addr; + size_t local_receiver_channel_buffer_address; + + size_t termination_signal_ptr; + + // Semaphore IDs + // this is the receiver channel's local sem for flow controlling with downstream fabric sender + std::optional receiver_channel_downstream_flow_control_semaphore_id; + size_t sender_channel_0_flow_control_semaphore_id; + size_t sender_channel_1_flow_control_semaphore_id; + size_t sender_channel_0_connection_semaphore_id; + size_t sender_channel_1_connection_semaphore_id; + + std::optional downstream_edm_noc_x; + std::optional downstream_edm_noc_y; + std::optional downstream_edm_buffer_base_address; + std::optional downstream_edm_semaphore_address; + std::optional downstream_edm_worker_registration_address; + std::optional downstream_edm_worker_location_info_address; + std::optional downstream_noc_interface_buffer_index_addr; +}; + +}; // namespace ccl +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp new file mode 100644 index 000000000000..720af76ed711 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -0,0 +1,195 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "dataflow_api.h" + +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" + +#include "debug/assert.h" + +#include + + +namespace tt::fabric { + +struct WorkerToFabricEdmSender{ + WorkerToFabricEdmSender () : worker_sem_addr(nullptr) {} + + WorkerToFabricEdmSender ( + size_t edm_worker_x, + size_t edm_worker_y, + std::size_t edm_buffer_base_addr, + std::size_t num_buffers_per_channel, + std::size_t edm_l1_sem_id, + std::size_t edm_connection_handshake_addr, + std::size_t edm_worker_location_info_addr, // The EDM's location for `EDMChannelWorkerLocationInfo` + std::size_t buffer_size_bytes, + volatile uint32_t * const worker_sem_addr, + uint32_t buffer_index_addr + ) : + edm_buffer_addr(get_noc_addr(edm_worker_x, edm_worker_y, edm_buffer_base_addr)), + edm_semaphore_addr(get_noc_addr(edm_worker_x, edm_worker_y, get_semaphore(edm_l1_sem_id))), + edm_connection_handshake_addr(edm_connection_handshake_addr), + edm_worker_location_info_addr(edm_worker_location_info_addr), + worker_sem_addr(worker_sem_addr), + edm_buffer_base_addr(edm_buffer_base_addr), + num_buffers_per_channel(num_buffers_per_channel), + last_buffer_index(num_buffers_per_channel - 1), + edm_l1_sem_addr(get_semaphore(edm_l1_sem_id)), + buffer_size_bytes(buffer_size_bytes), + buffer_index_ptr(reinterpret_cast(buffer_index_addr)) + { + ASSERT(buffer_size_bytes > 0); + } + + [[nodiscard]] FORCE_INLINE bool consumer_has_space() const { + return *this->worker_sem_addr == 1; + } + FORCE_INLINE void clear_flow_control_semaphore() const { + noc_semaphore_set(this->worker_sem_addr, 0); + } + FORCE_INLINE void wait_for_empty_write_slot() const { + noc_semaphore_wait(this->worker_sem_addr, 1); + } + + FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + send_payload_impl(cb_id, num_pages, page_size); + } + + // Does not wait for CB. Assumes caller handles CB data availability + FORCE_INLINE void send_payload_non_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + send_payload_impl(cb_id, num_pages, page_size); + } + + /* + * No CB + */ + FORCE_INLINE void send_payload_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } + + /* + * No CB + */ + // Does not wait for CB. Assumes caller handles CB data availability + FORCE_INLINE void send_payload_non_blocking_from_address(uint32_t source_address, size_t size_bytes) { + send_payload_from_address_impl(source_address, size_bytes); + } + + // Layout + // |-----------------------| + // | EDM Handshake | 16B + // |-----------------------| + // | EDM Ack Channel Sync | 16B + // |-----------------------| - + // | Connection Semaphore | 16B | + // |-----------------------| | + // | Buffer Index | 16B >- Per Sender Channel (On EDM) + // |-----------------------| | + // | Worker Connection Info| 16B |worker + // |-----------------------| -/ + // |-----------------------| + // + static constexpr size_t edm_sender_channel_field_stride_bytes = 16; + + FORCE_INLINE void open() { + auto dest_addr = this->edm_semaphore_addr; + static constexpr uint32_t open_connection_value = 1; + // May need to force buffer index to be a semaphore address + // remove the address portion to replace with the connection terminate address + dest_addr &= ~0x0000000FFFFFFFFFl; + uint64_t remote_buffer_index_addr = dest_addr | (edm_connection_handshake_addr + edm_sender_channel_field_stride_bytes); + ASSERT(remote_buffer_index_addr > 0); + noc_async_read(remote_buffer_index_addr, reinterpret_cast(this->buffer_index_ptr), sizeof(uint32_t)); + + ASSERT(edm_worker_location_info_addr == edm_connection_handshake_addr + 2 * edm_sender_channel_field_stride_bytes); + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_worker_location_info_addr; + // TODO: Need to change byte enable to be word enable + noc_inline_dw_write(dest_addr, reinterpret_cast(worker_sem_addr)); + noc_inline_dw_write(dest_addr + sizeof(uint32_t), ttnn::ccl::WorkerXY(my_x[0], my_y[0]).to_uint32()); + + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr; + noc_inline_dw_write(dest_addr, open_connection_value); + noc_async_read_barrier(); + } + + FORCE_INLINE void close() { + auto dest_addr = this->edm_semaphore_addr; + static constexpr uint32_t terminate_connection_value = 0; + // remove the address portion to replace with the connection terminate address + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr; + noc_inline_dw_write(dest_addr, terminate_connection_value); + + // buffer index stored at location after handshake addr + dest_addr &= ~0x0000000FFFFFFFFFl; + dest_addr |= edm_connection_handshake_addr + edm_sender_channel_field_stride_bytes; + noc_inline_dw_write(dest_addr, *this->buffer_index_ptr); + noc_async_write_barrier(); + } + + uint64_t edm_buffer_addr; + uint64_t edm_semaphore_addr; + size_t edm_connection_handshake_addr; + size_t edm_worker_location_info_addr; + volatile uint32_t * const worker_sem_addr; + std::size_t edm_buffer_base_addr; + std::size_t num_buffers_per_channel; + std::size_t last_buffer_index; + std::size_t edm_l1_sem_addr; + std::size_t buffer_size_bytes; + std::size_t *buffer_index_ptr; + + private: + template + FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, size_t size_bytes) { + this->clear_flow_control_semaphore(); + uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + + ASSERT(size_bytes <= this->buffer_size_bytes); + ASSERT(static_cast(buffer_address & 0x0FFFFFFF) <= 270000); + + /*{ // For debug purposes only. Useful to permanently backup the packet somewhere we can inspect with ttx-status + uint32_t dram_noc_x = my_y[0] == 1 ? 0 : 0; + uint32_t dram_noc_y = my_y[0] == 1 ? 0 : 5; + // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), 0x0F); + // noc_async_writes_flushed(); + // noc_inline_dw_write(get_noc_addr(dram_noc_x, dram_noc_y, storage_offset + 4), 0); + // auto pkthdr_size_words = sizeof(tt::fabric::PacketHeader) >> 2; + // for (size_t i = 0; i < pkthdr_size_words; i++) { + // reinterpret_cast(source_address)[pkthdr_size_words - i] = + // reinterpret_cast(source_address)[pkthdr_size_words - 1 - i]; + // } + // reinterpret_cast(source_address)[0] = 0xc0ffee; + // DPRINT << "NEXT STORAGE OFF: " << (uint32_t)storage_offset << "\n"; + noc_async_write(source_address, get_noc_addr(dram_noc_x, dram_noc_y, storage_offset), size_bytes); + storage_offset += size_bytes; + storage_offset += 64; + storage_offset = storage_offset & (~0x1F); + }*/ + + send_chunk_from_address(source_address, 1, size_bytes, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + + *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + } + + template + FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { + this->clear_flow_control_semaphore(); + uint64_t buffer_address = this->edm_buffer_addr + (*this->buffer_index_ptr * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + ASSERT(num_pages * page_size <= this->buffer_size_bytes); + send_chunk(cb_id, num_pages, page_size, buffer_address); + noc_semaphore_inc(edm_semaphore_addr, 1); + *this->buffer_index_ptr = (*this->buffer_index_ptr == this->last_buffer_index) ? 0 : *this->buffer_index_ptr + 1; + } +}; + + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp new file mode 100644 index 000000000000..37210c2d0128 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace tt::fabric { + +enum TerminationSignal : uint32_t { + KEEP_RUNNING = 0, + + // Wait for messages to drain + GRACEFULLY_TERMINATE = 1, + + // Immediately terminate - don't wait for any outstanding messages to arrive or drain out + IMMEDIATELY_TERMINATE = 2 +}; + +// 2 bits +enum CommandType : uint8_t { + WRITE = 0, + ATOMIC_INC = 1 +}; + +// How to send the payload across the cluster +// 1 bit +enum ChipSendType : uint8_t { + CHIP_UNICAST = 0, + CHIP_MULTICAST = 1 +}; +enum NocSendType : uint8_t { + NOC_UNICAST = 0, + NOC_MULTICAST = 1 +}; + + +struct UnicastRoutingCommandHeader { + uint8_t distance_in_hops; +}; +static_assert(sizeof(UnicastRoutingCommandHeader) == 1, "UnicastRoutingCommandHeader size is not 1 byte"); +struct MulticastRoutingCommandHeader { + uint8_t start_distance_in_hops: 4; + uint8_t range_hops: 4; // 0 implies unicast +}; +static_assert(sizeof(MulticastRoutingCommandHeader) == 1, "MulticastRoutingCommandHeader size is not 1 byte"); +union RoutingFields { + UnicastRoutingCommandHeader chip_unicast; + MulticastRoutingCommandHeader chip_mcast; +}; +static_assert(sizeof(RoutingFields) == sizeof(UnicastRoutingCommandHeader), "RoutingFields size is not 1 bytes"); + +struct NocUnicastCommandHeader { + uint32_t address; + uint32_t size; + uint8_t noc_x; + uint8_t noc_y; + uint16_t reserved; + // ignores header size + inline uint32_t get_payload_only_size() const { + return size; + } +}; +struct NocUnicastAtomicIncCommandHeader { + NocUnicastAtomicIncCommandHeader(uint32_t address, uint16_t val, uint16_t wrap, uint8_t noc_x, uint8_t noc_y) + : address(address), val(val), wrap(wrap), noc_x(noc_x), noc_y(noc_y) {} + + uint32_t address; + uint16_t val; + uint16_t wrap; + uint8_t noc_x; + uint8_t noc_y; + +}; +struct NocMulticastCommandHeader { + uint32_t address; + uint32_t size; + uint8_t noc_x_start; + uint8_t noc_y_start; + uint8_t mcast_rect_size_x; + uint8_t mcast_rect_size_y; + + // ignores header size + inline uint32_t get_payload_only_size() const { + return size; + } +}; +struct NocMulticastAtomicIncCommandHeader { + uint32_t address; + uint16_t val; + uint16_t wrap; + uint8_t noc_x_start; + uint8_t noc_y_start; + uint8_t size_x; + uint8_t size_y; +}; +static_assert(sizeof(NocUnicastCommandHeader) == 12, "NocUnicastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocMulticastCommandHeader) == 12, "NocMulticastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocUnicastAtomicIncCommandHeader) == 12, "NocUnicastCommandHeader size is not 1 byte"); +static_assert(sizeof(NocMulticastAtomicIncCommandHeader) == 12, "NocAtomicIncCommandHeader size is not 1 byte"); +union CommandFields{ + NocUnicastCommandHeader unicast_write; + NocMulticastCommandHeader mcast_write; + NocUnicastAtomicIncCommandHeader unicast_seminc; + NocMulticastAtomicIncCommandHeader mcast_seminc; +} ; +static_assert(sizeof(CommandFields) <= 15, "CommandFields size is not 15 bytes"); + +// TODO: wrap this in a debug version that holds type info so we can assert for field/command/ +struct PacketHeader { + // TODO: trim this down noc_send_type 2 bits (4 values): + // -> unicast_write, mcast_write, unicast_seminc, mcast_seminc + // For now, kept it separate so I could do reads which would be handled differently + // but for our purposes we shouldn't need read so we should be able to omit the support + CommandType command_type : 2; + ChipSendType chip_send_type : 1; + NocSendType noc_send_type : 1; + uint8_t reserved : 4; + + RoutingFields routing_fields; + uint16_t reserved2; + CommandFields command_fields; + + // Sort of hack to work-around DRAM read alignment issues that must be 32B aligned + // To simplify worker kernel code, we for now decide to pad up the packet header + // to 32B so the user can simplify shift into their CB chunk by sizeof(tt::fabric::PacketHeader) + // and automatically work around the DRAM read alignment bug. + // + // Future changes will remove this padding and require the worker kernel to be aware of this bug + // and pad their own CBs conditionally when reading from DRAM. It'll be up to the users to + // manage this complexity. + uint32_t padding0; + uint32_t padding1; + uint32_t padding2; + uint32_t padding3; + + inline void set_command_type(CommandType &type) { this->command_type = type; } + inline void set_chip_send_type(ChipSendType &type) { this->chip_send_type = type; } + inline void set_noc_send_type(NocSendType &type) { this->noc_send_type = type; } + inline void set_routing_fields(RoutingFields &fields) { this->routing_fields = fields; } + inline void set_command_fields(CommandFields &fields) { this->command_fields = fields; } + + size_t get_payload_size_excluding_header() volatile const { + switch(this->command_type) { + case WRITE: { + switch(this->noc_send_type) { + case NOC_UNICAST: { + return this->command_fields.unicast_write.size - sizeof(PacketHeader); + } break; + case NOC_MULTICAST: { + return this->command_fields.mcast_write.size - sizeof(PacketHeader); + } break; + default: + return 0; + } + } break; + case ATOMIC_INC: { + return 0; + } break; + default: + return 0; + } + } + inline size_t get_payload_size_including_header() volatile const { + return get_payload_size_excluding_header() + sizeof(PacketHeader); + } + + inline PacketHeader& to_write() { this->command_type = WRITE; return *this; } + inline PacketHeader& to_atomic_inc() { this->command_type = ATOMIC_INC; return *this; } + + inline PacketHeader &to_chip_unicast(UnicastRoutingCommandHeader const &chip_unicast_command_header) { + this->chip_send_type = CHIP_UNICAST; + this->routing_fields.chip_unicast = chip_unicast_command_header; + return *this; + } + inline PacketHeader &to_chip_multicast(MulticastRoutingCommandHeader const &chip_multicast_command_header) { + this->chip_send_type = CHIP_MULTICAST; + this->routing_fields.chip_mcast = chip_multicast_command_header; + return *this; + } + inline PacketHeader &to_noc_unicast(NocUnicastCommandHeader const &noc_unicast_command_header) { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_write = noc_unicast_command_header; + return *this; + } + inline PacketHeader &to_noc_multicast(NocMulticastCommandHeader const &noc_multicast_command_header) { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_write = noc_multicast_command_header; + return *this; + } + inline PacketHeader &to_noc_unicast_atomic_inc( + NocUnicastAtomicIncCommandHeader const &noc_unicast_atomic_inc_command_header) { + this->noc_send_type = NOC_UNICAST; + this->command_fields.unicast_seminc = noc_unicast_atomic_inc_command_header; + return *this; + } + inline PacketHeader &to_noc_multicast_atomic_inc( + NocMulticastAtomicIncCommandHeader const &noc_multicast_atomic_inc_command_header) { + this->noc_send_type = NOC_MULTICAST; + this->command_fields.mcast_seminc = noc_multicast_atomic_inc_command_header; + return *this; + } +}; + + +// TODO: When we remove the 32B padding requirement, reduce to 16B size check +static_assert(sizeof(PacketHeader) == 32, "sizeof(PacketHeader) is not equal to 32B"); + +static constexpr size_t header_size_bytes = sizeof(PacketHeader); + + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp new file mode 100644 index 000000000000..22267eb2bdba --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "debug/assert.h" + +namespace tt::fabric { + +FORCE_INLINE void validate(PacketHeader const& packet_header) { + ASSERT(packet_header.command_type < 2); + ASSERT(packet_header.chip_send_type < 2); + ASSERT(packet_header.noc_send_type < 2); +} + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp new file mode 100644 index 000000000000..9e6ba23c4b1c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -0,0 +1,228 @@ + +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/hw/inc/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include + +void write_unicast_blocking(uint32_t local_address, uint64_t dest_address, uint32_t size_bytes) { + noc_async_write(local_address, dest_address, size_bytes); + noc_async_writes_flushed(); +} + +void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { + switch (packet_start->chip_send_type) { + case tt::fabric::CHIP_UNICAST: { + DPRINT << "C_UNI: dist:" << (uint32_t) packet_start->routing_fields.chip_unicast.distance_in_hops << "\n"; + break; + } + case tt::fabric::CHIP_MULTICAST: { + DPRINT << "C_MCST: dist:" << (uint32_t) packet_start->routing_fields.chip_mcast.start_distance_in_hops << + ", rng:" << (uint32_t) packet_start->routing_fields.chip_mcast.range_hops << "\n"; + break; + } + }; +} + +void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet_start) { + switch (packet_start->noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + switch (packet_start->command_type) { + case tt::fabric::CommandType::WRITE: { + DPRINT << "N_WR addr:"<<(uint32_t)packet_start->command_fields.unicast_write.address << + ", size:" << (uint32_t) packet_start->command_fields.unicast_write.size << + ", x:" << (uint32_t) packet_start->command_fields.unicast_write.noc_x << + ", y:" << (uint32_t) packet_start->command_fields.unicast_write.noc_y << "\n"; + } break; + case tt::fabric::CommandType::ATOMIC_INC: { + DPRINT << "N_WR addr:"<<(uint32_t)packet_start->command_fields.unicast_seminc.address << + ", val:" << (uint32_t) packet_start->command_fields.unicast_seminc.val << + ", x:" << (uint32_t) packet_start->command_fields.unicast_seminc.noc_x << + ", y:" << (uint32_t) packet_start->command_fields.unicast_seminc.noc_y << "\n"; + + } break; + } + break; + } + case tt::fabric::NocSendType::NOC_MULTICAST: { + break; + } + } +} + +void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { + auto const& header = *packet_start; + DPRINT << "PKT: cmd_t:" << (uint32_t) packet_start->command_type << + ", csnd_t:" << (uint32_t) packet_start->chip_send_type << + ", nsnd_t:" << (uint32_t) packet_start->noc_send_type << "\n"; + print_pkt_hdr_routing_fields(packet_start); + print_pkt_header_noc_fields(packet_start); +} + + +// Since we unicast to local, we must omit the packet header +void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { + auto const& header = *packet_start; + uint32_t payload_start_address = reinterpret_cast(packet_start) + sizeof(tt::fabric::PacketHeader); + + tt::fabric::CommandType command_type = packet_start->command_type; + tt::fabric::NocSendType noc_send_type = packet_start->noc_send_type; + switch (command_type) { + case tt::fabric::CommandType::WRITE: { + switch (noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + auto const dest_address = get_noc_addr( + header.command_fields.unicast_write.noc_x, + header.command_fields.unicast_write.noc_y, + header.command_fields.unicast_write.address); + auto const size = header.command_fields.unicast_write.size - sizeof(tt::fabric::PacketHeader); + write_unicast_blocking(payload_start_address, dest_address, size); + + }break; + case tt::fabric::NocSendType::NOC_MULTICAST: { + // TODO: confirm if we need to adjust dest core count if we span eth or dram cores + auto const mcast_dest_address = get_noc_multicast_addr( + header.command_fields.mcast_write.noc_x_start, + header.command_fields.mcast_write.noc_y_start, + header.command_fields.mcast_write.noc_x_start + header.command_fields.mcast_write.mcast_rect_size_x, + header.command_fields.mcast_write.noc_y_start + header.command_fields.mcast_write.mcast_rect_size_y, + header.command_fields.mcast_write.address); + auto const num_dests = header.command_fields.mcast_write.mcast_rect_size_x * header.command_fields.mcast_write.mcast_rect_size_y; + auto const size = header.command_fields.mcast_write.size - sizeof(tt::fabric::PacketHeader); + noc_async_write_multicast_one_packet(payload_start_address, mcast_dest_address, size, num_dests); + noc_async_writes_flushed(); + + }break; + default: { + ASSERT(false); + } + } + break; + } + case tt::fabric::CommandType::ATOMIC_INC: { + switch (noc_send_type) { + case tt::fabric::NocSendType::NOC_UNICAST: { + auto const dest_address = get_noc_addr( + header.command_fields.unicast_seminc.noc_x, + header.command_fields.unicast_seminc.noc_y, + header.command_fields.unicast_seminc.address); + auto const increment = header.command_fields.unicast_seminc.val; + noc_semaphore_inc(dest_address, increment); + + }break; + case tt::fabric::NocSendType::NOC_MULTICAST: { + ASSERT(false); + // noc_async_write(payload_start_address, header.dest_address, header.size_bytes); + + }break; + default: { + ASSERT(false); + } + } + break; + + }; + + default: { + ASSERT(false); + } + }; +} + + + +void update_packet_header_for_next_hop(volatile tt::fabric::PacketHeader * packet_header) { + switch (packet_header->chip_send_type) { + case tt::fabric::CHIP_UNICAST: { + packet_header->routing_fields.chip_unicast.distance_in_hops--; + } break; + case tt::fabric::CHIP_MULTICAST: { + if (packet_header->routing_fields.chip_mcast.start_distance_in_hops == 0) { + packet_header->routing_fields.chip_mcast.range_hops--; + } else { + packet_header->routing_fields.chip_mcast.start_distance_in_hops--; + } + } break; + } +} + +// This function forwards a packet to the downstream EDM channel for eventual sending +// to the next chip in the line/ring +// +// Modifies the packet header (decrements hop counts) so ... +// +// !!!WARNING!!! +// !!!WARNING!!! do NOT call before determining if the packet should be consumed locally or forwarded +// !!!WARNING!!! +tt::fabric::SendStatus forward_payload_to_downstream_edm( + volatile tt::fabric::PacketHeader *packet_header, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface + ) { + // SHOULD BE ABLE TO ASSERT ON THIS SINCE WE CHECK FOR THIS IN THE CALLER + // TODO: PERF + bool safe_to_send = downstream_edm_interface.consumer_has_space(); + if (!safe_to_send) { + return tt::fabric::SendStatus::NOT_SENT; + } + + // print_pkt_header(packet_header); + ASSERT(const_cast(packet_header)->get_payload_size_including_header() < 100000) + update_packet_header_for_next_hop(packet_header); + + downstream_edm_interface.send_payload_blocking_from_address( + reinterpret_cast(packet_header), + const_cast(packet_header)->get_payload_size_including_header()); + + return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; +} + +void execute_chip_multicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { + ASSERT(false); +} + +bool packet_must_be_consumed_locally(tt::fabric::PacketHeader const& packet_header) { + switch (packet_header.chip_send_type) { + case tt::fabric::ChipSendType::CHIP_UNICAST: { + // TODO: does it make more sense to have 0 as the terminating distance or 1? + // depends where we want to do the decrement and what the starting value + // is expected to be for worker + // Maybe at API level we just always decrement by 1 under the hood + // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) + return packet_header.routing_fields.chip_unicast.distance_in_hops == 0; + } + case tt::fabric::ChipSendType::CHIP_MULTICAST: { + return packet_header.routing_fields.chip_mcast.start_distance_in_hops == 0; + } + default: { + ASSERT(false); + return false; + } + } +} + + +bool packet_must_be_forwarded_to_next_chip(tt::fabric::PacketHeader const& packet_header) { + switch (packet_header.chip_send_type) { + case tt::fabric::ChipSendType::CHIP_UNICAST: { + // TODO: does it make more sense to have 0 as the terminating distance or 1? + // depends where we want to do the decrement and what the starting value + // is expected to be for worker + // Maybe at API level we just always decrement by 1 under the hood + // so user can call `fabric_send_packet(payload_addr, size, n_hops=1) + return packet_header.routing_fields.chip_unicast.distance_in_hops != 0; + } + case tt::fabric::ChipSendType::CHIP_MULTICAST: { + return packet_header.routing_fields.chip_mcast.range_hops != 0; + } + default: { + ASSERT(false); + return false; + } + } +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp new file mode 100644 index 000000000000..2366c8758de9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include + +namespace tt::fabric { +enum BlockingMode: uint8_t { + // + BUSY_WAIT_BLOCKING, + + // will wait and allow context switching + CTX_SWITCH_BLOCKING, + + // function will early exist if not able to send + NON_BLOCKING +}; + +enum SendStatus : uint8_t { + // Indicates that the sender was able to send the payload + // but was not able to send the channel_sync_t at the end of the + // buffer + // + // This enum should only ever be returned if we are sending less than + // a full packet/buffer of data AND when we are trying to send the + // channel_sync_t at the end of the buffer (which must be as a separate + // command) but the eth_tx_cmd_q is busy for that second message + // + // Receiving this value indicates we + // MUST: + // - Eventually send the channel_sync_t before advancing to the next buffer + // MUST NOT: + // - Advance to the next buffer index + // - Forward the other sender channel's data (if it has any) + SENT_PAYLOAD_ONLY, + + // Indicates both the payload and the channel sync were sent successfully + SENT_PAYLOAD_AND_SYNC, + + // Indicates no data was sent because the eth_tx_cmd_q was busy + NOT_SENT, + + ERROR, +}; + +struct EDMChannelWorkerLocationInfo { + uint32_t worker_semaphore_address; + ttnn::ccl::WorkerXY worker_xy; +}; + +static_assert(sizeof(EDMChannelWorkerLocationInfo) <= 16); + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp new file mode 100644 index 000000000000..d105e2bf6d08 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -0,0 +1,881 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "dataflow_api.h" +#include "tt_metal/hw/inc/ethernet/dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/edm_handshake.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +using ttnn::ccl::WorkerXY; + +/* + +The fabric Erisc Data Mover (EDM) is a component that can be used to build *very* simple linear topology fabrics. +One of these EDMs can be instantiated on each ethernet link. It is built from 3 "channels" (though the definition +of channel here is a little loose since two of the 3 will merge traffic, so this setup could be interpreted as a +two channel setup.). This EDM implements packet based packets only - concepts like sockets are not supported. + +## EDM Structure + +There are two sender channels and one receiver channel. "Sender" and "receiver" are relative to the Ethernet link, +not the chip. Sender sends over the link and receiver receives from the link. + +Each sender channel serves a different purpose: +- Sender channel 0 : Accepts packets from a workers on the local chip +- Sender channel 1: accepts packets from an upstream EDM (i.e. an upstream + EDM receiver channel on the same chip but different core) + +The receiver channel accepts packets from the Ethernet link and can do one (or both) of: +- Write the packet to local chhip if it is the intended destination (unicast or mcast) +- Forward the packet to the next chip in the line if: + - Unicast and not the target chip + - Multicast and this chip is in the multicast target range + +Sender channels will merge traffic into the remote EDM's receiver channel. + +Below is a diagram that shows how EDMs can be connected over an ethernet link. In this case, the two +EDM kernels are run on separate, but connected ethernet link cores. + + ┌───────────────────────┐ ┌───────────────────────┐ + │ Sender Channel 0 │ │ Receiver Channel │ + │ ┌────────────────┐ │ │ ┌────────────────┐ │ + │ │ ┼──┼───┬───────┼───► │ │ + │ │ │ │ │ │ │ │ │ + │ └────────────────┘ │ │ │ └────────────────┘ │ + │ Sender Channel 1 │ │ │ Sender Channel 1 │ + │ ┌────────────────┐ │ │ │ ┌────────────────┐ │ + │ │ ┼──┼───┘ │ │ │ │ + │ │ │ │ ┌─┼───┼ │ │ + │ └────────────────┘ │ │ │ └────────────────┘ │ + │ Receiver Channel │ │ │ Sender Channel 0 │ + │ ┌────────────────┐ │ │ │ ┌────────────────┐ │ + │ │ │ │ │ │ │ │ │ + │ │ ◄──┼─────────┴─┼───┼ │ │ + │ └────────────────┘ │ │ └────────────────┘ │ + │ │ │ │ + │ │ │ │ + └───────────────────────┘ └───────────────────────┘ + + +## Building a "Fabric" + +At present, only linear topologies are supported, and one per ethernet link along that given line. +Below shows the intended connectivity of EDMs across chips in a hypothetical 3-chip fabric. For longer +lines, the pattern would be extended. + + CHIP 0 CHIP 1 CHIP 2 + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ +┌────┴─────┐ ▲ ┌─────┴────┐ ┌────┴─────┐ ▲ ┌─────┴────┐ ┌────┴─────┐ ▲ ┌─────┴────┐ +│ EDM │ │ │ EDM │ │ EDM │ │ │ EDM │ │ EDM │ │ │ EDM │ +│ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ +│ │ Rx ┼─┼─┴───┼─► S1 ┼─┼─┬────┼─► Rx ┼─┼─┴───┼─► S1 ┼─┼┬─────┼─► Rx ┼─┼─┘ | | S1 │ │ +│ └──────┘ │ │ └──────┘ │ │ │ └──────┘ │ │ └──────┘ ││ │ └──────┘ │ │ └──────┘ │ +│ ┌──────┐ │ │ ┌──────┐ │ │ │ ┌──────┐ │ │ ┌──────┐ ││ │ ┌──────┐ │ │ ┌──────┐ │ +│ │ S0 ◄─┼──┬──┼─► S0 ┼─┼─┘ ┌┼─┼ S0 ◄─┼──┬──┼─► S0 ┼─┼┘ ┌┼─┼ S0 ◄─┼──┬──┼─► S0 │ │ +│ └──────┘ │ │ │ └──────┘ │ ││ └──────┘ │ │ │ └──────┘ │ ││ └──────┘ │ │ │ └──────┘ │ +│ ┌──────┐ │ │ │ ┌──────┐ │ ││ ┌──────┐ │ │ │ ┌──────┐ │ ││ ┌──────┐ │ │ │ ┌──────┐ │ +│ │ S1 | | │ ┌┼─┼ Rx ◄─┼─────┴┼─┼ S1 ◄─┼─┐│ ┌┼─┼ Rx ◄─┼─────┴┼─┼ S1 ◄─┼─┐│ ┌┼─┼ Rx │ │ +│ └──────┘ │ | |│ └──────┘ │ │ └──────┘ │ └┼─┤│ └──────┘ │ │ └──────┘ │ └┼─┤│ └──────┘ │ +└────┬─────┘ │ │└─────┬────┘ └────┬─────┘ │ │└─────┬────┘ └────┬─────┘ │ │└─────┬────┘ + │ ▼ │ │ ▼ │ │ ▼ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + + +## Connecting Workers to Channels + +As mentioned, only one worker can push to a given EDM sender channel at a time. In order to send to an EDM +sender channel, the worker must establish a connection. The connection protocol is as follows and is started +by the worker (the EDM is a slave in this protocol). + +*NOTE*: If multiple workers try to connect to the same EDM sender channel at the same time, the behavior is undefined. +*NOTE*: Additionally, if a worker pushes packets to a channel it isn't connected to, behaviour is undefined. +*NOTE*: Undefined == likely hang + +The `WorkerToFabricEdmSender` from `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp` +provides an implementation of the connection protocol. `WorkerToFabricEdmSender` also acts as a wrapper around that +protocol so workers can simply call `open()` to execute the connection protocol without having to manually reimplement +for each kernel. + +### Protocol +Worker: +- Read from EDM sender channel buffer_index address + - Required so that the worker knows where to write its first packet (since the channel may already contain packets from + a previous connection) +- Write worker core X/Y (NOC 0 based) +- Write worker flow control semaphore L1 address + +EDM Sender Channel: +- Check local connection valid semaphore for new established connection + - When the connection semaphore indicates an active connection, the channel assumes all other relevant fields were + correctly populated by the worker: + - Worker core_x (on NOC 0) + - Worker core_y (on NOC 0) + - Worker flow control semaphore L1 address + + +## Tearing Down Connections + +Every worker is required to explicitly teardown its connection with the EDM before terminating. To do this, the worker +must simply write a `0` to the EDM sender channel's connection semaphore address. As long as the worker has sent all +of its packets to the EDM before this, then the EDM will guarantee to forward the messages correctly. + +At this point, it is safe for another kernel to establish a connection. + +## Packet Structure + +Workers are responsible for populating packet headers before sending to the EDM. The packet header structure is defined +in `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp`. + +## Channel structure + +Each EDM channel is built from one or more buffers. Each buffer is the same size and can hold atmost one packet. +Neighbouring packets occupy nehighouring buffers - with the exception of the last buffer index. The next packet after a write +into the last buffer index will wrap around to the first buffer index. Even if packets do not occupy the full buffer, subsequent +packets will always be written into the next logical buffer. A gap will exist in memory but the EDM will not send that padded data +(unless it is more performant - which is possible in some special cases) + + Example channel with 8 buffers +┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ +│ │ │ │ │ │ │ │ │ +│ │ │ │ │ │ │ │ │ +└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ + buf 0 buf 1 buf 2 buf 3 buf 4 buf 5 buf 6 buf 7 + + +Here we have an example of a channel with 4 buffers, filled with some number of packets. Each packet is a different size. +Packets 0, 2, and 3 are smaller than the full buffer size, while packet 1 is the full buffer size. + +┌───────────────┬───────────────┬───────────────┬───────────────┐ +│H|Payload| / / │H|Payload │H|Pyld| / / / /│H|Payload |/ /│ +│ | |/ / /│ | │ | |/ / / / │ | | / │ +└───────────────┴───────────────┴───────────────┴───────────────┘ + buf 0 buf 1 buf 2 buf 3 + + +A detail of the channel structure is omitted from the above diagram, namely the EDM <-> EDM flow control region for each buffer. +Each buffer really looks something like this: + + + &header-> |----------------| channel_base_address + | header | + &payload-> |----------------| + | | + | payload | + | | + &channel_sync-> |----------------| + | channel_sync | // This is new + ------------------ + +The "channel_sync" is an `eth_channel_sync_t` and is internal to the EDM implementation and is used to indicate packet +transmission state between sender and receiver EDMs. + +The protocol for its use is: +1) Sender updates the field indicating new data: + - set `bytes_sent` to a non-zero value indicating new data + - clear `receiver_ack` to 0 + - set `src_id` to the sender channel id so the receiver knows who the sender was (and where the ack should go) +2) Sender sends this channel sync to the corresponding location in the receiver channel (either in the same transmission + as the packet or separately) +3) Receiver sees that `bytes_sent` is non-zero, indicating a new packet. It sends back an acknowledgement (first level): + - set `receiver_ack` to non-zero + *NOTE* IMPORTANT: To avoid a race, the receiver must be sure to send its channel_sync_t from a different address it uses + as for the second level acknowledgement + 3b) When sender receives an ack, it understands it can overwrite its local copy of the packet with new data +4) After receiver properly writes out its packet, it sends a second level acknowledgement, indicating it can receive new + data into this specific buffer index: + - clear the bytes_sent and receiver_ack fields and send back the `channel_sync` to the sender + + + +## Sending Packets +Sending a packet is done as follows: + +1) Worker waits for flow control semaphore increment from EDM sender channel + - Indicates there is space at the next buffer index for a packet +2) Worker performs a noc write of its packet to the EDM sender channel at the buffer index + +*NOTE*: !!!ALL PACKETS MUST CONTAIN DESTINATION NOC X/Y AS NOC 0 COORDINATES, REGARDLESS OF THE `noc_index` OF THE SENDER!!! + +*/ + +//////////////////////////////////////////////// +// Data structures, types, enums, and constants +//////////////////////////////////////////////// + +enum SenderState : uint8_t { + SENDER_DONE = 0, + + // we are ready to tell the worker(s) that the buffer is available for writing into + SENDER_SIGNALING_WORKER, + + // we are waiting for the payload to arrive in L1; we are checking local semaphore for worker + // completion + SENDER_WAITING_FOR_WORKER, + + // this state is enterred if the sender was able to send the payload but not the channel sync + SENDER_SEND_CHANNEL_SYNC, + + // Sender channel is not connected to a worker and is waiting for a new connection + SENDER_WAIT_WORKER_HANDSHAKE, + + // means we are waiting for ack from receiver that payload was received + SENDER_WAITING_FOR_ETH, + +}; + +enum ReceiverState : uint8_t { + RECEIVER_DONE = 0, + + // Receiver is processing the packet, either writing it locally or forwarding to the next EDM + // (toward next chip), or both + RECEIVER_SENDING_PAYLOAD, + + // Enter this state after performing writes of the current packet as a sort of soft barrier + // (for this channel only) so we can make progress on other channels while waiting for the + // writes to flush + RECEIVER_WAITING_FOR_WRITE_FLUSH, + + // means we are waitinf for a payload from sender + RECEIVER_WAITING_FOR_ETH, +}; + + +enum PacketLocalForwardType : uint8_t { + PACKET_FORWARD_INVALID = 0x0, + PACKET_FORWARD_LOCAL_ONLY = 0x1, + PACKET_FORWARD_REMOTE_ONLY = 0x2, + PACKET_FORWARD_LOCAL_AND_REMOTE = 0x3 +}; + +static constexpr uint32_t SWITCH_INTERVAL = 4000000; +static constexpr size_t ETH_BYTES_TO_WORDS_SHIFT = 4; +static constexpr size_t NUM_SENDER_CHANNELS = 2; +static constexpr size_t num_workers_ctor = 1; +static constexpr size_t num_messages_to_move_ctor_value = 1; +// Doesn't REALLY matter but for consistency I picked the next available ID +static constexpr size_t receiver_channel_id = NUM_SENDER_CHANNELS; +static constexpr size_t worker_info_offset_past_connection_semaphore = 32; + +///////////////////////////////////////////// +// SENDER SIDE HELPERS +///////////////////////////////////////////// + +FORCE_INLINE void sender_notify_workers_if_buffer_available_sequence( + tt::fabric::EdmChannelWorkerInterface &local_sender_worker_interface) { + local_sender_worker_interface.clear_local_semaphore(); + local_sender_worker_interface.increment_worker_semaphore(); +} + +template +void send_channel_sync( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + + eth_send_bytes_over_channel_payload_only_unsafe( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()), + reinterpret_cast(receiver_buffer_channel.get_current_bytes_sent_address()), + sizeof(eth_channel_sync_t), + sizeof(eth_channel_sync_t), + sizeof(eth_channel_sync_t) >> ETH_BYTES_TO_WORDS_SHIFT); +} + +template +tt::fabric::SendStatus send_next_data( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + + auto status = tt::fabric::SendStatus::NOT_SENT; + + ASSERT(!eth_txq_is_busy()); + + status = tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; + ASSERT( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == + (reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + + reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - + (uint32_t)sizeof(eth_channel_sync_t))); + *sender_buffer_channel.get_current_bytes_sent_address() = sender_buffer_channel.get_current_max_eth_payload_size(); + *sender_buffer_channel.get_current_bytes_acked_address() = 0; + *sender_buffer_channel.get_current_src_id_address() = sender_buffer_channel.get_id(); + ASSERT(*sender_buffer_channel.get_current_src_id_address() < 2); + + // TODO: TUNING - experiment with only conditionally breaking the transfer up into multiple packets if we are + // a certain threshold less than full packet + // we can precompute this value even on host and pass it in so we can get away with a single integer + // compare + // NOTE: if we always send full packet, then we don't need the second branch below dedicated for + // channel sync + tt::fabric::validate(*const_cast( + reinterpret_cast(receiver_buffer_channel.get_current_buffer_address()))); + const size_t payload_size = sender_buffer_channel.get_current_payload_plus_channel_sync_size(); + eth_send_bytes_over_channel_payload_only_unsafe( + sender_buffer_channel.get_current_buffer_address(), + receiver_buffer_channel.get_current_buffer_address(), // get_remote_eth_buffer_address(), + payload_size, + payload_size, + payload_size >> ETH_BYTES_TO_WORDS_SHIFT); + + bool sent_payload_and_channel_sync_in_one_shot = + payload_size == sender_buffer_channel.get_channel_buffer_max_size_in_bytes(); + if (!sent_payload_and_channel_sync_in_one_shot) { + // We weren't able to send the channel_sync_t in one shot with the payload so we need to send a second + // packet + // TODO: TUNING - consider busy waiting for a maximum amount of time + if (!eth_txq_is_busy()) { + send_channel_sync(sender_buffer_channel, receiver_buffer_channel); + } else { + status = tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; + } + } + + // Note: We can only advance to the next buffer index if we have fully completed the send (both the payload and sync + // messages) + if (status == tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC) { + sender_buffer_channel.advance_buffer_index(); + receiver_buffer_channel.advance_buffer_index(); + } + + return status; +} + +template +FORCE_INLINE bool sender_noc_receive_payload_ack_check_sequence( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + return sender_buffer_channel.is_local_semaphore_full(); +} + +template +FORCE_INLINE void sender_eth_check_receiver_ack_sequence( + tt::fabric::EthChannelBuffer &sender_buffer_channel, + tt::fabric::EdmChannelWorkerInterface &sender_worker_interface) { + sender_buffer_channel.eth_clear_sender_channel_ack(); + + sender_notify_workers_if_buffer_available_sequence(sender_worker_interface); +} + +///////////////////////////////////////////// +// RECEIVER SIDE HELPERS +///////////////////////////////////////////// + +template +FORCE_INLINE bool new_unacknowledged_packet_avilable_on_reciever_channel( + tt::fabric::EthChannelBuffer &local_receiver_channel) { + return local_receiver_channel.eth_bytes_are_available_on_channel(); +} + +/* + * Acting the receiver, we are looking at our receiver channel and acking the sender who sent us the latest packet. + * Doesn't check to see if indeed a new message is available. It's assumed the caller has handled that separately. + */ +// MUST CHECK !is_eth_txq_busy() before calling +template +void receiver_send_received_ack( + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { + // Set the acknowledgement bits. We have a different location than the + + const auto src_id = *local_receiver_buffer_channel.get_current_src_id_address(); + ASSERT(src_id < NUM_SENDER_CHANNELS); + auto &sender_buffer_channel = remote_sender_channels[src_id]; + ASSERT( + reinterpret_cast(sender_buffer_channel.get_current_bytes_sent_address()) == + reinterpret_cast(sender_buffer_channel.get_current_buffer_address()) + + reinterpret_cast(sender_buffer_channel.get_current_max_eth_payload_size()) - + sizeof(eth_channel_sync_t)); + + const size_t local_ack_channel_sync_src_addr = + local_receiver_buffer_channel.get_eth_transaction_ack_word_addr() + (src_id * sizeof(eth_channel_sync_t)); + reinterpret_cast(local_ack_channel_sync_src_addr)->bytes_sent = + *local_receiver_buffer_channel.get_current_bytes_sent_address(); + reinterpret_cast(local_ack_channel_sync_src_addr)->receiver_ack = 1; + reinterpret_cast(local_ack_channel_sync_src_addr)->src_id = + *local_receiver_buffer_channel.get_current_src_id_address(); + + // Make sure we don't alias the erisc_info eth_channel_sync_t + ASSERT( + reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + ->bytes_sent != 0); + ASSERT( + reinterpret_cast(local_receiver_buffer_channel.get_current_bytes_sent_address()) + ->receiver_ack == 0); + + ASSERT(!eth_txq_is_busy()); + internal_::eth_send_packet_unsafe( + 0, + local_ack_channel_sync_src_addr >> 4, + ((uint32_t)(sender_buffer_channel.get_current_bytes_sent_address())) >> 4, + 1); +} + +// MUST CHECK !is_eth_txq_busy() before calling +template +FORCE_INLINE void receiver_send_completion_ack( + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { + volatile auto local_bytes_sent_addr = local_receiver_buffer_channel.get_current_bytes_sent_address(); + volatile auto local_src_id_ptr = local_receiver_buffer_channel.get_current_src_id_address(); + + auto src_sender_channel = *local_src_id_ptr; + *(local_bytes_sent_addr) = 0; + *(local_receiver_buffer_channel.get_current_bytes_acked_address()) = 0; + ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); + + ASSERT(!eth_txq_is_busy()); + internal_::eth_send_packet_unsafe( + 0, + (uint32_t)(local_bytes_sent_addr) >> 4, + (uint32_t)(remote_sender_channels[src_sender_channel].get_current_bytes_sent_address()) >> 4, + 1); + + local_receiver_buffer_channel.advance_buffer_index(); + remote_sender_channels[src_sender_channel].advance_buffer_index(); +} + + +PacketLocalForwardType get_packet_local_forward_type(const tt::fabric::PacketHeader &packet_header) { + const bool local_chip_is_packet_destination = packet_must_be_consumed_locally(packet_header); + const bool packet_needs_forwarding = packet_must_be_forwarded_to_next_chip(packet_header); + PacketLocalForwardType forward_type = + static_cast(packet_needs_forwarding << 1 | local_chip_is_packet_destination); + return forward_type; +} + +FORCE_INLINE bool can_forward_packet_completely( + const tt::fabric::PacketHeader &packet_header, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + auto forward_status = get_packet_local_forward_type(packet_header); + bool can_send = true; + switch (forward_status) { + case PACKET_FORWARD_INVALID: return false; + case PACKET_FORWARD_LOCAL_ONLY: return true; + + case PACKET_FORWARD_REMOTE_ONLY: + case PACKET_FORWARD_LOCAL_AND_REMOTE: return downstream_edm_interface.consumer_has_space(); + default: ASSERT(false); return false; + }; +} + +// template +tt::fabric::SendStatus receiver_forward_packet( + volatile tt::fabric::PacketHeader *packet_start, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + // Just cache the packet_header - we don't really expect (or care) if contents change during this function. + tt::fabric::PacketHeader const &packet_header = *const_cast(packet_start); + tt::fabric::validate(packet_header); + auto forward_status = get_packet_local_forward_type(packet_header); + + switch (forward_status) { + case PACKET_FORWARD_LOCAL_ONLY: { + execute_chip_unicast_to_local_chip(packet_start); + return tt::fabric::SendStatus::SENT_PAYLOAD_AND_SYNC; + } break; + + case PACKET_FORWARD_REMOTE_ONLY: { + return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + } break; + + case PACKET_FORWARD_LOCAL_AND_REMOTE: { + ASSERT(packet_header.chip_send_type == tt::fabric::ChipSendType::CHIP_MULTICAST); + // TODO: make local chip write non-blocking + execute_chip_unicast_to_local_chip(packet_start); + return forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + } break; + + case PACKET_FORWARD_INVALID: + default: ASSERT(false); return tt::fabric::SendStatus::ERROR; + }; +} + +//////////////////////////////////// +//////////////////////////////////// +// Main Control Loop +//////////////////////////////////// +//////////////////////////////////// +template +bool run_sender_channel_state_machine_step( + tt::fabric::EthChannelBuffer &local_sender_channel, + tt::fabric::EdmChannelWorkerInterface &local_sender_channel_worker_interface, + tt::fabric::EthChannelBuffer &remote_receiver_channel, + SenderState *const sender_state_out) { + bool incr_sender_channel_index = true; + switch (*sender_state_out) { + case SenderState::SENDER_WAITING_FOR_WORKER: { + bool able_to_send = local_sender_channel_worker_interface.has_payload() && !eth_txq_is_busy() && + local_sender_channel.eth_is_receiver_channel_send_done(); + if (able_to_send) { + auto send_status = send_next_data(local_sender_channel, remote_receiver_channel); + // TODO: align the enums and state values so I can just do + // sender_states[sender_channel_index] += send_status :) + ASSERT(send_status != tt::fabric::SendStatus::ERROR); + *sender_state_out = + send_status == tt::fabric::SendStatus::NOT_SENT ? SenderState::SENDER_WAITING_FOR_WORKER + : send_status == tt::fabric::SendStatus::SENT_PAYLOAD_ONLY ? SenderState::SENDER_SEND_CHANNEL_SYNC + : SenderState::SENDER_WAITING_FOR_ETH; + // Avoid any sort of starvation/bubbles so we only advance if we've sent the packet and channel sync + // otherwise what can happen is we could start sending another large payload from the other channel + // and not be able to send the channel sync for the packet we just sent, which overall negatively + // impact latency + incr_sender_channel_index = send_status != tt::fabric::SendStatus::SENT_PAYLOAD_ONLY; + } else { + if (local_sender_channel_worker_interface.has_worker_teardown_request()) { + local_sender_channel_worker_interface.teardown_connection(); + *sender_state_out = SenderState::SENDER_WAIT_WORKER_HANDSHAKE; + } + } + } break; + + case SenderState::SENDER_WAIT_WORKER_HANDSHAKE: + if (local_sender_channel_worker_interface.connection_is_live()) { + bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || + local_sender_channel.eth_is_receiver_channel_send_done(); + if (is_safe_to_receive_next_message) { + sender_notify_workers_if_buffer_available_sequence(local_sender_channel_worker_interface); + *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; + } else { + *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + } + } + break; + + case SenderState::SENDER_SEND_CHANNEL_SYNC: { + bool can_send_channel_sync_without_blocking = !eth_txq_is_busy(); + if (can_send_channel_sync_without_blocking) { + send_channel_sync(local_sender_channel, remote_receiver_channel); + local_sender_channel.advance_buffer_index(); + remote_receiver_channel.advance_buffer_index(); + *sender_state_out = SenderState::SENDER_WAITING_FOR_ETH; + } + } break; + + case SenderState::SENDER_WAITING_FOR_ETH: { + bool is_safe_to_receive_next_message = local_sender_channel.eth_is_receiver_channel_send_acked() || + local_sender_channel.eth_is_receiver_channel_send_done(); + if (is_safe_to_receive_next_message) { + // This also notifies workers in the same call + sender_eth_check_receiver_ack_sequence(local_sender_channel, local_sender_channel_worker_interface); + *sender_state_out = SenderState::SENDER_WAITING_FOR_WORKER; + } + } break; + + default: break; + }; + + return incr_sender_channel_index; +}; + +template +void run_receiver_channel_state_machine_step( + tt::fabric::EthChannelBuffer &local_receiver_channel, + std::array, NUM_SENDER_CHANNELS> &remote_sender_channnels, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + ReceiverState *const receiver_state_out) { + switch (*receiver_state_out) { + case ReceiverState::RECEIVER_WAITING_FOR_ETH: { + bool got_payload = local_receiver_channel.eth_bytes_are_available_on_channel(); + if (got_payload) { + bool can_ack = !eth_txq_is_busy(); + if (can_ack) { + tt::fabric::validate( + *const_cast(local_receiver_channel.get_current_packet_header())); + ASSERT( + local_receiver_channel.get_current_packet_header()->command_fields.unicast_write.size < 100000); + receiver_send_received_ack(remote_sender_channnels, local_receiver_channel); + // TODO: PERF Need to add feature to let use perform local noc write and defer the forward to EDM + // if we are mcasting to the local chip and neighbours, but the downstream EDM isn't currently able + // to accept the packet + // ... + // but as a starting point we can do the dumb thing and just wait for space downstream + // before we do either. + *receiver_state_out = ReceiverState::RECEIVER_SENDING_PAYLOAD; + // TODO: PERF - SHORT CIRCUIT IF WE CAN TO NESXT STATE TO MINIMIZE LATENCY BUT CURRENTLY + // A LITTLE CODE SIZE BOUND + } + } + } break; + + case ReceiverState::RECEIVER_SENDING_PAYLOAD: { + auto packet_header = + *const_cast(local_receiver_channel.get_current_packet_header()); + bool can_send_to_all_local_chip_receivers = + can_forward_packet_completely(packet_header, downstream_edm_interface); + if (can_send_to_all_local_chip_receivers) { + receiver_forward_packet(local_receiver_channel.get_current_packet_header(), downstream_edm_interface); + *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH; + } + } break; + + case ReceiverState::RECEIVER_WAITING_FOR_WRITE_FLUSH: { + bool writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); + if (writes_flushed) { + bool can_send_ack_without_blocking = !eth_txq_is_busy(); + if (can_send_ack_without_blocking) { + receiver_send_completion_ack(remote_sender_channnels, local_receiver_channel); + *receiver_state_out = ReceiverState::RECEIVER_WAITING_FOR_ETH; + } + } + } break; + + default: break; + }; +}; + + +/* Termination signal handling*/ +FORCE_INLINE bool got_immediate_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return *termination_signal_ptr == tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE; +} +FORCE_INLINE bool got_graceful_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return *termination_signal_ptr == tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE; +} +FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + return got_immediate_termination_signal(termination_signal_ptr) || + got_graceful_termination_signal(termination_signal_ptr); +} + +/* + * Main control loop for fabric EDM. Run indefinitely until a termination signal is received + * + * Every loop iteration visit a sender channel and the receiver channel. Switch between sender + * channels every iteration unless it is unsafe/undesirable to do so (e.g. for performance reasons). + */ +template +void run_fabric_edm_main_loop( + tt::fabric::EthChannelBuffer &local_receiver_channel, + std::array, NUM_SENDER_CHANNELS> &local_sender_channels, + std::array &local_sender_channel_worker_interfaces, + tt::fabric::WorkerToFabricEdmSender &downstream_edm_noc_interface, + std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, + tt::fabric::EthChannelBuffer &remote_receiver_channel, + volatile tt::fabric::TerminationSignal *termination_signal_ptr) { + + std::array sender_states = { + SenderState::SENDER_WAIT_WORKER_HANDSHAKE, SenderState::SENDER_WAIT_WORKER_HANDSHAKE}; + ReceiverState receiver_state = ReceiverState::RECEIVER_WAITING_FOR_ETH; + size_t sender_channel_index = 0; + size_t did_nothing_count = 0; + *termination_signal_ptr = tt::fabric::TerminationSignal::KEEP_RUNNING; + + while (!got_termination_signal(termination_signal_ptr)) { + auto &local_sender_channel = local_sender_channels[sender_channel_index]; + auto &local_sender_channel_worker_interface = local_sender_channel_worker_interfaces[sender_channel_index]; + // There are some cases, mainly for performance, where we don't want to switch between sender channels + // so we interoduce this to provide finer grain control over when we disable the automatic switching + bool incr_sender_channel_index = run_sender_channel_state_machine_step( + local_sender_channel, + local_sender_channel_worker_interface, + remote_receiver_channel, + &(sender_states[sender_channel_index])); + if (incr_sender_channel_index) { + // TODO: this can probably be optimized + sender_channel_index = 1 - sender_channel_index; + } + + run_receiver_channel_state_machine_step( + local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, &receiver_state); + + if (did_nothing_count++ > SWITCH_INTERVAL) { + did_nothing_count = 0; + run_routing(); + } + } +} + +void kernel_main() { + // + // COMMON CT ARGS (not specific to sender or receiver) + // + static constexpr bool is_handshake_sender = get_compile_time_arg_val(0) != 0; + static constexpr size_t handshake_addr = get_compile_time_arg_val(1); + auto eth_transaction_ack_word_addr = handshake_addr + sizeof(eth_channel_sync_t); + + if constexpr (is_handshake_sender) { + erisc::datamover::handshake::sender_side_start(handshake_addr); + } else { + erisc::datamover::handshake::receiver_side_start(handshake_addr); + } + + // the size of one of the buffers within a sender channel + // For example if `channel_buffer_size` = 4k, with `SENDER_NUM_BUFFERS` = 2 + // then the total amount of buffering for that + static constexpr size_t channel_buffer_size = get_compile_time_arg_val(2); + + static constexpr size_t SENDER_NUM_BUFFERS = get_compile_time_arg_val(3); + static constexpr size_t RECEIVER_NUM_BUFFERS = get_compile_time_arg_val(4); + static constexpr size_t local_sender_0_channel_address = get_compile_time_arg_val(5); + static constexpr size_t local_sender_channel_0_connection_buffer_index_addr = get_compile_time_arg_val(6); + static constexpr size_t local_sender_channel_0_connection_info_addr = get_compile_time_arg_val(7); + static constexpr size_t local_sender_1_channel_address = get_compile_time_arg_val(8); + static constexpr size_t local_sender_channel_1_connection_buffer_index_addr = get_compile_time_arg_val(9); + static constexpr size_t local_sender_channel_1_connection_info_addr = get_compile_time_arg_val(10); + static constexpr size_t local_receiver_channel_buffer_address = get_compile_time_arg_val(11); + static constexpr size_t remote_receiver_channel_buffer_address = get_compile_time_arg_val(12); + static constexpr size_t remote_sender_0_channel_address = get_compile_time_arg_val(13); + static constexpr size_t remote_sender_1_channel_address = get_compile_time_arg_val(14); + + // TODO: CONVERT TO SEMAPHORE + volatile auto termination_signal_ptr = + reinterpret_cast(get_compile_time_arg_val(15)); + + static_assert(SENDER_NUM_BUFFERS > 0, "compile time argument [1]: SENDER_NUM_BUFFERS must be > 0"); + static_assert(RECEIVER_NUM_BUFFERS > 0, "compile time argument [2]: RECEIVER_NUM_BUFFERS must be > 0"); + + *reinterpret_cast(local_sender_channel_0_connection_buffer_index_addr) = 0; + *reinterpret_cast(local_sender_channel_1_connection_buffer_index_addr) = 0; + + size_t arg_idx = 0; + /////////////////////// + // Common runtime args: + /////////////////////// + + const size_t local_sender_channel_0_connection_semaphore_addr = + get_semaphore(get_arg_val(arg_idx++)); + const size_t local_sender_channel_1_connection_semaphore_addr = + get_semaphore(get_arg_val(arg_idx++)); + // downstream EDM semaphore location + const bool has_downstream_edm_buffer_connection = get_arg_val(arg_idx++) != 0; + const auto downstream_edm_buffer_base_address = get_arg_val(arg_idx++); + const auto downstream_edm_noc_x = get_arg_val(arg_idx++); + const auto downstream_edm_noc_y = get_arg_val(arg_idx++); + + // remote address for flow control + const auto downstream_edm_semaphore_id = get_arg_val(arg_idx++); // TODO: Convert to semaphore ID + const auto downstream_edm_worker_registration_address = + get_semaphore(get_arg_val(arg_idx++)); + const auto downstream_edm_worker_location_info_address = get_arg_val(arg_idx++); + const auto downstream_noc_interface_buffer_index_addr = + get_semaphore(get_arg_val(arg_idx++)); + + // Receiver channels local semaphore for managing flow control with the downstream EDM. + // The downstream EDM should be sending semaphore updates to this address any time it can + // accept a new message + const auto edm_forwarding_semaphore_address = + get_semaphore(get_arg_val(arg_idx++)); + + //////////////////////// + // Sender runtime args + //////////////////////// + auto sender0_worker_semaphore_ptr = reinterpret_cast( + get_semaphore(get_arg_val(arg_idx++))); + auto sender1_worker_semaphore_ptr = reinterpret_cast( + get_semaphore(get_arg_val(arg_idx++))); + *sender0_worker_semaphore_ptr = 0; + *sender1_worker_semaphore_ptr = 0; + + ////////////////////////////// + ////////////////////////////// + // Object Setup + ////////////////////////////// + ////////////////////////////// + + auto const &local_sender_buffer_addresses = + std::array{local_sender_0_channel_address, local_sender_1_channel_address}; + auto const &remote_sender_buffer_addresses = + std::array{remote_sender_0_channel_address, remote_sender_1_channel_address}; + std::array, NUM_SENDER_CHANNELS> remote_sender_channels; + std::array, NUM_SENDER_CHANNELS> local_sender_channels; + std::array local_sender_channel_worker_interfaces; + std::array local_sender_flow_control_semaphores = { + reinterpret_cast(sender0_worker_semaphore_ptr), reinterpret_cast(sender1_worker_semaphore_ptr)}; + std::array local_sender_connection_live_semaphore_addresses = { + local_sender_channel_0_connection_semaphore_addr, local_sender_channel_1_connection_semaphore_addr}; + std::array local_sender_connection_info_addresses = { + local_sender_channel_0_connection_info_addr, local_sender_channel_1_connection_info_addr}; + auto downstream_edm_noc_interface = + has_downstream_edm_buffer_connection + ? tt::fabric::WorkerToFabricEdmSender( + downstream_edm_noc_x, + downstream_edm_noc_y, + downstream_edm_buffer_base_address, + SENDER_NUM_BUFFERS, + downstream_edm_semaphore_id, + downstream_edm_worker_registration_address, // edm_connection_handshake_addr, + downstream_edm_worker_location_info_address, + channel_buffer_size, + reinterpret_cast(edm_forwarding_semaphore_address), + downstream_noc_interface_buffer_index_addr) + : tt::fabric::WorkerToFabricEdmSender(); + + auto local_receiver_channel = tt::fabric::EthChannelBuffer( + local_receiver_channel_buffer_address, + channel_buffer_size, + tt::fabric::header_size_bytes, + eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory that + // can fit 2 eth_channel_syncs cfor ack + receiver_channel_id); + auto remote_receiver_channel = tt::fabric::EthChannelBuffer( + remote_receiver_channel_buffer_address, + channel_buffer_size, + tt::fabric::header_size_bytes, + eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory that + // can fit 2 eth_channel_syncs cfor ack + receiver_channel_id); + + uint32_t args_offset = 0; + + for (uint8_t i = 0; i < NUM_SENDER_CHANNELS; i++) { + new (&local_sender_channels[i]) tt::fabric::EthChannelBuffer( + local_sender_buffer_addresses[i], + channel_buffer_size, + tt::fabric::header_size_bytes, + 0, // For sender channels there is no eth_transaction_ack_word_addr because they don't send acks + i); + new (&remote_sender_channels[i]) tt::fabric::EthChannelBuffer( + remote_sender_buffer_addresses[i], + channel_buffer_size, + tt::fabric::header_size_bytes, + 0, // For sender channels there is no eth_transaction_ack_word_addr because they don't send acks + i); + + auto connection_live_semaphore_ptr = + reinterpret_cast(local_sender_connection_live_semaphore_addresses[i]); + auto connection_worker_info_ptr = reinterpret_cast( + local_sender_connection_info_addresses[i]); + new (&local_sender_channel_worker_interfaces[i]) tt::fabric::EdmChannelWorkerInterface( + connection_worker_info_ptr, // worker_location_info_ptr, + reinterpret_cast( + local_sender_flow_control_semaphores[i]), // local_semaphore_address, + reinterpret_cast(connection_live_semaphore_ptr)); + } + + if (has_downstream_edm_buffer_connection) { + downstream_edm_noc_interface.open(); + } + + if constexpr (is_handshake_sender) { + erisc::datamover::handshake::sender_side_finish(handshake_addr); + } else { + erisc::datamover::handshake::receiver_side_finish(handshake_addr); + } + + ////////////////////////////// + ////////////////////////////// + // MAIN LOOP + ////////////////////////////// + ////////////////////////////// + run_fabric_edm_main_loop( + local_receiver_channel, + local_sender_channels, + local_sender_channel_worker_interfaces, + downstream_edm_noc_interface, + remote_sender_channels, + remote_receiver_channel, + termination_signal_ptr); + + if (got_graceful_termination_signal(termination_signal_ptr)) { + ASSERT(false); + } else { + // So long suckers! + } + + WAYPOINT("DONE"); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp new file mode 100644 index 000000000000..90f2c692aa2e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -0,0 +1,225 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "debug/dprint.h" +#include "tt_metal/hw/inc/dataflow_api.h" +#include "tt_metal/hw/inc/ethernet/tunneling.h" +#include "tt_metal/hw/inc/risc_attribs.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" + +namespace tt::fabric { +// Increments val and wraps to 0 if it reaches limit +template +auto wrap_increment(T val) -> T { + static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); + if constexpr (LIMIT == 1) { + return val; + } else if constexpr (LIMIT == 2) { + return 1 - val; + } else if constexpr ((LIMIT > 0) && (LIMIT & (LIMIT - 1)) == 0) { + return (val + 1) & (LIMIT - 1); + } else { + return (val == LIMIT - 1) ? 0 : val + 1; + } +} + +template +FORCE_INLINE auto wrap_increment(T val, size_t max) { + return (val == max - 1) ? 0 : val + 1; +} + +template +class EthChannelBuffer final { + public: + // The channel structure is as follows: + // &header-> |----------------| channel_base_address + // | header | + // &payload-> |----------------| + // | | + // | payload | + // | | + // &channel_sync-> |----------------| + // | channel_sync | + // ------------------ + EthChannelBuffer() : buffer_size_in_bytes(0), eth_transaction_ack_word_addr(0), max_eth_payload_size_in_bytes(0) {} + + /* + * Expected that *buffer_index_ptr is initialized outside of this object + */ + EthChannelBuffer( + size_t channel_base_address, + size_t buffer_size_bytes, + size_t header_size_bytes, + size_t eth_transaction_ack_word_addr, // Assume for receiver channel, this address points to a chunk of memory + // that can fit 2 eth_channel_syncs cfor ack + uint8_t channel_id) : + buffer_size_in_bytes(buffer_size_bytes), + eth_transaction_ack_word_addr(eth_transaction_ack_word_addr), + max_eth_payload_size_in_bytes(buffer_size_in_bytes + sizeof(eth_channel_sync_t)), + buff_idx(0), + channel_id(channel_id) { + for (uint8_t i = 0; i < NUM_BUFFERS; i++) { + this->buffer_addresses[i] = + channel_base_address + i * this->max_eth_payload_size_in_bytes; //(this->buffer_size_in_bytes); + + uint32_t channel_sync_addr = this->buffer_addresses[i] + buffer_size_in_bytes; + auto channel_sync_ptr = reinterpret_cast(channel_sync_addr); + + channel_bytes_sent_addresses[i] = + reinterpret_cast(&(channel_sync_ptr->bytes_sent)); + channel_bytes_acked_addresses[i] = + reinterpret_cast(&(channel_sync_ptr->receiver_ack)); + channel_src_id_addresses[i] = reinterpret_cast(&(channel_sync_ptr->src_id)); + + ASSERT((uint32_t)channel_bytes_acked_addresses[i] != (uint32_t)(channel_bytes_sent_addresses[i])); + *(channel_bytes_sent_addresses[i]) = 0; + *(channel_bytes_acked_addresses[i]) = 0; + // Note we don't need to overwrite the `channel_src_id_addresses` except for perhapse + // debug purposes where we may wish to tag this with a special value + } + } + + [[nodiscard]] FORCE_INLINE size_t get_current_buffer_address() const { + return this->buffer_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile PacketHeader *get_current_packet_header() const { + return reinterpret_cast(this->buffer_addresses[this->buffer_index()]); + } + + [[nodiscard]] FORCE_INLINE size_t get_current_payload_size() const { + return get_current_packet_header()->get_payload_size_including_header(); + } + [[nodiscard]] FORCE_INLINE size_t get_current_payload_plus_channel_sync_size() const { + return get_current_packet_header()->get_payload_size_including_header() + sizeof(eth_channel_sync_t); + } + + // TODO: Split off into two separate functions: + // volatile tt_l1_ptr size_t *get_current_bytes_sent_ptr() const + // size_t get_current_bytes_sent_address() const + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_sent_address() const { + return this->channel_bytes_sent_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_bytes_acked_address() const { + return this->channel_bytes_acked_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE volatile tt_l1_ptr size_t *get_current_src_id_address() const { + return this->channel_src_id_addresses[this->buffer_index()]; + } + + [[nodiscard]] FORCE_INLINE size_t get_channel_buffer_max_size_in_bytes() const { + return this->buffer_size_in_bytes; + } + + // Doesn't return the message size, only the maximum eth payload size + [[nodiscard]] FORCE_INLINE size_t get_current_max_eth_payload_size() const { + return this->max_eth_payload_size_in_bytes; + } + + [[nodiscard]] FORCE_INLINE size_t get_id() const { return this->channel_id; } + + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_done() const { + return *(this->get_current_bytes_sent_address()) == 0; + } + [[nodiscard]] FORCE_INLINE bool eth_bytes_are_available_on_channel() const { + return *(this->get_current_bytes_sent_address()) != 0; + } + [[nodiscard]] FORCE_INLINE bool eth_is_receiver_channel_send_acked() const { + return *(this->get_current_bytes_acked_address()) != 0; + } + FORCE_INLINE void eth_clear_sender_channel_ack() const { + *(this->channel_bytes_acked_addresses[this->buffer_index()]) = 0; + } + + [[nodiscard]] FORCE_INLINE size_t get_eth_transaction_ack_word_addr() const { + return this->eth_transaction_ack_word_addr; + } + + FORCE_INLINE void advance_buffer_index() { + this->buff_idx = wrap_incrementbuff_idx), NUM_BUFFERS>(this->buff_idx); + } + + private: + FORCE_INLINE auto buffer_index() const { + ASSERT(this->buff_idx < NUM_BUFFERS); + return buff_idx; + } + + std::array buffer_addresses; + std::array channel_bytes_sent_addresses; + std::array channel_bytes_acked_addresses; + std::array channel_src_id_addresses; + + // header + payload regions only + const std::size_t buffer_size_in_bytes; + // Includes header + payload + channel_sync + const std::size_t eth_transaction_ack_word_addr; + const std::size_t max_eth_payload_size_in_bytes; + uint8_t buff_idx; + uint8_t channel_id; +}; + +struct EdmChannelWorkerInterface { + EdmChannelWorkerInterface() : + worker_location_info_ptr(nullptr), local_semaphore_address(nullptr), connection_live_semaphore(nullptr) {} + EdmChannelWorkerInterface( + // TODO: PERF: See if we can make this non-volatile and then only + // mark it volatile when we know we need to reload it (i.e. after we receive a + // "done" message from sender) + // Have a volatile update function that only triggers after reading the volatile + // completion field so that way we don't have to do a volatile read for every + // packet... Then we'll also be able to cache the uint64_t addr of the worker + // semaphore directly (saving on regenerating it each time) + volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr, + volatile tt_l1_ptr uint32_t *const local_semaphore_address, + volatile tt_l1_ptr uint32_t *const connection_live_semaphore) : + worker_location_info_ptr(worker_location_info_ptr), + local_semaphore_address(local_semaphore_address), + connection_live_semaphore(connection_live_semaphore) {} + + // Flow control methods + // + [[nodiscard]] FORCE_INLINE auto local_semaphore_value() const { return *local_semaphore_address; } + + [[nodiscard]] FORCE_INLINE bool has_payload() { return *local_semaphore_address != 0; } + + FORCE_INLINE void clear_local_semaphore() { noc_semaphore_set(local_semaphore_address, 0); } + + [[nodiscard]] FORCE_INLINE uint32_t get_worker_semaphore_address() const { + return worker_location_info_ptr->worker_semaphore_address; + } + + void increment_worker_semaphore() const { + auto const &worker_info = *worker_location_info_ptr; + uint64_t worker_semaphore_address = get_noc_addr( + (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_semaphore_address); + + DPRINT << "EDMS notif @ " << (uint64_t)worker_semaphore_address << "\n"; + noc_semaphore_inc(worker_semaphore_address, 1); + } + + // Connection management methods + // + FORCE_INLINE void teardown_connection() const { increment_worker_semaphore(); } + + [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == 0; } + + [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == 1; } + + volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr; + volatile tt_l1_ptr uint32_t *const local_semaphore_address; + volatile tt_l1_ptr uint32_t *const connection_live_semaphore; +}; + +} // namespace tt::fabric