From d277980875bf0b8ba2e60113c9009a236622b564 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 18 Feb 2025 10:30:18 -0500 Subject: [PATCH 1/8] Apply EDM Fabric Optimizations - Up to 13.5 GB/s bidir unicast and 10.5 GB/s bidir mcast @4k packet size (#17930) Numerous EDM Fabric (1D Fabric) optimizations that take the EDM fabric to the following approximate performance with 4K packet size: - 13.5 GB/s in neighbour exchange test - 10.5 GB/s in 4chip mcast test Measured ~ 1 GB/s higher when compiling with -O3 but that is currently not enabled in this PR The optimizations in this PR include: - Add optimized power-of-2 queue pointer handling and enable power-of-2 buffer slot counts - Add optimized power-of-2 transaction ID handling and use power-of-2 transaction IDs on write - Mild cleanup/optimizations of volatile pointer usage - Optimize main top level control loop of EDM fabric - Reduce the frequency of context switch/teardown checks - Nest main control loop in a tight loop - Partially unrol sender state execution steps (one for each channel) instead of using a sender channel ID to alternate through them --- .../gtests/ccl/kernels/edm_fabric_writer.cpp | 13 +- ...erisc_data_mover_loopback_with_workers.cpp | 11 ++ .../ccl/erisc_datamover_builder.cpp | 27 ++- .../ccl/erisc_datamover_builder.hpp | 2 + .../edm_fabric_flow_control_helpers.hpp | 162 +++++++++++++++++ .../edm_fabric/edm_fabric_worker_adapters.hpp | 93 +++++++--- .../fabric_edm_packet_transmission.hpp | 17 +- .../edm_fabric/fabric_erisc_datamover.cpp | 163 +++++++++++------- .../fabric_erisc_datamover_channels.hpp | 147 +--------------- 9 files changed, 381 insertions(+), 254 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp index 952a49631045..91fe40d181e4 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp @@ -139,13 +139,9 @@ void kernel_main() { safe_get_noc_addr(static_cast(dest_noc_x), static_cast(dest_noc_y), dest_bank_addr); noc_async_write(source_l1_buffer_address, dest_addr, packet_payload_size_bytes); if (fabric_connection.has_forward_connection()) { - DeviceZoneScopedN("WR-FWD"); mcast_fwd_packet_header->to_noc_unicast_write( NocUnicastCommandHeader{noc0_dest_addr}, packet_payload_size_bytes); - { - DeviceZoneScopedN("WR-FWD-WAIT"); - fabric_connection.get_forward_connection().wait_for_empty_write_slot(); - } + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); print_pkt_header(mcast_fwd_packet_header); fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( source_l1_buffer_address, packet_payload_size_bytes); @@ -154,13 +150,9 @@ void kernel_main() { } if (fabric_connection.has_backward_connection()) { - DeviceZoneScopedN("WR-BWD"); mcast_bwd_packet_header->to_noc_unicast_write( NocUnicastCommandHeader{noc0_dest_addr}, packet_payload_size_bytes); - { - DeviceZoneScopedN("WR-BWD-WAIT"); - fabric_connection.get_backward_connection().wait_for_empty_write_slot(); - } + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); print_pkt_header(mcast_bwd_packet_header); fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( source_l1_buffer_address, packet_payload_size_bytes); @@ -176,7 +168,6 @@ void kernel_main() { for (size_t i = 0; i < num_unicasts; i++) { auto noc0_dest_addr = safe_get_noc_addr(static_cast(dest_noc_x), static_cast(dest_noc_y), dest_bank_addr, 0); - DeviceZoneScopedN("UNICAST-WRITE"); auto& fabric_conn = unicast_is_fwd ? fabric_connection.get_forward_connection() : fabric_connection.get_backward_connection(); unicast_packet_header->to_noc_unicast_write(NocUnicastCommandHeader{noc0_dest_addr}, packet_payload_size_bytes); diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index 4f9eadf730c2..1ab121ffec7d 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -3590,6 +3590,17 @@ TEST(EdmFabric, BasicMcastThroughputTest_2) { RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } +TEST(EdmFabric, BasicMcastThroughputTest_3_SingleLink) { + const size_t num_mcasts = 200000; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} TEST(EdmFabric, BasicMcastThroughputTest_3) { const size_t num_mcasts = 200000; const size_t num_unicasts = 2; diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp index 8be28978f47a..2f505f415863 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -75,24 +75,43 @@ FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( TT_FATAL(sender_channel_1_buffer_index_address != sender_channel_0_buffer_index_address, "FabricEriscDatamoverConfig was constructed with illegal buffer index address"); const size_t min_buffer_size = sizeof(tt::fabric::PacketHeader) + 2 * FabricEriscDatamoverConfig::eth_channel_sync_size; TT_FATAL(channel_buffer_size_bytes >= min_buffer_size, "FabricEriscDatamoverConfig was constructed with `channel_buffer_size_bytes` argument set smaller than minimum size of {}", min_buffer_size); + + constexpr size_t default_pow2_num_sender_buffer_slots = 8; + constexpr size_t default_pow2_num_receiver_buffer_slots = 16; + const std::size_t channel_buffer_size_with_channel_sync = channel_buffer_size_bytes + sizeof(tt::fabric::PacketHeader); // + 16 // sizeof(tt::fabric::PacketHeader); - this->channel_buffer_size_bytes = channel_buffer_size_bytes; + const size_t next_lowest_power_of_2_buffer_slot_count = + + this->channel_buffer_size_bytes = channel_buffer_size_bytes; this->channel_buffer_size_bytes_with_channel_sync = channel_buffer_size_with_channel_sync; const std::size_t total_ratio_count = 2 * sender_ratio_size + receiver_ratio_size; + this->sender_0_channel_size_bytes = tt::round_down( (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, channel_buffer_size_with_channel_sync); - this->sender_0_num_buffers = this->sender_0_channel_size_bytes / channel_buffer_size_with_channel_sync; + if constexpr (FabricEriscDatamoverConfig::constrain_to_power_of_2_buffer_slot_counts) { + this->sender_0_num_buffers = default_pow2_num_sender_buffer_slots; + } else { + this->sender_0_num_buffers = this->sender_0_channel_size_bytes / channel_buffer_size_with_channel_sync; + } this->sender_1_channel_size_bytes = tt::round_down( (available_channel_buffering_space / total_ratio_count) * sender_ratio_size, channel_buffer_size_with_channel_sync); - this->sender_1_num_buffers = this->sender_1_channel_size_bytes / channel_buffer_size_with_channel_sync; + if constexpr (FabricEriscDatamoverConfig::constrain_to_power_of_2_buffer_slot_counts) { + this->sender_1_num_buffers = default_pow2_num_sender_buffer_slots; + } else { + this->sender_1_num_buffers = this->sender_1_channel_size_bytes / channel_buffer_size_with_channel_sync; + } this->receiver_channel_size_bytes = tt::round_down( (available_channel_buffering_space / total_ratio_count) * receiver_ratio_size, channel_buffer_size_with_channel_sync); - this->receiver_num_buffers = this->receiver_channel_size_bytes / channel_buffer_size_with_channel_sync; + if constexpr (FabricEriscDatamoverConfig::constrain_to_power_of_2_buffer_slot_counts) { + this->receiver_num_buffers = default_pow2_num_receiver_buffer_slots; + } else { + this->receiver_num_buffers = this->receiver_channel_size_bytes / channel_buffer_size_with_channel_sync; + } this->sender_0_channel_base_address = buffer_region_start; this->sender_1_channel_base_address = this->sender_0_channel_base_address + this->sender_0_channel_size_bytes; diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp index 1d32db7f8c33..a9d1a076ba67 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -30,6 +30,8 @@ namespace ccl { struct FabricEriscDatamoverConfig { + static constexpr bool constrain_to_power_of_2_buffer_slot_counts = true; + static constexpr std::size_t field_size = 16; static constexpr std::size_t buffer_alignment = 32; static constexpr std::size_t eth_word_l1_alignment = 16; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp new file mode 100644 index 000000000000..63bf9bad9f36 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tt_metal/hw/inc/utils/utils.h" +#include "risc_attribs.h" + +namespace tt::fabric { + +template +class NamedType { +public: + FORCE_INLINE explicit NamedType(const T& value) : value_(value) {} + FORCE_INLINE explicit NamedType(T&& value) : value_(std::move(value)) {} + FORCE_INLINE NamedType& operator=(const NamedType& rhs) = default; + FORCE_INLINE T& get() { return value_; } + FORCE_INLINE const T& get() const { return value_; } + FORCE_INLINE operator T() const { return value_; } + FORCE_INLINE operator T&() { return value_; } + +private: + T value_; +}; + +using BufferIndex = NamedType; +using BufferPtr = NamedType; + +// Increments val and wraps to 0 if it reaches limit +template +FORCE_INLINE auto wrap_increment(T val) -> T { + constexpr bool is_pow2 = LIMIT != 0 && is_power_of_2(LIMIT); + if constexpr (LIMIT == 1) { + return val; + } else if constexpr (LIMIT == 2) { + return 1 - val; + } else if constexpr (is_pow2) { + return (val + 1) & (static_cast(LIMIT - 1)); + } else { + return (val == static_cast(LIMIT - 1)) ? static_cast(0) : static_cast(val + 1); + } +} +template +FORCE_INLINE auto wrap_increment_n(T val, uint8_t increment) -> T { + constexpr bool is_pow2 = LIMIT != 0 && is_power_of_2(LIMIT); + if constexpr (LIMIT == 1) { + return val; + } else if constexpr (LIMIT == 2) { + return 1 - val; + } else if constexpr (is_pow2) { + return (val + increment) & (LIMIT - 1); + } else { + T new_unadjusted_val = val + increment; + bool wraps = new_unadjusted_val >= LIMIT; + return wraps ? static_cast(new_unadjusted_val - LIMIT) : static_cast(new_unadjusted_val); + } +} + +FORCE_INLINE +auto normalize_ptr(BufferPtr ptr, uint8_t num_buffers) -> BufferIndex { + // note it may make sense to calculate this only when we increment + // which will save calculations overall (but may add register pressure) + // and introduce undesirable loads + bool normalize = ptr >= num_buffers; + uint8_t normalized_ptr = ptr.get() - static_cast(normalize * num_buffers); + ASSERT(normalized_ptr < num_buffers); + return BufferIndex{normalized_ptr}; +} +template +FORCE_INLINE auto normalize_ptr(BufferPtr ptr) -> BufferIndex { + static_assert(NUM_BUFFERS != 0, "normalize_ptr called with NUM_BUFFERS of 0; it must be greater than 0"); + constexpr bool is_size_pow2 = NUM_BUFFERS != 0 && (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; + constexpr bool is_size_2 = NUM_BUFFERS == 2; + constexpr bool is_size_1 = NUM_BUFFERS == 1; + constexpr uint8_t wrap_mask = NUM_BUFFERS - 1; + if constexpr (is_size_pow2) { + return BufferIndex{static_cast(ptr.get() & wrap_mask)}; + } else if constexpr (is_size_2) { + return BufferIndex{(uint8_t)1 - ptr.get()}; + } else if constexpr (is_size_1) { + return BufferIndex{0}; + } else { + // note it may make sense to calculate this only when we increment + // which will save calculations overall (but may add register pressure) + // and introduce undesirable loads + return normalize_ptr(ptr, NUM_BUFFERS); + } +} + +FORCE_INLINE uint8_t +distance_behind(const BufferPtr& trailing_ptr, const BufferPtr& leading_ptr, uint8_t ptr_wrap_size) { + bool leading_gte_trailing_ptr = leading_ptr >= trailing_ptr; + return leading_gte_trailing_ptr ? leading_ptr - trailing_ptr : ptr_wrap_size - (trailing_ptr - leading_ptr); +} +template +FORCE_INLINE uint8_t distance_behind(const BufferPtr& trailing_ptr, const BufferPtr& leading_ptr) { + static_assert(NUM_BUFFERS != 0, "distance_behind called with NUM_BUFFERS of 0; it must be greater than 0"); + constexpr bool is_size_pow2 = is_power_of_2(NUM_BUFFERS); + constexpr uint8_t ptr_wrap_mask = (2 * NUM_BUFFERS) - 1; + constexpr uint8_t ptr_wrap_size = 2 * NUM_BUFFERS; + bool leading_gte_trailing_ptr = leading_ptr >= trailing_ptr; + if constexpr (is_size_pow2) { + return (leading_ptr - trailing_ptr) & ptr_wrap_mask; + } else { + return distance_behind(trailing_ptr, leading_ptr, ptr_wrap_size); + } +} + +template +class ChannelBufferPointer { + static_assert( + NUM_BUFFERS <= std::numeric_limits::max() / 2, + "NUM_BUFFERS must be less than or half of std::numeric_limits::max() due to the internal " + "implementation"); + +public: + static constexpr bool is_size_pow2 = (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; + static constexpr bool is_size_2 = NUM_BUFFERS == 2; + static constexpr bool is_size_1 = NUM_BUFFERS == 1; + static constexpr uint8_t ptr_wrap_size = 2 * NUM_BUFFERS; + + // Only to use if is_size_pow2 + static constexpr uint8_t ptr_wrap_mask = (2 * NUM_BUFFERS) - 1; + static constexpr uint8_t buffer_wrap_mask = NUM_BUFFERS - 1; + ChannelBufferPointer() : ptr(0) {} + /* + * Returns the "raw" pointer - not usable to index the buffer channel + */ + FORCE_INLINE BufferPtr get_ptr() const { return this->ptr; } + + FORCE_INLINE bool is_caught_up_to(const ChannelBufferPointer& leading_ptr) const { + return this->is_caught_up_to(leading_ptr.get_ptr()); + } + FORCE_INLINE uint8_t distance_behind(const ChannelBufferPointer& leading_ptr) const { + return this->distance_behind(leading_ptr.get_ptr()); + } + + /* + * Returns the buffer index pointer which is usable to index into the buffer memory + */ + FORCE_INLINE BufferIndex get_buffer_index() const { return BufferIndex{normalize_ptr(this->ptr)}; } + + FORCE_INLINE void increment_n(uint8_t n) { + this->ptr = BufferPtr{wrap_increment_n<2 * NUM_BUFFERS>(this->ptr.get(), n)}; + } + FORCE_INLINE void increment() { this->ptr = BufferPtr{wrap_increment<2 * NUM_BUFFERS>(this->ptr.get())}; } + +private: + // Make these private to make sure caller doesn't accidentally mix two pointers pointing to + // different sized channels + FORCE_INLINE bool is_caught_up_to(const BufferPtr& leading_ptr) const { return this->get_ptr() == leading_ptr; } + FORCE_INLINE uint8_t distance_behind(const BufferPtr& leading_ptr) const { + return tt::fabric::distance_behind(this->ptr, leading_ptr); + } + BufferPtr ptr = BufferPtr{0}; +}; + +} // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp index e6b2253c2778..4864cea0b293 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -10,6 +10,8 @@ #include "cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header_validate.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp" +#include "tt_metal/hw/inc/utils/utils.h" #include "debug/assert.h" #include "debug/dprint.h" #include @@ -17,7 +19,7 @@ namespace tt::fabric { /* - * The WorkerToFabricEdmSender acts as an adapter between the worker and the EDM, it hides details + * The WorkerToFabricEdmSenderImpl acts as an adapter between the worker and the EDM, it hides details * of the communication between worker and EDM to provide flexibility for the implementation to change * over time without kernel updates. Additionally, details for adapter setup w.r.t runtime args is also hidden. * The main functionality provided is: @@ -34,15 +36,20 @@ namespace tt::fabric { * As the adapter writes into the EDM, it updates the local wrptr. As the EDM reads from its local L1 channel buffer, * it will notify the worker/adapter (here) by updating the worker remote_rdptr to carry the value of the EDM rdptr. */ -struct WorkerToFabricEdmSender { +template +struct WorkerToFabricEdmSenderImpl { + static constexpr bool USER_DEFINED_NUM_BUFFER_SLOTS = EDM_NUM_BUFFER_SLOTS != 0; + static constexpr bool IS_POW2_NUM_BUFFERS = USER_DEFINED_NUM_BUFFER_SLOTS && is_power_of_2(EDM_NUM_BUFFER_SLOTS); + static constexpr size_t BUFFER_SLOT_PTR_WRAP = EDM_NUM_BUFFER_SLOTS * 2; + static constexpr size_t LAST_BUFFER_SLOT_PTR_BEFORE_WRAP = BUFFER_SLOT_PTR_WRAP - 1; static constexpr uint32_t unused_connection_value = 0; static constexpr uint32_t open_connection_value = 1; static constexpr uint32_t close_connection_request_value = 2; - WorkerToFabricEdmSender() : from_remote_buffer_slot_rdptr_ptr(nullptr) {} + WorkerToFabricEdmSenderImpl() : from_remote_buffer_slot_rdptr_ptr(nullptr) {} template - static WorkerToFabricEdmSender build_from_args(std::size_t& arg_idx) { + static WorkerToFabricEdmSenderImpl build_from_args(std::size_t& arg_idx) { bool is_persistent_fabric = get_arg_val(arg_idx++); WorkerXY const edm_worker_xy = WorkerXY::from_uint32(get_arg_val(arg_idx++)); auto const edm_buffer_base_addr = get_arg_val(arg_idx++); @@ -64,7 +71,7 @@ struct WorkerToFabricEdmSender { (my_core_type == ProgrammableCoreType::TENSIX && (uint32_t)writer_send_sem_addr < 1499136) || (my_core_type == ProgrammableCoreType::ACTIVE_ETH && (uint32_t)writer_send_sem_addr < 262144)); ASSERT(edm_buffer_index_addr < 262144); - return WorkerToFabricEdmSender( + return WorkerToFabricEdmSenderImpl( is_persistent_fabric, edm_worker_xy.x, edm_worker_xy.y, @@ -80,7 +87,7 @@ struct WorkerToFabricEdmSender { worker_buffer_index_semaphore_addr); } - WorkerToFabricEdmSender( + WorkerToFabricEdmSenderImpl( bool connected_to_persistent_fabric, uint8_t edm_worker_x, uint8_t edm_worker_y, @@ -116,18 +123,45 @@ struct WorkerToFabricEdmSender { edm_noc_x(edm_worker_x), edm_noc_y(edm_worker_y) { ASSERT(buffer_size_bytes > 0); + if constexpr (USER_DEFINED_NUM_BUFFER_SLOTS) { + ASSERT(num_buffers_per_channel == EDM_NUM_BUFFER_SLOTS); + } } FORCE_INLINE bool edm_has_space_for_packet() const { - auto const wrptr = *this->buffer_slot_wrptr_ptr; - auto const rdptr = *this->from_remote_buffer_slot_rdptr_ptr; - bool wrptr_ge_rptr = wrptr >= rdptr; - uint8_t slots_used = wrptr_ge_rptr ? (wrptr - rdptr) : ((2 * this->num_buffers_per_channel) - rdptr) + wrptr; - return slots_used < this->num_buffers_per_channel; + using namespace tt::fabric; + if constexpr (USER_DEFINED_NUM_BUFFER_SLOTS) { + auto slots_used = distance_behind( + BufferPtr{static_cast(*this->from_remote_buffer_slot_rdptr_ptr)}, + BufferPtr{static_cast(*this->buffer_slot_wrptr_ptr)}); + return slots_used < this->num_buffers_per_channel; + } else { + auto const rdptr = *this->from_remote_buffer_slot_rdptr_ptr; + auto const wrptr = *this->buffer_slot_wrptr_ptr; + auto buffer_ptr_wrap = 2 * this->num_buffers_per_channel; + auto slots_used = distance_behind( + BufferPtr{static_cast(rdptr)}, + BufferPtr{static_cast(wrptr)}, + buffer_ptr_wrap); + return slots_used < this->num_buffers_per_channel; + } } FORCE_INLINE void wait_for_empty_write_slot() const { - while (!this->edm_has_space_for_packet()); + using namespace tt::fabric; + if constexpr (USER_DEFINED_NUM_BUFFER_SLOTS) { + while (distance_behind(BufferPtr{static_cast(*this->from_remote_buffer_slot_rdptr_ptr)}, BufferPtr{static_cast(*this->buffer_slot_wrptr_ptr)}) < this->num_buffers_per_channel); + } else { + auto const first_rdptr = *this->from_remote_buffer_slot_rdptr_ptr; + auto buffer_ptr_wrap = 2 * this->num_buffers_per_channel; + bool has_space = distance_behind( + BufferPtr{static_cast(first_rdptr)}, + BufferPtr{static_cast(*this->buffer_slot_wrptr_ptr)}, + buffer_ptr_wrap) < this->num_buffers_per_channel; + if (!has_space) { + while (first_rdptr == *this->from_remote_buffer_slot_rdptr_ptr); + } + } } FORCE_INLINE void send_payload_blocking(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { @@ -192,6 +226,8 @@ struct WorkerToFabricEdmSender { const uint64_t edm_connection_handshake_noc_addr = dest_noc_addr_coord_only | edm_connection_handshake_l1_addr; noc_inline_dw_write(edm_connection_handshake_noc_addr, open_connection_value); noc_async_read_barrier(); + + this->edm_buffer_addr = this->edm_buffer_base_addr + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); ASSERT(*this->buffer_slot_wrptr_ptr < 20); } @@ -249,25 +285,27 @@ struct WorkerToFabricEdmSender { noc_inline_dw_write(noc_sem_addr, *this->buffer_slot_wrptr_ptr); } - FORCE_INLINE void advance_buffer_slot_wrptr() { - // TODO: smarter addition if we are working with pow2 - uint8_t wrptr = *this->buffer_slot_wrptr_ptr; - *this->buffer_slot_wrptr_ptr = - !(wrptr == ((this->num_buffers_per_channel * 2) - 1)) ? wrptr + 1 : 0; - } - FORCE_INLINE uint8_t get_buffer_slot_index() const { - auto const wrptr = *this->buffer_slot_wrptr_ptr; - bool normalize = wrptr >= this->num_buffers_per_channel; - return wrptr - (normalize * this->num_buffers_per_channel); + if constexpr (USER_DEFINED_NUM_BUFFER_SLOTS) { + return normalize_ptr(BufferPtr{static_cast(*this->buffer_slot_wrptr_ptr)}); + } else { + return normalize_ptr(BufferPtr{static_cast(*this->buffer_slot_wrptr_ptr)}, this->num_buffers_per_channel); + } } - FORCE_INLINE uint32_t compute_dest_buffer_slot_bank_address() const { - return this->edm_buffer_addr + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + FORCE_INLINE void advance_buffer_slot_wrptr() { + if constexpr (USER_DEFINED_NUM_BUFFER_SLOTS) { + *this->buffer_slot_wrptr_ptr = wrap_increment(*this->buffer_slot_wrptr_ptr); + } else { + uint8_t wrptr = *this->buffer_slot_wrptr_ptr; + *this->buffer_slot_wrptr_ptr = + !(wrptr == ((this->num_buffers_per_channel * 2) - 1)) ? wrptr + 1 : 0; + } + this->edm_buffer_addr = this->edm_buffer_base_addr + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); } FORCE_INLINE uint64_t compute_dest_buffer_slot_noc_addr() const { - return get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->compute_dest_buffer_slot_bank_address()); + return get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr); } FORCE_INLINE void post_send_payload_increment_pointers() { @@ -319,4 +357,9 @@ struct WorkerToFabricEdmSender { } }; +using WorkerToFabricEdmSender = WorkerToFabricEdmSenderImpl<0>; + +template +using EdmToEdmSender = WorkerToFabricEdmSenderImpl; + } // namespace tt::fabric diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp index 35533d4d26e5..85553bf6dab1 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -16,7 +16,7 @@ static constexpr size_t DESTINATION_HOP_COUNT = 1; // TODO: make 0 and the associated field to num mcast destinations static constexpr size_t LAST_MCAST_DESTINATION = 1; -void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { +FORCE_INLINE void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { #ifdef DEBUG_PRINT_ENABLED switch (packet_start->chip_send_type) { case tt::fabric::CHIP_UNICAST: { @@ -32,7 +32,7 @@ void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packe #endif } -void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet_start) { +FORCE_INLINE void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet_start) { #ifdef DEBUG_PRINT_ENABLED switch (packet_start->noc_send_type) { case tt::fabric::NocSendType::NOC_UNICAST_WRITE: { @@ -50,7 +50,7 @@ void print_pkt_header_noc_fields(volatile tt::fabric::PacketHeader *const packet #endif } -void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { +FORCE_INLINE void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { #ifdef DEBUG_PRINT_ENABLED auto const& header = *packet_start; DPRINT << "PKT: nsnd_t:" << (uint32_t) packet_start->noc_send_type << @@ -64,12 +64,12 @@ void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { // Since we unicast to local, we must omit the packet header -FORCE_INLINE void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start, uint32_t transaction_id) { +FORCE_INLINE void execute_chip_unicast_to_local_chip( + volatile tt::fabric::PacketHeader *const packet_start, uint16_t payload_size_bytes, uint32_t transaction_id) { auto const& header = *packet_start; uint32_t payload_start_address = reinterpret_cast(packet_start) + sizeof(tt::fabric::PacketHeader); tt::fabric::NocSendType noc_send_type = packet_start->noc_send_type; - auto const payload_size_bytes = header.payload_size_bytes; switch (noc_send_type) { case tt::fabric::NocSendType::NOC_UNICAST_WRITE: { auto const dest_address = header.command_fields.unicast_write.noc_address; @@ -125,13 +125,14 @@ FORCE_INLINE void update_packet_header_for_next_hop(volatile tt::fabric::PacketH // !!!WARNING!!! * do NOT call before determining if the packet should be consumed locally or forwarded // !!!WARNING!!! * ENSURE DOWNSTREAM EDM HAS SPACE FOR PACKET BEFORE CALLING // !!!WARNING!!! +template FORCE_INLINE void forward_payload_to_downstream_edm( volatile tt::fabric::PacketHeader *packet_header, + uint16_t payload_size_bytes, tt::fabric::RoutingFields cached_routing_fields, - tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + tt::fabric::EdmToEdmSender &downstream_edm_interface, uint8_t transaction_id ) { - DPRINT << "Fwding pkt to downstream\n"; // TODO: PERF - this should already be getting checked by the caller so this should be redundant make it an ASSERT ASSERT(downstream_edm_interface.edm_has_space_for_packet()); // best effort check @@ -140,6 +141,6 @@ FORCE_INLINE void forward_payload_to_downstream_edm( update_packet_header_for_next_hop(packet_header, cached_routing_fields); downstream_edm_interface.send_payload_non_blocking_from_address_with_trid( reinterpret_cast(packet_header), - packet_header->get_payload_size_including_header(), + payload_size_bytes + sizeof(tt::fabric::PacketHeader), transaction_id); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp index b0c732ee00b8..4f7b82b5ce70 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -14,6 +14,7 @@ #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "noc_overlay_parameters.h" +#include "tt_metal/hw/inc/utils/utils.h" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp" @@ -23,7 +24,7 @@ using ttnn::ccl::WorkerXY; -static constexpr bool enable_first_level_ack = true; +static constexpr bool enable_first_level_ack = false; static constexpr bool fuse_receiver_flush_and_completion_ptr = true; /* @@ -110,8 +111,8 @@ by the worker (the EDM is a slave in this protocol). *NOTE*: Additionally, if a worker pushes packets to a channel it isn't connected to, behaviour is undefined. *NOTE*: Undefined == likely hang -The `WorkerToFabricEdmSender` from `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp` -provides an implementation of the connection protocol. `WorkerToFabricEdmSender` also acts as a wrapper around that +The `EdmToEdmSender` from `ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp` +provides an implementation of the connection protocol. `EdmToEdmSender` also acts as a wrapper around that protocol so workers can simply call `open()` to execute the connection protocol without having to manually reimplement for each kernel. @@ -265,40 +266,64 @@ struct TransactionIdCounter { template struct WriteTransactionIdTracker { static constexpr uint8_t INVALID_TRID = MAX_TRANSACTION_IDS; + static constexpr bool N_TRIDS_IS_POW2 = is_power_of_2(MAX_TRANSACTION_IDS); + static constexpr bool N_CHANS_IS_POW2 = is_power_of_2(NUM_CHANNELS); + static constexpr uint8_t TRID_POW2_MASK = MAX_TRANSACTION_IDS - 1; + static constexpr bool BOTH_PARAMS_ARE_POW2 = N_TRIDS_IS_POW2 && N_CHANS_IS_POW2; + WriteTransactionIdTracker() { for (size_t i = 0; i < NUM_CHANNELS; i++) { this->buffer_slot_trids[i] = INVALID_TRID; } } FORCE_INLINE void set_buffer_slot_trid(uint8_t trid, tt::fabric::BufferIndex buffer_index) { - this->buffer_slot_trids[buffer_index] = trid; - } - - FORCE_INLINE void advance_trid_counter() { - this->trid_counter.increment(); + if constexpr (!BOTH_PARAMS_ARE_POW2) { + this->buffer_slot_trids[buffer_index] = trid; + } } FORCE_INLINE uint8_t update_buffer_slot_to_next_trid_and_advance_trid_counter(tt::fabric::BufferIndex buffer_index) { - uint8_t next_trid = this->trid_counter.get(); - this->buffer_slot_trids[buffer_index] = next_trid; - this->trid_counter.increment(); - return next_trid; + if constexpr (BOTH_PARAMS_ARE_POW2) { + uint8_t next_trid = buffer_index & TRID_POW2_MASK; + this->trid_counter.increment(); + return next_trid; + } else { + uint8_t next_trid = this->trid_counter.get(); + this->buffer_slot_trids[buffer_index] = next_trid; + this->trid_counter.increment(); + return next_trid; + } } FORCE_INLINE void clear_trid_at_buffer_slot(tt::fabric::BufferIndex buffer_index) { - this->buffer_slot_trids[buffer_index] = INVALID_TRID; + if constexpr (!BOTH_PARAMS_ARE_POW2) { + this->buffer_slot_trids[buffer_index] = INVALID_TRID; + } } FORCE_INLINE uint8_t get_buffer_slot_trid(tt::fabric::BufferIndex buffer_index) const { - return this->buffer_slot_trids[buffer_index]; + if constexpr (BOTH_PARAMS_ARE_POW2) { + return buffer_index & TRID_POW2_MASK; + } else { + return this->buffer_slot_trids[buffer_index]; + } } FORCE_INLINE bool transaction_flushed(tt::fabric::BufferIndex buffer_index) const { - auto trid = this->get_buffer_slot_trid(buffer_index); - return trid == INVALID_TRID || ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); + if constexpr (BOTH_PARAMS_ARE_POW2) { + auto trid = this->get_buffer_slot_trid(buffer_index); + return ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); + } else { + // TODO: should be able to remove compare against INVALID_TRID + auto trid = this->get_buffer_slot_trid(buffer_index); + return trid == INVALID_TRID || ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); + } } private: std::array buffer_slot_trids; TransactionIdCounter trid_counter; + + // TODO: cleanup - only used for when both params are pow2, else above are used. + uint8_t next_trid = 0; }; static constexpr uint32_t DEFAULT_ETH_TXQ = 0; @@ -366,6 +391,8 @@ constexpr std::array to_sender_packets_completed_streams = {{ */ template struct OutboundReceiverChannelPointers { + static constexpr bool is_pow2 = is_power_of_2(RECEIVER_NUM_BUFFERS); + tt::fabric::ChannelBufferPointer wrptr; tt::fabric::ChannelBufferPointer ack_ptr; tt::fabric::ChannelBufferPointer completion_ptr; @@ -571,11 +598,10 @@ FORCE_INLINE void receiver_send_completion_ack( remote_sender_completion_ptr.increment(); } - +template FORCE_INLINE bool can_forward_packet_completely( - const volatile tt::fabric::PacketHeader* packet_header, tt::fabric::RoutingFields cached_routing_fields, - tt::fabric::WorkerToFabricEdmSender& downstream_edm_interface) { + tt::fabric::EdmToEdmSender& downstream_edm_interface) { // We always check if it is the terminal mcast packet value. We can do this because all unicast packets have the // mcast terminal value masked in to the routing field. This simplifies the check here to a single compare. bool deliver_locally_only = cached_routing_fields.value == tt::fabric::RoutingFields::LAST_MCAST_VAL; @@ -583,20 +609,22 @@ FORCE_INLINE bool can_forward_packet_completely( } // !!!WARNING!!! - MAKE SURE CONSUMER HAS SPACE BEFORE CALLING +template FORCE_INLINE void receiver_forward_packet( // TODO: have a separate cached copy of the packet header to save some additional L1 loads volatile tt::fabric::PacketHeader *packet_start, tt::fabric::RoutingFields cached_routing_fields, - tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + tt::fabric::EdmToEdmSender &downstream_edm_interface, uint8_t transaction_id) { bool start_distance_is_terminal_value = (cached_routing_fields.value & tt::fabric::RoutingFields::HOP_DISTANCE_MASK) == tt::fabric::RoutingFields::LAST_HOP_DISTANCE_VAL; + uint16_t payload_size_bytes = packet_start->payload_size_bytes; if (start_distance_is_terminal_value) { - execute_chip_unicast_to_local_chip(packet_start, transaction_id); + execute_chip_unicast_to_local_chip(packet_start, payload_size_bytes, transaction_id); } bool not_last_destination_device = cached_routing_fields.value != tt::fabric::RoutingFields::LAST_MCAST_VAL; if (not_last_destination_device) { - forward_payload_to_downstream_edm(packet_start, cached_routing_fields, downstream_edm_interface, transaction_id); + forward_payload_to_downstream_edm(packet_start, payload_size_bytes, cached_routing_fields, downstream_edm_interface, transaction_id); } } @@ -633,7 +661,6 @@ FORCE_INLINE bool run_sender_channel_step( tt::fabric::validate(*packet_header); packet_header_recorder.record_packet_header(packet_header); } - print_pkt_header(packet_header); send_next_data( local_sender_channel, local_sender_channel_worker_interface, @@ -710,17 +737,16 @@ FORCE_INLINE bool run_sender_channel_step( return did_something; }; -template +template FORCE_INLINE void run_receiver_channel_step( tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &remote_sender_channnels, - tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + tt::fabric::EdmToEdmSender &downstream_edm_interface, volatile tt::fabric::EdmFabricReceiverChannelCounters *receiver_channel_counters_ptr, std::array, NUM_SENDER_CHANNELS> &remote_eth_sender_wrptrs, ReceiverChannelPointers &receiver_channel_pointers, PacketHeaderRecorder &packet_header_recorder, - WriteTransactionIdTracker &receiver_channel_trid_tracker, - ReceiverState *const receiver_state_out) { + WriteTransactionIdTracker &receiver_channel_trid_tracker) { auto &ack_ptr = receiver_channel_pointers.ack_ptr; auto pkts_received_since_last_check = get_ptr_val(); @@ -750,12 +776,11 @@ FORCE_INLINE void run_receiver_channel_step( volatile auto packet_header = local_receiver_channel.get_packet_header(receiver_buffer_index); tt::fabric::RoutingFields cached_routing_fields = const_cast(packet_header)->routing_fields; - print_pkt_header(packet_header); bool can_send_to_all_local_chip_receivers = - can_forward_packet_completely(packet_header, cached_routing_fields, downstream_edm_interface); + can_forward_packet_completely( + cached_routing_fields, downstream_edm_interface); bool trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index); if (can_send_to_all_local_chip_receivers && trid_flushed) { - // DeviceZoneScopedN("EDMR-Send-Impl"); uint8_t trid = receiver_channel_trid_tracker.update_buffer_slot_to_next_trid_and_advance_trid_counter(receiver_buffer_index); receiver_forward_packet(packet_header, cached_routing_fields, downstream_edm_interface, trid); wr_sent_ptr.increment(); @@ -822,7 +847,7 @@ FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal got_graceful_termination_signal(termination_signal_ptr); } -template +template bool all_channels_drained(tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &local_sender_channels, std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces, @@ -849,12 +874,12 @@ bool all_channels_drained(tt::fabric::EthChannelBuffer &lo * Every loop iteration visit a sender channel and the receiver channel. Switch between sender * channels every iteration unless it is unsafe/undesirable to do so (e.g. for performance reasons). */ -template +template void run_fabric_edm_main_loop( tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &local_sender_channels, std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces, - tt::fabric::WorkerToFabricEdmSender &downstream_edm_noc_interface, + tt::fabric::EdmToEdmSender &downstream_edm_noc_interface, std::array, NUM_SENDER_CHANNELS> &remote_sender_channels, tt::fabric::EthChannelBuffer &remote_receiver_channel, volatile tt::fabric::TerminationSignal *termination_signal_ptr, @@ -864,7 +889,6 @@ void run_fabric_edm_main_loop( std::array &sender_channel_packet_recorders) { std::array sender_states = { SenderState::SENDER_WAIT_WORKER_HANDSHAKE, SenderState::SENDER_WAIT_WORKER_HANDSHAKE}; - ReceiverState receiver_state = ReceiverState::RECEIVER_WAITING_FOR_ETH; size_t sender_channel_index = 0; size_t did_nothing_count = 0; *termination_signal_ptr = tt::fabric::TerminationSignal::KEEP_RUNNING; @@ -883,6 +907,11 @@ void run_fabric_edm_main_loop( WriteTransactionIdTracker receiver_channel_trid_tracker; + // This value defines the number of loop iterations we perform of the main control sequence before exiting + // to check for termination and context switch. Removing the these checks from the inner loop can drastically + // improve performance. The value of 32 was chosen somewhat empirically and then raised up slightly. + constexpr uint32_t DEFAULT_ITERATIONS_BETWEEN_CTX_SWITCH_AND_TEARDOWN_CHECKS = 32; + while (!got_immediate_termination_signal(termination_signal_ptr)) { bool got_graceful_termination = got_graceful_termination_signal(termination_signal_ptr); if (got_graceful_termination) { @@ -894,33 +923,41 @@ void run_fabric_edm_main_loop( return; } } - - // Capture these to see if we made progress - auto old_recv_state = receiver_state; - - // There are some cases, mainly for performance, where we don't want to switch between sender channels - // so we interoduce this to provide finer grain control over when we disable the automatic switching - bool did_something_sender = run_sender_channel_step( - local_sender_channels[sender_channel_index], - local_sender_channel_worker_interfaces[sender_channel_index], - outbound_to_receiver_channel_pointers, - remote_receiver_channel, - sender_channel_counters_ptrs[sender_channel_index], - sender_channel_packet_recorders[sender_channel_index], - channel_connection_established[sender_channel_index], - sender_channel_index); - - sender_channel_index = 1 - sender_channel_index; - - run_receiver_channel_step( - local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, receiver_channel_counters_ptr, - remote_eth_sender_wrptrs, - receiver_channel_pointers, - receiver_channel_packet_recorder, - receiver_channel_trid_tracker, - &receiver_state); - - bool did_something = did_something_sender || old_recv_state != receiver_state; + bool did_something = false; + for (size_t i = 0; i < DEFAULT_ITERATIONS_BETWEEN_CTX_SWITCH_AND_TEARDOWN_CHECKS; i++) { + // Capture these to see if we made progress + + // There are some cases, mainly for performance, where we don't want to switch between sender channels + // so we interoduce this to provide finer grain control over when we disable the automatic switching + bool did_something_sender = run_sender_channel_step( + local_sender_channels[0], + local_sender_channel_worker_interfaces[0], + outbound_to_receiver_channel_pointers, + remote_receiver_channel, + sender_channel_counters_ptrs[0], + sender_channel_packet_recorders[0], + channel_connection_established[0], + 0); + + run_receiver_channel_step( + local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, receiver_channel_counters_ptr, + remote_eth_sender_wrptrs, + receiver_channel_pointers, + receiver_channel_packet_recorder, + receiver_channel_trid_tracker); + + bool did_something_sender2 = run_sender_channel_step( + local_sender_channels[1], + local_sender_channel_worker_interfaces[1], + outbound_to_receiver_channel_pointers, + remote_receiver_channel, + sender_channel_counters_ptrs[1], + sender_channel_packet_recorders[1], + channel_connection_established[1], + 1); + + did_something = did_something || did_something_sender || did_something_sender2; + } if (did_something) { did_nothing_count = 0; @@ -1113,7 +1150,7 @@ void kernel_main() { } auto downstream_edm_noc_interface = has_downstream_edm_buffer_connection - ? tt::fabric::WorkerToFabricEdmSender( + ? tt::fabric::EdmToEdmSender( //persistent_mode -> hardcode to false because for EDM -> EDM // connections we must always use semaphore lookup false, @@ -1129,7 +1166,7 @@ void kernel_main() { reinterpret_cast(edm_forwarding_semaphore_address), reinterpret_cast(edm_teardown_semaphore_address), downstream_noc_interface_buffer_index_local_addr) - : tt::fabric::WorkerToFabricEdmSender(); + : tt::fabric::EdmToEdmSender(); auto local_receiver_channel = tt::fabric::EthChannelBuffer( local_receiver_channel_buffer_address, diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp index 2285a6c42cbe..369c4f57f335 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -17,148 +17,9 @@ #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" - +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_flow_control_helpers.hpp" namespace tt::fabric { -template -class NamedType -{ -public: - FORCE_INLINE explicit NamedType(T const& value) : value_(value) {} - FORCE_INLINE explicit NamedType(T&& value) : value_(std::move(value)) {} - FORCE_INLINE NamedType &operator=(NamedType const& rhs) = default; - FORCE_INLINE T& get() { return value_; } - FORCE_INLINE T const& get() const {return value_; } - FORCE_INLINE operator T() const { return value_; } - FORCE_INLINE operator T&() { return value_; } -private: - T value_; -}; - -using BufferIndex = NamedType; -using BufferPtr = NamedType; - - -// Increments val and wraps to 0 if it reaches limit -template -FORCE_INLINE -auto wrap_increment(T val) -> T { - static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); - constexpr bool is_pow2 = is_power_of_2(LIMIT); - if constexpr (LIMIT == 1) { - return val; - } else if constexpr (LIMIT == 2) { - return 1 - val; - } else if constexpr (is_pow2) { - return (val + 1) & (LIMIT - 1); - } else { - return (val == static_cast(LIMIT - 1)) ? static_cast(0) : static_cast(val + 1); - } -} -template -FORCE_INLINE -auto wrap_increment_n(T val, uint8_t increment) -> T { - static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); - constexpr bool is_pow2 = is_power_of_2(LIMIT); - if constexpr (LIMIT == 1) { - return val; - } else if constexpr (LIMIT == 2) { - return 1 - val; - } else if constexpr (is_pow2) { - return (val + increment) & (LIMIT - 1); - } else { - T new_unadjusted_val = val + increment; - bool wraps = new_unadjusted_val >= LIMIT; - return wraps ? static_cast(new_unadjusted_val - LIMIT) : static_cast(new_unadjusted_val); - } -} - -template -FORCE_INLINE -auto normalize_ptr(BufferPtr ptr) -> BufferIndex { - static_assert(NUM_BUFFERS != 0, "normalize_ptr called with NUM_BUFFERS of 0; it must be greater than 0"); - constexpr bool is_size_pow2 = (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; - constexpr bool is_size_2 = NUM_BUFFERS == 2; - constexpr bool is_size_1 = NUM_BUFFERS == 1; - constexpr uint8_t wrap_mask = NUM_BUFFERS - 1; - if constexpr (is_size_pow2) { - return BufferIndex{ptr & wrap_mask}; - } else if constexpr (is_size_2) { - return BufferIndex{(uint8_t)1 - ptr}; - } else if constexpr (is_size_1) { - return BufferIndex{0}; - } else { - // note it may make sense to calculate this only when we increment - // which will save calculations overall (but may add register pressure) - // and introduce undesirable loads - bool normalize = ptr >= NUM_BUFFERS; - uint8_t normalized_ptr = ptr.get() - static_cast(normalize * NUM_BUFFERS); - ASSERT(normalized_ptr < NUM_BUFFERS); - return BufferIndex{normalized_ptr}; - } -} - - -template -class ChannelBufferPointer { - static_assert(NUM_BUFFERS <= std::numeric_limits::max() / 2, "NUM_BUFFERS must be less than or half of std::numeric_limits::max() due to the internal implementation"); - public: - static constexpr bool is_size_pow2 = (NUM_BUFFERS & (NUM_BUFFERS - 1)) == 0; - static constexpr bool is_size_2 = NUM_BUFFERS == 2; - static constexpr bool is_size_1 = NUM_BUFFERS == 1; - static constexpr uint8_t ptr_wrap_size = 2 * NUM_BUFFERS; - - // Only to use if is_size_pow2 - static constexpr uint8_t ptr_wrap_mask = (2 * NUM_BUFFERS) - 1; - static constexpr uint8_t buffer_wrap_mask = NUM_BUFFERS - 1; - ChannelBufferPointer() : ptr(0) {} - /* - * Returns the "raw" pointer - not usable to index the buffer channel - */ - FORCE_INLINE BufferPtr get_ptr() const { - return this->ptr; - } - - FORCE_INLINE bool is_caught_up_to(ChannelBufferPointer const& leading_ptr) const { - return this->is_caught_up_to(leading_ptr.get_ptr()); - } - FORCE_INLINE uint8_t distance_behind(ChannelBufferPointer const& leading_ptr) const { - return this->distance_behind(leading_ptr.get_ptr()); - } - - /* - * Returns the buffer index pointer which is usable to index into the buffer memory - */ - FORCE_INLINE BufferIndex get_buffer_index() const { - return BufferIndex{normalize_ptr(this->ptr)}; - } - - FORCE_INLINE void increment_n(uint8_t n) { - this->ptr = BufferPtr{wrap_increment_n<2*NUM_BUFFERS>(this->ptr.get(), n)}; - } - FORCE_INLINE void increment() { - this->ptr = wrap_increment<2*NUM_BUFFERS>(this->ptr); - } - - private: - // Make these private to make sure caller doesn't accidentally mix two pointers pointing to - // different sized channels - FORCE_INLINE bool is_caught_up_to(BufferPtr const& leading_ptr) const { - return this->get_ptr() == leading_ptr; - } - FORCE_INLINE uint8_t distance_behind(BufferPtr const& leading_ptr) const { - bool leading_gte_trailing_ptr = leading_ptr >= this->ptr; - if constexpr (is_size_pow2) { - return (leading_ptr - this->ptr) & ptr_wrap_mask; - } else { - return leading_gte_trailing_ptr ? - leading_ptr - this->ptr : - ptr_wrap_size - (this->ptr - leading_ptr); - } - } - BufferPtr ptr = BufferPtr{0}; -}; - template FORCE_INLINE auto wrap_increment(T val, size_t max) { @@ -310,7 +171,7 @@ struct EdmChannelWorkerInterface { (uint32_t)worker_info.worker_xy.x, (uint32_t)worker_info.worker_xy.y, worker_info.worker_teardown_semaphore_address); // Set connection to unused so it's available for next worker - *this->connection_live_semaphore = tt::fabric::WorkerToFabricEdmSender::unused_connection_value; + *this->connection_live_semaphore = tt::fabric::EdmToEdmSender<0>::unused_connection_value; *reinterpret_cast(&(worker_location_info_ptr->edm_rdptr)) = last_edm_rdptr_value; @@ -329,8 +190,8 @@ struct EdmChannelWorkerInterface { worker_location_info_ptr->edm_rdptr = local_ackptr.get_ptr(); } - [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == tt::fabric::WorkerToFabricEdmSender::close_connection_request_value; } - [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == tt::fabric::WorkerToFabricEdmSender::open_connection_value; } + [[nodiscard]] FORCE_INLINE bool has_worker_teardown_request() const { return *connection_live_semaphore == tt::fabric::EdmToEdmSender<0>::close_connection_request_value; } + [[nodiscard]] FORCE_INLINE bool connection_is_live() const { return *connection_live_semaphore == tt::fabric::EdmToEdmSender<0>::open_connection_value; } volatile EDMChannelWorkerLocationInfo *worker_location_info_ptr; volatile tt_l1_ptr uint32_t *const remote_producer_wrptr; From 2958cac744e213b1816e1565b92b71c19786f07e Mon Sep 17 00:00:00 2001 From: Nour Ardo Date: Tue, 18 Feb 2025 10:38:18 -0500 Subject: [PATCH 2/8] Fix shape in outer (#17492) ### Ticket Link to Github Issue https://github.com/tenstorrent/tt-metal/issues/16882 ### Problem description ttnn::outer fails after tilizing the inputs ### What's changed outer op is checking the padded size of the inputs which is causing the error. This PR changes the shape used in outer ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/13167635235 - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- .../eltwise/binary/device/binary_composite_op.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index a4dac8812f1e..7a9cbc4be601 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -476,8 +476,8 @@ Tensor _scatter(const Tensor& input_a, const Tensor& input_b, const std::optiona * by running reshape. */ Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { - const ttnn::Shape s_a = input_a.padded_shape(); - const ttnn::Shape s_b = input_b.padded_shape(); + const ttnn::Shape s_a = input_a.get_logical_shape(); + const ttnn::Shape s_b = input_b.get_logical_shape(); auto num_ones = [](const ttnn::Shape& s) -> uint32_t { uint32_t num1s = 0; for (uint32_t idx = 0; idx < 4; idx++) { @@ -497,10 +497,12 @@ Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional< Tensor b_slim = input_b; if (!skip_reshape_a) { - a_slim = ttnn::reshape(input_a, ttnn::Shape{std::array{1, 1, input_a.volume(), 1}}); + uint32_t a_volume = s_a[0] * s_a[1] * s_a[2] * s_a[3]; + a_slim = ttnn::reshape(input_a, ttnn::Shape{std::array{1, 1, a_volume, 1}}); } if (!skip_reshape_b) { - b_slim = ttnn::reshape(input_b, ttnn::Shape{std::array{1, 1, 1, input_b.volume()}}); + uint32_t b_volume = s_b[0] * s_b[1] * s_b[2] * s_b[3]; + b_slim = ttnn::reshape(input_b, ttnn::Shape{std::array{1, 1, 1, b_volume}}); } a_slim = ttnn::to_layout(a_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (IDevice*)nullptr); b_slim = ttnn::to_layout(b_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (IDevice*)nullptr); From be555b1d3d9c165f24c2f1019be3aca179e59b1c Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Fri, 14 Feb 2025 15:12:10 -0600 Subject: [PATCH 3/8] Install RPATH ORIGIN Add ORIGIN to both ttnn and tt_metal library RPATH's to simplify wheel installation for upstream consumers. --- tt_metal/CMakeLists.txt | 2 +- ttnn/CMakeLists.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 44f80bb4ec0f..11c36177fa9e 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -131,7 +131,7 @@ set_target_properties( tt_metal PROPERTIES INSTALL_RPATH - "${PROJECT_BINARY_DIR}/lib" + "${PROJECT_BINARY_DIR}/lib;$ORIGIN" ADDITIONAL_CLEAN_FILES "${PROJECT_BINARY_DIR}/lib;${PROJECT_BINARY_DIR}/obj" ) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 7eb79f85d0df..eb63d038eda8 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -861,6 +861,7 @@ TT_ENABLE_UNITY_BUILD(ttnn) set(TTNN_INSTALL_RPATH "${PROJECT_BINARY_DIR}/lib" "$ORIGIN/build/lib" + "$ORIGIN" ) #Make sure library built is _ttnn.so and that it can find all it's linked libraries From ed210e7dae8dafba91a5434d6fbb50dc7dce8932 Mon Sep 17 00:00:00 2001 From: Atul Krishnadas Date: Tue, 18 Feb 2025 08:36:59 -0800 Subject: [PATCH 4/8] #17094: fill implicit pad sharded using the new shardedAddrGen (#17692) --- .../unit_tests/operations/test_fill_pad.py | 153 +++++++++++++++++- .../fill_pad/device/fill_pad_op.cpp | 6 - .../device/fill_pad_program_factory.cpp | 13 +- .../kernels/dataflow/fill_pad_writer.cpp | 28 +++- 4 files changed, 187 insertions(+), 13 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_fill_pad.py b/tests/ttnn/unit_tests/operations/test_fill_pad.py index 48dff554b6c3..489cb371325c 100644 --- a/tests/ttnn/unit_tests/operations/test_fill_pad.py +++ b/tests/ttnn/unit_tests/operations/test_fill_pad.py @@ -5,6 +5,7 @@ import pytest import torch import ttnn +import math from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import torch_random, run_for_wormhole_b0 @@ -52,12 +53,12 @@ def create_nd_padded_tiled_tensor(shape, tile_size, fill_value, dtype): ttnn.bfloat16: torch.float32, } +# torch.set_printoptions(threshold=10000) + -# @pytest.mark.parametrize("shape", [(2, 32, 300, 256)]) @pytest.mark.parametrize( "shape", [ - # 2D shapes with edge cases for fill_pad (1, 16), (16, 1), (1, 17), @@ -67,6 +68,7 @@ def create_nd_padded_tiled_tensor(shape, tile_size, fill_value, dtype): (31, 31), (33, 33), (65, 65), + (97, 97), (1, 2, 3, 2, 1, 2, 97, 97), ], ) @@ -96,3 +98,150 @@ def test_fill_pad( padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch() assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor) + + +@pytest.mark.parametrize("fill_value", [1]) +@pytest.mark.parametrize( + "shape", + [ + (1, 16), + (97, 97), + ], +) +@pytest.mark.parametrize( + "shard_scheme", + [ + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ], +) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.uint32]) +def test_fill_pad_complex_sharding(device, fill_value, shape, shard_scheme, dtype): + torch.manual_seed(1234) + torch_input_tensor, padded_torch_tensor = create_nd_padded_tiled_tensor( + shape, 32, fill_value, ttnn_dtype_to_torch_dtype[dtype] + ) + num_cores_xblock = 2 + num_cores_yblock = 4 + num_cores = num_cores_xblock * num_cores_yblock + + # Add complex shard grid with 2 X 4 = 8 cores + shard_grid = ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1)), + ttnn.CoreRange(ttnn.CoreCoord(2, 0), ttnn.CoreCoord(3, 1)), + ttnn.CoreRange(ttnn.CoreCoord(0, 4), ttnn.CoreCoord(0, 5)), + ] + ) + + tiles_per_2d = padded_torch_tensor.shape[-2] * padded_torch_tensor.shape[-1] / (32 * 32) + dims_b4_last_dim = 1 + for i in range(len(padded_torch_tensor.shape) - 1): + dims_b4_last_dim *= padded_torch_tensor.shape[i] + + shard_shape = [32, 32] + if shard_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED: + shard_shape = (dims_b4_last_dim, 32 * math.ceil((math.ceil(padded_torch_tensor.shape[-1] / 32) / num_cores))) + elif shard_scheme == ttnn.TensorMemoryLayout.HEIGHT_SHARDED: + tile_widths_per_core = math.ceil(dims_b4_last_dim / num_cores) + shard_shape = (32 * tile_widths_per_core, padded_torch_tensor.shape[-1]) + elif shard_scheme == ttnn.TensorMemoryLayout.BLOCK_SHARDED: + tile_widths_per_core = math.ceil(dims_b4_last_dim / num_cores_xblock) + shard_shape = ( + 32 * tile_widths_per_core, + 32 * math.ceil((math.ceil(padded_torch_tensor.shape[-1] / 32) / num_cores_yblock)), + ) + else: + shard_shape = (math.ceil(math.sqrt(tiles_per_core)), math.ceil(math.sqrt(tiles_per_core))) + + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) + output_mem_config = ttnn.MemoryConfig( + shard_scheme, + ttnn.BufferType.L1, + shard_spec, + ) + + input_tensor = ttnn.to_device( + ttnn.from_torch(torch_input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT), + device, + memory_config=output_mem_config, + ) + + output_tensor = ttnn.fill_implicit_tile_padding(input_tensor, fill_value, memory_config=ttnn.DRAM_MEMORY_CONFIG) + padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch() + + assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor, 0.99) + + +@pytest.mark.parametrize("fill_value", [1]) +@pytest.mark.parametrize( + "shape", + [ + (1, 16), + (16, 1), + (17, 17), + (17, 1), + (16, 16), + (17, 17), + (31, 31), + (33, 33), + (97, 97), + ], +) +@pytest.mark.parametrize( + "shard_scheme", + [ + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ], +) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.uint32]) +def test_fill_pad_sharded(device, fill_value, shape, shard_scheme, dtype): + torch.manual_seed(1234) + torch_input_tensor, padded_torch_tensor = create_nd_padded_tiled_tensor( + shape, 32, fill_value, ttnn_dtype_to_torch_dtype[dtype] + ) + + num_cores_x = 8 + num_cores_y = 7 + num_cores = num_cores_x * num_cores_y + shard_grid = ttnn.CoreRangeSet( + [ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))] + ) + + tiles_per_2d = padded_torch_tensor.shape[-2] * padded_torch_tensor.shape[-1] / (32 * 32) + dims_b4_last_dim = 1 + for i in range(len(padded_torch_tensor.shape) - 1): + dims_b4_last_dim *= padded_torch_tensor.shape[i] + + shard_shape = [32, 32] + if shard_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED: + shard_shape = (dims_b4_last_dim, 32 * math.ceil((math.ceil(padded_torch_tensor.shape[-1] / 32) / num_cores))) + elif shard_scheme == ttnn.TensorMemoryLayout.HEIGHT_SHARDED: + tile_widths_per_core = math.ceil(dims_b4_last_dim / num_cores) + shard_shape = (32 * tile_widths_per_core, padded_torch_tensor.shape[-1]) + elif shard_scheme == ttnn.TensorMemoryLayout.BLOCK_SHARDED: + tile_widths_per_core = math.ceil(dims_b4_last_dim / num_cores_x) + shard_shape = (32 * tile_widths_per_core, 32 * math.ceil((padded_torch_tensor.shape[-1] / 32 / num_cores_y))) + else: + shard_shape = (math.ceil(math.sqrt(tiles_per_core)), math.ceil(math.sqrt(tiles_per_core))) + + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) + output_mem_config = ttnn.MemoryConfig( + shard_scheme, + ttnn.BufferType.L1, + shard_spec, + ) + + input_tensor = ttnn.to_device( + ttnn.from_torch(torch_input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT), + device, + memory_config=output_mem_config, + ) + + output_tensor = ttnn.fill_implicit_tile_padding(input_tensor, fill_value, memory_config=ttnn.DRAM_MEMORY_CONFIG) + padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch() + + assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor, 0.99) diff --git a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_op.cpp index 78c13267c69d..3de81f581ff6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_op.cpp @@ -14,12 +14,6 @@ namespace ttnn::operations::data_movement { void FillPad::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_FATAL(input_tensor_a.get_layout() == TILE_LAYOUT, "FillPad should only be used for tile layout"); - TT_FATAL( - input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, - "FillPad does not currently support sharding"); - TT_FATAL( - this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, - "FillPad does not currently support sharding"); } std::vector FillPad::compute_output_specs(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.cpp index e798d9f0c3f2..b07c6e65bf03 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/fill_pad_program_factory.cpp @@ -9,6 +9,7 @@ #include #include #include +#include "ttnn/operations/ccl/sharding_addrgen_helper.hpp" bool is_power_of_two_at_least_32(uint32_t value) { return value >= 32 && (value & (value - 1)) == 0; } @@ -68,6 +69,8 @@ operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor, padded_height / tt::constants::TILE_HEIGHT * padded_width / tt::constants::TILE_HEIGHT; uint32_t tiles_per_tile_row = padded_width / tt::constants::TILE_HEIGHT; + bool sharded = input_tensor.memory_config().memory_layout != TensorMemoryLayout::INTERLEAVED; + // create kernel // reader compile time args std::vector writer_compile_time_args = { @@ -82,7 +85,12 @@ operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor, (std::uint32_t)tiles_per_2d_tensor, (std::uint32_t)tiles_per_tile_row, (std::uint32_t)tt::constants::TILE_HEIGHT, - (std::uint32_t)tt::constants::FACE_HEIGHT}; + (std::uint32_t)tt::constants::FACE_HEIGHT, + (std::uint32_t)sharded}; + + if (sharded) { + shard_builder::extend_sharding_compile_time_args(input_tensor, writer_compile_time_args); + } tt::tt_metal::KernelHandle writer_kernel_id = tt::tt_metal::CreateKernel( program, @@ -102,6 +110,9 @@ operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor, { writer_runtime_args[2] = tile_offset; writer_runtime_args[3] = local_num_2d_tensors; + if (sharded) { + shard_builder::extend_sharding_run_time_args(input_tensor, writer_runtime_args); + } tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp index a94aa7fdea0f..91d166e95100 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" +#include "cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp" void kernel_main() { constexpr uint32_t cb_id_0 = get_compile_time_arg_val(0); @@ -19,20 +21,38 @@ void kernel_main() { constexpr uint32_t tile_size = get_compile_time_arg_val(10); constexpr uint32_t tile_hw = tile_size * tile_size; constexpr uint32_t face_size = get_compile_time_arg_val(11); +#define SHARDED get_compile_time_arg_val(12) == 1 constexpr uint32_t face_hw = face_size * face_size; constexpr uint32_t alignment_adjustor = 16; - uint32_t dst_addr = get_arg_val(0); - uint32_t cb_page_size = get_arg_val(1); - uint32_t starting_tile_offset = get_arg_val(2); - uint32_t num_2d_tensors = get_arg_val(3); + uint32_t rt_arg_ind = 0; + uint32_t dst_addr = get_arg_val(rt_arg_ind++); + uint32_t cb_page_size = get_arg_val(rt_arg_ind++); + uint32_t starting_tile_offset = get_arg_val(rt_arg_ind++); + uint32_t num_2d_tensors = get_arg_val(rt_arg_ind++); +#if (SHARDED) + typedef ShardedInfo< + get_compile_time_arg_val(13), + get_compile_time_arg_val(14), + get_compile_time_arg_val(15), + get_compile_time_arg_val(16), + get_compile_time_arg_val(17), + get_compile_time_arg_val(18), + get_compile_time_arg_val(19)> + tensor_shard_info; + + const auto [mapping_table, rt_increment] = + experimental::shard_addr_gen_utils::get_shard_map(get_arg_addr(rt_arg_ind)); + experimental::ShardedAddrGen s0 = {.bank_base_address = dst_addr, .shard_array = mapping_table}; +#else const DataFormat data_format = get_dataformat(cb_id_0); const InterleavedAddrGenFast s0 = { .bank_base_address = dst_addr, .page_size = tile_hw * element_size_bytes, .data_format = data_format // page_size needs to be tile_size_bytes }; +#endif // Reserve and push the fill value into the circular buffer cb_reserve_back(cb_id_0, 1); From 6e257a5c5fdbbd7d4b1bd6944936c82ece768460 Mon Sep 17 00:00:00 2001 From: William Ly Date: Tue, 18 Feb 2025 12:24:08 -0500 Subject: [PATCH 5/8] [skip ci] #0: Fix produce_data bug "jq: error: writing output failed: Broken pipe" (#17953) ### Ticket ### Problem description Recent produce_data workflows started bugging out on a line that checks github API for artifacts starting with "test_reports_*" with `jq: error: writing output failed: Broken pipe` https://github.com/tenstorrent/tt-metal/actions/runs/13382103493/job/37372300588#step:7:9 ### What's changed Store all output from gh api into var, and then `grep -q` after. ### Checklist - [x] New/Existing tests provide coverage for changes Same failing workflow, rerun on branch with fix: https://github.com/tenstorrent/tt-metal/actions/runs/13396159663 --- .../github/download_cicd_logs_and_artifacts.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/infra/data_collection/github/download_cicd_logs_and_artifacts.sh b/infra/data_collection/github/download_cicd_logs_and_artifacts.sh index 1c5d3852a8da..48e265c6f61a 100755 --- a/infra/data_collection/github/download_cicd_logs_and_artifacts.sh +++ b/infra/data_collection/github/download_cicd_logs_and_artifacts.sh @@ -17,7 +17,9 @@ download_artifacts() { local repo=$1 local workflow_run_id=$2 - if gh api --paginate /repos/$repo/actions/runs/$workflow_run_id/artifacts | jq '.artifacts[] | .name' | grep -q "test_reports_"; then + echo "[info] Downloading test reports for workflow run $workflow_run_id" + api_output=$(gh api --paginate /repos/$repo/actions/runs/$workflow_run_id/artifacts | jq -r '.artifacts[] | .name') + if echo "$api_output" | grep -q "test_reports_"; then gh run download --repo $repo -D generated/cicd/$workflow_run_id/artifacts --pattern test_reports_* $workflow_run_id else echo "[Warning] Test reports not found for workflow run $workflow_run_id" From d08245ef3c03197bab2b199a49e6fd5d99f3b195 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Tue, 18 Feb 2025 12:37:44 -0500 Subject: [PATCH 6/8] #0: Include in xtensor conversion utils (#17948) ### Ticket N/A ### Problem description `tt::stl::SmallVector` removed a dependency on c++20 std::span, which was transitively included in this header. This [breaks](https://github.com/tenstorrent/tt-mlir/actions/runs/13384256221/job/37378049606?pr=2194) tt-mlir. ### What's changed Include ``. ### Checklist - Compilation tested locally, @brataTT confirmed the fix works for tt-mlir. --- ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp index df97212e6489..fa7b15c6ee40 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include "ttnn/tensor/tensor.hpp" From 6573fa85d63b8f2041076cabe33afdb3c3ef9643 Mon Sep 17 00:00:00 2001 From: aagarwalTT Date: Tue, 18 Feb 2025 11:41:28 -0600 Subject: [PATCH 7/8] Remove gatekeeper kernel from fabric launch --- .../kernels/tt_fabric_traffic_gen_tx.cpp | 13 +- .../routing/kernels/tt_fabric_tx_ubench.cpp | 12 +- .../routing/test_tt_fabric_sanity.cpp | 151 ++++-------------- tt_fabric/hw/inc/tt_fabric.h | 2 +- tt_fabric/hw/inc/tt_fabric_api.h | 36 +---- tt_fabric/impl/kernels/tt_fabric_router.cpp | 69 +++++--- 6 files changed, 93 insertions(+), 190 deletions(-) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp index 483513270025..2dac3ffaebe6 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp @@ -83,10 +83,6 @@ packet_header_t packet_header __attribute__((aligned(16))); uint32_t target_address; uint32_t noc_offset; uint32_t rx_addr_hi; - -uint32_t gk_interface_addr_l; -uint32_t gk_interface_addr_h; - uint32_t controller_noc_offset; // flag to check if need to zero out notification addr @@ -389,11 +385,9 @@ void kernel_main() { src_endpoint_id = get_arg_val(increment_arg_idx(rt_args_idx)); noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); controller_noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); - uint32_t routing_plane = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t outbound_eth_chan = get_arg_val(increment_arg_idx(rt_args_idx)); dest_device = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t rx_buf_size = get_arg_val(increment_arg_idx(rt_args_idx)); - gk_interface_addr_l = get_arg_val(increment_arg_idx(rt_args_idx)); - gk_interface_addr_h = get_arg_val(increment_arg_idx(rt_args_idx)); if constexpr (ASYNC_WR & test_command) { base_target_address = get_arg_val(increment_arg_idx(rt_args_idx)); @@ -462,9 +456,8 @@ void kernel_main() { uint32_t packet_count = 0; // initalize client - fabric_endpoint_init(client_interface_addr, gk_interface_addr_l, gk_interface_addr_h); - routing_table = reinterpret_cast( - client_interface->routing_tables_l1_offset + sizeof(fabric_router_l1_config_t) * routing_plane); + fabric_endpoint_init(client_interface_addr, outbound_eth_chan); + routing_table = reinterpret_cast(client_interface->routing_tables_l1_offset); while (true) { iter++; diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp index d9991ed8b675..ae1bebc19deb 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp @@ -68,8 +68,6 @@ volatile fabric_client_interface_t* client_interface; uint64_t xy_local_addr; uint32_t target_address; uint32_t noc_offset; -uint32_t gk_interface_addr_l; -uint32_t gk_interface_addr_h; uint32_t controller_noc_offset; uint32_t time_seed; @@ -94,11 +92,9 @@ void kernel_main() { src_endpoint_id = get_arg_val(increment_arg_idx(rt_args_idx)); noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); controller_noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); - uint32_t routing_plane = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t outbound_eth_chan = get_arg_val(increment_arg_idx(rt_args_idx)); dest_device = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t rx_buf_size = get_arg_val(increment_arg_idx(rt_args_idx)); - gk_interface_addr_l = get_arg_val(increment_arg_idx(rt_args_idx)); - gk_interface_addr_h = get_arg_val(increment_arg_idx(rt_args_idx)); if constexpr (ASYNC_WR & test_command) { base_target_address = get_arg_val(increment_arg_idx(rt_args_idx)); @@ -140,7 +136,7 @@ void kernel_main() { } // initalize client - fabric_endpoint_init(client_interface_addr, gk_interface_addr_l, gk_interface_addr_h); + fabric_endpoint_init(client_interface_addr, outbound_eth_chan); // notify the controller kernel that this worker is ready to proceed notify_traffic_controller(); @@ -161,7 +157,7 @@ void kernel_main() { client_interface->local_pull_request.pull_request.words_read = 0; if constexpr (mcast_data) { fabric_async_write_multicast( - routing_plane, // the network plane to use for this transaction + 0, // the network plane to use for this transaction data_buffer_start_addr, // source address in sender’s memory dest_device >> 16, dest_device & 0xFFFF, @@ -173,7 +169,7 @@ void kernel_main() { s_depth); } else { fabric_async_write( - routing_plane, // the network plane to use for this transaction + 0, // the network plane to use for this transaction data_buffer_start_addr, // source address in sender’s memory dest_device >> 16, dest_device & 0xFFFF, diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp index a0e91bd4dc29..f9ff6e036706 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp @@ -34,15 +34,7 @@ uint32_t time_seed; // decides if the tx puts the data directly on eth or if a noc hop is allowed as well bool allow_1st_noc_hop = false; -// Gatekeeper kernel coordinates -uint32_t gk_x, gk_y; - -// Check if gatekeeper runs on tensix worker or idle ethernet based on the board type -bool run_gk_on_idle_ethernet; - uint32_t routing_table_addr; -uint32_t gk_interface_addr; -uint32_t socket_info_addr; // if the traffic b/w any pair of chips is bi-directional bool bidirectional_traffic; @@ -54,7 +46,6 @@ uint32_t tx_signal_address; uint32_t host_signal_address; // kernels -const std::string gatekeeper_kernel_src = "tt_fabric/impl/kernels/tt_fabric_gatekeeper.cpp"; const std::string router_kernel_src = "tt_fabric/impl/kernels/tt_fabric_router.cpp"; const std::string traffic_controller_src = "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_controller.cpp"; @@ -149,11 +140,6 @@ typedef struct test_board { } else { physical_chip_ids = available_chip_ids; } - - // gatekeeper - run on idle ethernet for n300/T3K - if (("n300" == board_type_) || ("t3k" == board_type_)) { - run_gk_on_idle_ethernet = true; - } } void _init_galaxy_board(uint32_t num_chips, bool all_pcie = false) { @@ -468,13 +454,11 @@ typedef struct test_device { std::vector router_virtual_cores; CoreCoord core_range_start_virtual; CoreCoord core_range_end_virtual; - CoreCoord gk_logical_core; - CoreCoord gk_phys_core; mesh_id_t mesh_id; chip_id_t logical_chip_id; + uint32_t master_router_idx; uint32_t mesh_chip_id = 0; uint32_t router_mask = 0; - uint32_t gk_noc_offset; metal_SocDescriptor soc_desc; std::unordered_map>> router_worker_map; // router chan to worker logical cores @@ -519,20 +503,7 @@ typedef struct test_device { _generate_router_worker_map(); } - // gatekeeper - if (run_gk_on_idle_ethernet) { - auto idle_eth_cores = device_handle->get_inactive_ethernet_cores(); - if (idle_eth_cores.size() == 0) { - throw std::runtime_error("No idle ethernet cores found on the device"); - } - - gk_logical_core = *idle_eth_cores.begin(); - gk_phys_core = device_handle->ethernet_core_from_logical_core(gk_logical_core); - } else { - gk_logical_core = {gk_x, gk_y}; - gk_phys_core = device_handle->worker_core_from_logical_core(gk_logical_core); - } - gk_noc_offset = tt_metal::hal.noc_xy_encoding(gk_phys_core.x, gk_phys_core.y); + master_router_idx = 0; } void create_router_kernels(std::vector& compile_args, std::map& defines) { @@ -540,14 +511,21 @@ typedef struct test_device { std::vector zero_buf(1, 0); for (auto i = 0; i < num_routers; i++) { + std::vector router_compile_args = compile_args; // setup run time args std::vector runtime_args = { - num_routers, // 0: number of active fabric routers - router_mask, // 1: active fabric router mask - gk_interface_addr, // 2: gk_message_addr_l - gk_noc_offset, // 3: gk_message_addr_h + num_routers, // 0: number of active fabric routers + router_mask, // 1: active fabric router mask + router_logical_cores[master_router_idx].y // 2: master router eth chan }; + // pass is_master flag as compile arg, index 0 is master + if (master_router_idx == i) { + router_compile_args.push_back(1); + } else { + router_compile_args.push_back(0); + } + // initialize the semaphore tt::llrt::write_hex_vec_to_core( device_handle->id(), router_virtual_cores[i], zero_buf, FABRIC_ROUTER_SYNC_SEM); @@ -557,70 +535,25 @@ typedef struct test_device { router_kernel_src, router_logical_cores[i], tt_metal::EthernetConfig{ - .noc = tt_metal::NOC::NOC_0, .compile_args = compile_args, .defines = defines}); + .noc = tt_metal::NOC::NOC_0, .compile_args = router_compile_args, .defines = defines}); tt_metal::SetRuntimeArgs(program_handle, kernel, router_logical_cores[i], runtime_args); } } - void create_gatekeeper_kernel(std::vector& compile_args, std::map& defines) { - uint32_t num_routers = router_logical_cores.size(); - std::vector zero_buf(12, 0); - - std::vector runtime_args = { - num_routers, // 0: number of active fabric routers - router_mask, // 1: active fabric router mask - }; - - // initialize the semaphore - tt::llrt::write_hex_vec_to_core(device_handle->id(), gk_phys_core, zero_buf, gk_interface_addr); - - KernelHandle kernel; - - if (run_gk_on_idle_ethernet) { - kernel = tt_metal::CreateKernel( - program_handle, - gatekeeper_kernel_src, - {gk_logical_core}, - tt_metal::EthernetConfig{ - .eth_mode = Eth::IDLE, - .noc = tt_metal::NOC::NOC_0, - .compile_args = compile_args, - .defines = defines}); - } else { - kernel = tt_metal::CreateKernel( - program_handle, - gatekeeper_kernel_src, - {gk_logical_core}, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_0, - .noc = tt_metal::NOC::RISCV_0_default, - .compile_args = compile_args, - .defines = defines}); - } - - tt_metal::SetRuntimeArgs(program_handle, kernel, gk_logical_core, runtime_args); - } - - void wait_for_gatekeeper_sync() { - uint32_t gk_status = 0; - uint32_t num_routers = router_logical_cores.size(); - uint32_t sync_addr = gk_interface_addr + offsetof(gatekeeper_info_t, router_sync) + offsetof(sync_word_t, val); - while (num_routers != gk_status) { - gk_status = tt::llrt::read_hex_vec_from_core(device_handle->id(), gk_phys_core, sync_addr, 4)[0]; + void wait_for_router_sync() { + uint32_t master_router_status = 0; + uint32_t expected_val = router_logical_cores.size(); + while (expected_val != master_router_status) { + master_router_status = tt::llrt::read_hex_vec_from_core( + device_handle->id(), router_virtual_cores[master_router_idx], FABRIC_ROUTER_SYNC_SEM, 4)[0]; } } void terminate_router_kernels() { std::vector zero_buf(1, 0); - for (auto& core : router_virtual_cores) { - tt::llrt::write_hex_vec_to_core(device_handle->id(), core, zero_buf, FABRIC_ROUTER_SYNC_SEM); - } - } - - void terminate_gatekeeper_kernel() { - std::vector zero_buf(12, 0); - tt::llrt::write_hex_vec_to_core(device_handle->id(), gk_phys_core, zero_buf, gk_interface_addr); + tt::llrt::write_hex_vec_to_core( + device_handle->id(), router_virtual_cores[master_router_idx], zero_buf, FABRIC_ROUTER_SYNC_SEM); } std::vector select_random_worker_cores(uint32_t count) { @@ -951,11 +884,9 @@ typedef struct test_traffic { tx_device->get_endpoint_id(tx_core), // 1: src_endpoint_id rx_devices[0]->get_noc_offset(rx_core), // 2: dest_noc_offset tx_device->get_noc_offset(controller_logical_core), // 3: controller noc offset - routing_plane, // 4: routing plane to use + eth_chan, // 4: outbound eth chan mesh_chip_id, // 5: mesh and chip id rx_buf_size, // 6: space in rx's L1 - gk_interface_addr, // 7: gk_message_addr_l - tx_device->gk_noc_offset, // 8: gk_message_addr_h }; if (ASYNC_WR & fabric_command) { @@ -968,13 +899,14 @@ typedef struct test_traffic { log_info( LogTest, - "[Device: Phys: {}, Logical: {}] TX kernel running on: logical: x={},y={}; virtual: x={},y={}", + "[Device: Phys: {}, Logical: {}] TX running on: logical: x={},y={}; virtual: x={},y={}, Eth chan: {}", tx_device->physical_chip_id, (uint32_t)tx_device->logical_chip_id, tx_core.x, tx_core.y, tx_virtual_cores[i].x, - tx_virtual_cores[i].y); + tx_virtual_cores[i].y, + (uint32_t)eth_chan); auto kernel = tt_metal::CreateKernel( tx_device->program_handle, tx_kernel_src, @@ -1262,8 +1194,6 @@ int main(int argc, char **argv) { constexpr uint32_t default_tx_y = 0; constexpr uint32_t default_rx_x = 0; constexpr uint32_t default_rx_y = 3; - constexpr uint32_t default_gk_x = 0; - constexpr uint32_t default_gk_y = 9; constexpr uint32_t default_mux_x = 0; constexpr uint32_t default_mux_y = 1; @@ -1379,8 +1309,6 @@ int main(int argc, char **argv) { uint32_t tx_y = test_args::get_command_option_uint32(input_args, "--tx_y", default_tx_y); uint32_t rx_x = test_args::get_command_option_uint32(input_args, "--rx_x", default_rx_x); uint32_t rx_y = test_args::get_command_option_uint32(input_args, "--rx_y", default_rx_y); - gk_x = test_args::get_command_option_uint32(input_args, "--gk_x", default_gk_x); - gk_y = test_args::get_command_option_uint32(input_args, "--gk_y", default_gk_y); uint32_t prng_seed = test_args::get_command_option_uint32(input_args, "--prng_seed", default_prng_seed); uint32_t data_kb_per_tx = test_args::get_command_option_uint32(input_args, "--data_kb_per_tx", default_data_kb_per_tx); @@ -1618,14 +1546,6 @@ int main(int argc, char **argv) { uint32_t worker_unreserved_base_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); - if (run_gk_on_idle_ethernet) { - routing_table_addr = hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalL1MemAddrType::UNRESERVED); - } else { - routing_table_addr = worker_unreserved_base_addr; - } - gk_interface_addr = routing_table_addr + sizeof(fabric_router_l1_config_t) * 4; - socket_info_addr = gk_interface_addr + sizeof(gatekeeper_info_t); - // create router kernels std::vector router_compile_args = { (tunneler_queue_size_bytes >> 4), // 0: rx_queue_size_words @@ -1637,19 +1557,6 @@ int main(int argc, char **argv) { test_device->create_router_kernels(router_compile_args, defines); } - // create gatekeeper kernel - std::vector gatekeeper_compile_args = { - gk_interface_addr, // 0: gk info addr - socket_info_addr, // 1: - routing_table_addr, // 2: - test_results_addr, // 3: test_results_addr - test_results_size, // 4: test_results_size - 0, // 5: timeout_cycles - }; - for (auto& [chip_id, test_device] : test_devices) { - test_device->create_gatekeeper_kernel(gatekeeper_compile_args, defines); - } - if (check_txrx_timeout) { defines["CHECK_TIMEOUT"] = ""; } @@ -1719,9 +1626,9 @@ int main(int argc, char **argv) { tt_metal::detail::LaunchProgram(test_device->device_handle, test_device->program_handle, false); } - // wait for all routers to handshake with their gatekeepers + // wait for all routers to handshake with master router for (auto& [chip_id, test_device] : test_devices) { - test_device->wait_for_gatekeeper_sync(); + test_device->wait_for_router_sync(); } // notify tx controller to signal the tx workers @@ -1735,7 +1642,7 @@ int main(int argc, char **argv) { } // terminate fabric routers for (auto& [chip_id, test_device] : test_devices) { - test_device->terminate_gatekeeper_kernel(); + test_device->terminate_router_kernels(); } // wait for programs to exit diff --git a/tt_fabric/hw/inc/tt_fabric.h b/tt_fabric/hw/inc/tt_fabric.h index 04fa643b82cb..6065f927953e 100644 --- a/tt_fabric/hw/inc/tt_fabric.h +++ b/tt_fabric/hw/inc/tt_fabric.h @@ -23,7 +23,7 @@ const uint32_t SYNC_BUF_PTR_MASK = ((SYNC_BUF_SIZE << 1) - 1); extern uint64_t xy_local_addr; extern volatile local_pull_request_t* local_pull_request; -extern volatile fabric_router_l1_config_t* routing_table; +extern volatile tt_l1_ptr fabric_router_l1_config_t* routing_table; extern chan_payload_ptr inbound_rdptr_ack; extern volatile chan_payload_ptr remote_rdptr; diff --git a/tt_fabric/hw/inc/tt_fabric_api.h b/tt_fabric/hw/inc/tt_fabric_api.h index 5b66fa860d1a..fd96de1a1bd1 100644 --- a/tt_fabric/hw/inc/tt_fabric_api.h +++ b/tt_fabric/hw/inc/tt_fabric_api.h @@ -245,43 +245,19 @@ inline void fabric_socket_connect(socket_handle_t* socket_handle) { while (((volatile socket_handle_t*)socket_handle)->socket_state != SocketState::ACTIVE); } -inline void fabric_endpoint_init(uint32_t base_address, uint32_t gk_interface_addr_l, uint32_t gk_interface_addr_h) { +inline void fabric_endpoint_init(uint32_t base_address, uint32_t outbound_eth_chan) { tt_fabric_init(); client_interface = (volatile fabric_client_interface_t*)base_address; uint32_t routing_tables_offset = base_address + sizeof(fabric_client_interface_t); zero_l1_buf((uint32_t*)client_interface, sizeof(fabric_client_interface_t)); - client_interface->gk_interface_addr = ((uint64_t)gk_interface_addr_h << 32) | gk_interface_addr_l; - client_interface->gk_msg_buf_addr = - (((uint64_t)gk_interface_addr_h << 32) | gk_interface_addr_l) + offsetof(gatekeeper_info_t, gk_msg_buf); client_interface->routing_tables_l1_offset = routing_tables_offset; + client_interface->num_routing_planes = 1; - // make sure fabric node gatekeeper is available. - uint64_t noc_addr = client_interface->gk_interface_addr + offsetof(gatekeeper_info_t, ep_sync); - client_interface->return_status[0] = 0; - while (1) { - noc_async_read_one_packet(noc_addr, (uint32_t)&client_interface->return_status[0], 4); - noc_async_read_barrier(); - if (client_interface->return_status[0] != 0) { - break; - } - } - - // read the gk info first at routing table addr and later override with routing tables - noc_async_read_one_packet( - client_interface->gk_interface_addr, client_interface->routing_tables_l1_offset, sizeof(gatekeeper_info_t)); - noc_async_read_barrier(); - - client_interface->num_routing_planes = ((gatekeeper_info_t*)routing_tables_offset)->routing_planes; - - // read routing tables - uint64_t gk_rt_noc_addr = client_interface->gk_interface_addr - sizeof(fabric_router_l1_config_t) * 4; - uint32_t table_offset; - for (uint32_t i = 0; i < client_interface->num_routing_planes; i++) { - table_offset = sizeof(fabric_router_l1_config_t) * i; - noc_async_read_one_packet( - gk_rt_noc_addr + table_offset, routing_tables_offset + table_offset, sizeof(fabric_router_l1_config_t)); - } + // read routing table + uint64_t dest_addr = get_noc_addr_helper( + eth_chan_to_noc_xy[noc_index][outbound_eth_chan], eth_l1_mem::address_map::FABRIC_ROUTER_CONFIG_BASE); + noc_async_read_one_packet(dest_addr, routing_tables_offset, sizeof(fabric_router_l1_config_t)); noc_async_read_barrier(); } diff --git a/tt_fabric/impl/kernels/tt_fabric_router.cpp b/tt_fabric/impl/kernels/tt_fabric_router.cpp index 0eeb7879f9d7..9cd08cbe2d87 100644 --- a/tt_fabric/impl/kernels/tt_fabric_router.cpp +++ b/tt_fabric/impl/kernels/tt_fabric_router.cpp @@ -24,10 +24,12 @@ constexpr uint32_t fvc_data_buf_size_bytes = fvc_data_buf_size_words * PACKET_WO constexpr uint32_t kernel_status_buf_addr_arg = get_compile_time_arg_val(1); constexpr uint32_t kernel_status_buf_size_bytes = get_compile_time_arg_val(2); constexpr uint32_t timeout_cycles = get_compile_time_arg_val(3); +constexpr bool is_master = get_compile_time_arg_val(4); uint32_t sync_val; uint32_t router_mask; -uint32_t gk_message_addr_l; -uint32_t gk_message_addr_h; +uint32_t master_router_chan; +uint64_t xy_local_addr; +bool terminated_slave_routers = false; // careful, may be null tt_l1_ptr uint32_t* const kernel_status = reinterpret_cast(kernel_status_buf_addr_arg); @@ -35,16 +37,23 @@ tt_l1_ptr volatile chan_req_buf* fvc_consumer_req_buf = reinterpret_cast(FABRIC_ROUTER_REQ_QUEUE_START); volatile tt_l1_ptr fabric_router_l1_config_t* routing_table = reinterpret_cast(eth_l1_mem::address_map::FABRIC_ROUTER_CONFIG_BASE); -uint64_t xy_local_addr; + +volatile uint32_t* sync_sem_addr = (volatile uint32_t*)FABRIC_ROUTER_SYNC_SEM; #define SWITCH_THRESHOLD 0x3FFF -inline void notify_gatekeeper() { - // send semaphore increment to gatekeeper on this device. +inline void wait_for_sem(uint32_t value) { + while (*sync_sem_addr != value) { + // context switch while waiting to allow slow dispatch traffic to go through + internal_::risc_context_switch(); + } +} + +inline void notify_master_router() { + // send semaphore increment to master router on this device. // semaphore notifies all other routers that this router has completed // startup handshake with its ethernet peer. - uint64_t dest_addr = - (((uint64_t)gk_message_addr_h << 32) | gk_message_addr_l) + offsetof(gatekeeper_info_t, router_sync); + uint64_t dest_addr = get_noc_addr_helper(eth_chan_to_noc_xy[noc_index][master_router_chan], FABRIC_ROUTER_SYNC_SEM); noc_fast_atomic_increment( noc_index, NCRISC_AT_CMD_BUF, @@ -55,27 +64,31 @@ inline void notify_gatekeeper() { false, false, MEM_NOC_ATOMIC_RET_VAL_ADDR); +} - volatile uint32_t* sync_sem_addr = (volatile uint32_t*)FABRIC_ROUTER_SYNC_SEM; - // wait for all device routers to have incremented the sync semaphore. - // sync_val is equal to number of tt-fabric routers running on a device. - while (*sync_sem_addr != sync_val) { - // context switch while waiting to allow slow dispatch traffic to go through - internal_::risc_context_switch(); +inline void notify_slave_routers(uint32_t notification) { + uint32_t remaining_cores = router_mask; + for (uint32_t i = 0; i < 16; i++) { + if (remaining_cores == 0) { + break; + } + if ((remaining_cores & (0x1 << i)) && (master_router_chan != i)) { + uint64_t dest_addr = get_noc_addr_helper(eth_chan_to_noc_xy[noc_index][i], FABRIC_ROUTER_SYNC_SEM); + noc_inline_dw_write(dest_addr, notification); + remaining_cores &= ~(0x1 << i); + } } } void kernel_main() { + tt_fabric_init(); fvc_producer_state_t fvc_producer_state; rtos_context_switch_ptr = (void (*)())RtosTable[0]; uint32_t rt_args_idx = 0; sync_val = get_arg_val(rt_args_idx++); router_mask = get_arg_val(rt_args_idx++); - gk_message_addr_l = get_arg_val(rt_args_idx++); - gk_message_addr_h = get_arg_val(rt_args_idx++); - - tt_fabric_init(); + master_router_chan = get_arg_val(rt_args_idx++); write_kernel_status(kernel_status, TT_FABRIC_STATUS_INDEX, TT_FABRIC_STATUS_STARTED); write_kernel_status(kernel_status, TT_FABRIC_MISC_INDEX, 0xff000000); @@ -112,7 +125,19 @@ void kernel_main() { return; } - notify_gatekeeper(); + if constexpr (is_master) { + // wait for all device routers to have incremented the sync semaphore. + // sync_val is equal to number of tt-fabric routers running on a device. + wait_for_sem(sync_val - 1); + notify_slave_routers(sync_val); + // increment the sync sem to signal host that handshake is complete + *sync_sem_addr += 1; + } else { + notify_master_router(); + // wait for the signal from the master router + wait_for_sem(sync_val); + } + uint64_t start_timestamp = get_timestamp(); write_kernel_status(kernel_status, TT_FABRIC_MISC_INDEX, 0xff000001); @@ -176,7 +201,13 @@ void kernel_main() { internal_::risc_context_switch(); } if (*(volatile uint32_t*)FABRIC_ROUTER_SYNC_SEM == 0) { - // terminate signal from host sw. + // terminate signal from host sw + if constexpr (is_master) { + if (!terminated_slave_routers) { + notify_slave_routers(0); + terminated_slave_routers = true; + } + } if (loop_count >= 0x1000) { break; } From 2d4f9945fbb70a8bc4fe1525ef645d99ff6247c3 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Wed, 12 Feb 2025 09:36:41 -0800 Subject: [PATCH 8/8] #0: Clean up ShardSpecBuffer - Rename tensor2d_shape() to tensor2d_shape_in_pages() - Rename size() to num_pages() - Flip height/width in shape_in_pages() - Remove DEBUG_PRINT_SHARD --- .../tt_metal/distributed/test_mesh_buffer.cpp | 10 +++-- ...queueWriteBuffer_and_EnqueueReadBuffer.cpp | 40 +++++++++++-------- tt_metal/api/tt-metalium/buffer.hpp | 26 ++++++------ tt_metal/api/tt-metalium/tt_metal.hpp | 2 +- tt_metal/impl/buffers/buffer.cpp | 20 +++++----- tt_metal/impl/buffers/dispatch.cpp | 10 +++-- tt_metal/tt_metal.cpp | 5 --- .../multi_core/all_gather_op_multi_core.cpp | 8 ++-- .../ccl/sharding_addrgen_helper.cpp | 7 ++-- .../operations/experimental/reshape/view.cpp | 2 +- ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 4 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 4 +- 12 files changed, 73 insertions(+), 65 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_buffer.cpp b/tests/tt_metal/distributed/test_mesh_buffer.cpp index 5fdc6369a24d..f85f57a329b2 100644 --- a/tests/tt_metal/distributed/test_mesh_buffer.cpp +++ b/tests/tt_metal/distributed/test_mesh_buffer.cpp @@ -25,11 +25,11 @@ struct DeviceLocalShardedBufferTestConfig { TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED; ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR; - Shape2D tensor2d_shape() { + Shape2D tensor2d_shape_in_pages() { return {num_pages_per_core.height() * num_cores.height(), num_pages_per_core.width() * num_cores.width()}; } - uint32_t num_pages() { return tensor2d_shape().height() * tensor2d_shape().width(); } + uint32_t num_pages() { return tensor2d_shape_in_pages().height() * tensor2d_shape_in_pages().width(); } std::array shard_shape() { return {num_pages_per_core.height() * page_shape.height(), num_pages_per_core.width() * page_shape.width()}; @@ -44,7 +44,11 @@ struct DeviceLocalShardedBufferTestConfig { ShardSpecBuffer shard_parameters() { return ShardSpecBuffer( - this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape()); + this->shard_grid(), + this->shard_shape(), + this->shard_orientation, + this->page_shape, + this->tensor2d_shape_in_pages()); } }; diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_EnqueueWriteBuffer_and_EnqueueReadBuffer.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_EnqueueWriteBuffer_and_EnqueueReadBuffer.cpp index 4b5b1826c976..77a870d07f35 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_EnqueueWriteBuffer_and_EnqueueReadBuffer.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_EnqueueWriteBuffer_and_EnqueueReadBuffer.cpp @@ -56,11 +56,11 @@ class BufferStressTestConfigSharded { this->num_cores = cores; } - std::array tensor2d_shape() { + std::array tensor2d_shape_in_pages() { return {num_pages_per_core[0] * num_cores[0], num_pages_per_core[1] * num_cores[1]}; } - uint32_t num_pages() { return tensor2d_shape()[0] * tensor2d_shape()[1]; } + uint32_t num_pages() { return tensor2d_shape_in_pages()[0] * tensor2d_shape_in_pages()[1]; } std::array shard_shape() { return {num_pages_per_core[0] * page_shape[0], num_pages_per_core[1] * page_shape[1]}; @@ -73,7 +73,11 @@ class BufferStressTestConfigSharded { ShardSpecBuffer shard_parameters() { return ShardSpecBuffer( - this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape()); + this->shard_grid(), + this->shard_shape(), + this->shard_orientation, + this->page_shape, + this->tensor2d_shape_in_pages()); } uint32_t page_size() { return page_shape[0] * page_shape[1] * element_size; } @@ -87,7 +91,7 @@ struct ShardedSubBufferStressTestConfig { CoreRangeSet cores; Shape2D shard_shape; Shape2D page_shape; - Shape2D tensor2d_shape; + Shape2D tensor2d_shape_in_pages; TensorMemoryLayout layout; ShardOrientation orientation; }; @@ -133,11 +137,12 @@ vector generate_sharded_sub_buffer_test_config uint32_t page_shape_width_div_factor = 1; while (page_shape_width_div_factor <= num_pages_per_shard) { if (page_shape_width_div_factor * page_shape_height_div_factor == num_pages_per_shard) { - uint32_t tensor2d_shape_height = page_shape_height_div_factor; - while (tensor2d_shape_height <= num_pages) { - uint32_t tensor2d_shape_width = page_shape_width_div_factor; - while (tensor2d_shape_width <= num_pages) { - if (tensor2d_shape_height * tensor2d_shape_width == num_pages) { + uint32_t tensor2d_shape_in_pages_height = page_shape_height_div_factor; + while (tensor2d_shape_in_pages_height <= num_pages) { + uint32_t tensor2d_shape_in_pages_width = page_shape_width_div_factor; + while (tensor2d_shape_in_pages_width <= num_pages) { + if (tensor2d_shape_in_pages_height * tensor2d_shape_in_pages_width == + num_pages) { for (TensorMemoryLayout layout : {TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::BLOCK_SHARDED, @@ -157,17 +162,18 @@ vector generate_sharded_sub_buffer_test_config page_shape_height_div_factor, tt::constants::TILE_WIDTH / page_shape_width_div_factor}, - .tensor2d_shape = - {tensor2d_shape_height, tensor2d_shape_width}, + .tensor2d_shape_in_pages = + {tensor2d_shape_in_pages_height, + tensor2d_shape_in_pages_width}, .layout = layout, .orientation = orientation}; configs.push_back(config); } } } - tensor2d_shape_width += page_shape_width_div_factor; + tensor2d_shape_in_pages_width += page_shape_width_div_factor; } - tensor2d_shape_height += page_shape_height_div_factor; + tensor2d_shape_in_pages_height += page_shape_height_div_factor; } } page_shape_width_div_factor += 1; @@ -1018,7 +1024,7 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestReadWriteShardedSubBufferForL1) tt::log_debug( tt::LogTest, "Device: {} buffer_size: {} page_size: {} region_offset: {} region_size: {} shard_shape: [{}, {}] " - "page_shape: [{}, {}] tensor2d_shape: [{}, {}] layout: {} orientation: {} cores: {}", + "page_shape: [{}, {}] tensor2d_shape_in_pages: [{}, {}] layout: {} orientation: {} cores: {}", device->id(), config.buffer_size, config.page_size, @@ -1028,8 +1034,8 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestReadWriteShardedSubBufferForL1) config.shard_shape.width(), config.page_shape.height(), config.page_shape.width(), - config.tensor2d_shape.height(), - config.tensor2d_shape.width(), + config.tensor2d_shape_in_pages.height(), + config.tensor2d_shape_in_pages.width(), magic_enum::enum_name(config.layout).data(), magic_enum::enum_name(config.orientation).data(), config.cores.str()); @@ -1039,7 +1045,7 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestReadWriteShardedSubBufferForL1) {tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH}, config.orientation, config.page_shape, - config.tensor2d_shape); + config.tensor2d_shape_in_pages); auto buffer = Buffer::create(device, config.buffer_size, config.page_size, BufferType::L1, config.layout, shard_spec); diff --git a/tt_metal/api/tt-metalium/buffer.hpp b/tt_metal/api/tt-metalium/buffer.hpp index 119900e59290..e52f45b21059 100644 --- a/tt_metal/api/tt-metalium/buffer.hpp +++ b/tt_metal/api/tt-metalium/buffer.hpp @@ -86,33 +86,33 @@ std::ostream& operator<<(std::ostream& os, const ShardSpec& spec); struct ShardSpecBuffer { ShardSpec tensor_shard_spec; std::array page_shape; - std::array tensor2d_shape; + std::array tensor2d_shape_in_pages; ShardSpecBuffer( - const CoreRangeSet &core_sets_, - const std::array &shard_shape_, - const ShardOrientation &shard_orientation_, - const std::array &page_shape, - const std::array &tensor2d_shape) : + const CoreRangeSet& core_sets_, + const std::array& shard_shape_, + const ShardOrientation& shard_orientation_, + const std::array& page_shape, + const std::array& tensor2d_shape_in_pages) : tensor_shard_spec(core_sets_, shard_shape_, shard_orientation_) { this->page_shape = page_shape; - this->tensor2d_shape = tensor2d_shape; + this->tensor2d_shape_in_pages = tensor2d_shape_in_pages; } ShardSpecBuffer( - const ShardSpec &shard_spec, - const std::array &page_shape, - const std::array &tensor2d_shape) : + const ShardSpec& shard_spec, + const std::array& page_shape, + const std::array& tensor2d_shape_in_pages) : tensor_shard_spec(shard_spec) { this->page_shape = page_shape; - this->tensor2d_shape = tensor2d_shape; + this->tensor2d_shape_in_pages = tensor2d_shape_in_pages; } CoreRangeSet grid() const { return tensor_shard_spec.grid; } std::array shape() const { return tensor_shard_spec.shape; } ShardOrientation orientation() const { return tensor_shard_spec.orientation; } void set_shard_spec(const ShardSpec& shard_spec) { tensor_shard_spec = shard_spec; }; - /* Shape in pages of the full tensor, not per core */ + /* Shape in pages of the full shard */ std::array shape_in_pages() const; - DeviceAddr size() const; + DeviceAddr num_pages() const; }; inline namespace v0 { diff --git a/tt_metal/api/tt-metalium/tt_metal.hpp b/tt_metal/api/tt-metalium/tt_metal.hpp index c5d3bf708b28..b56b6fd168d4 100644 --- a/tt_metal/api/tt-metalium/tt_metal.hpp +++ b/tt_metal/api/tt-metalium/tt_metal.hpp @@ -112,7 +112,7 @@ void ReadShard(Buffer& buffer, uint8_t* host_buffer, const uint32_t& core_id); */ template void ReadShard(Buffer& buffer, std::vector& host_buffer, const uint32_t& core_id) { - host_buffer.resize(buffer.page_size() * buffer.shard_spec().size()); + host_buffer.resize(buffer.page_size() * buffer.shard_spec().num_pages()); ReadShard(buffer, reinterpret_cast(host_buffer.data()), core_id); } diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index e615e87669c8..29cdf05c9802 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -208,12 +208,12 @@ BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer) { uint32_t num_dev_pages = buffer.num_dev_pages(); auto [core_host_page_indices, shard_shape] = core_to_host_pages( num_dev_pages, - shard_spec.size(), + shard_spec.num_pages(), num_cores, buffer.buffer_layout(), shard_spec.page_shape, shard_spec.shape(), - shard_spec.tensor2d_shape); + shard_spec.tensor2d_shape_in_pages); buffer_page_mapping.core_host_page_indices_ = std::vector>(num_cores); @@ -229,7 +229,7 @@ BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer) { auto shape_in_pages = shard_spec.shape_in_pages(); for (uint32_t core_index = 0; core_index < core_host_page_indices.size(); core_index++) { uint32_t valid_shard_page = 0; - buffer_page_mapping.core_host_page_indices_[core_index].reserve(shard_spec.size()); + buffer_page_mapping.core_host_page_indices_[core_index].reserve(shard_spec.num_pages()); uint32_t shard_page_id = 0; for (uint32_t shard_page_x = 0; shard_page_x < shape_in_pages[0]; shard_page_x++) { for (uint32_t shard_page_y = 0; shard_page_y < shape_in_pages[1]; shard_page_y++) { @@ -469,7 +469,7 @@ uint32_t Buffer::num_dev_pages() const { return this->num_pages(); } - return this->shard_spec().size() * this->num_cores().value(); + return this->shard_spec().num_pages() * this->num_cores().value(); } CoreType Buffer::core_type() const { @@ -523,7 +523,7 @@ DeviceAddr Buffer::bank_local_page_address(uint32_t bank_id, uint32_t page_index uint32_t offset; if (is_sharded(this->buffer_layout())) { auto shard_spec = this->shard_spec(); - uint32_t pages_offset_within_bank = page_index % shard_spec.size(); + uint32_t pages_offset_within_bank = page_index % shard_spec.num_pages(); offset = (round_up(this->page_size(), this->alignment()) * pages_offset_within_bank); } else { uint32_t pages_offset_within_bank = page_index / num_banks; @@ -550,7 +550,7 @@ DeviceAddr Buffer::aligned_size_per_bank() const { DeviceAddr Buffer::sharded_page_address(uint32_t bank_id, uint32_t page_index) const { TT_FATAL(is_sharded(this->buffer_layout()), "Buffer not sharded"); auto shard_spec = this->shard_spec(); - uint32_t pages_offset_within_bank = page_index % shard_spec.size(); + uint32_t pages_offset_within_bank = page_index % shard_spec.num_pages(); auto offset = (round_up(this->page_size(), this->alignment()) * pages_offset_within_bank); return translate_page_address(offset, bank_id); } @@ -591,12 +591,12 @@ bool ShardSpec::operator==(const ShardSpec&) const = default; bool ShardSpec::operator!=(const ShardSpec&) const = default; std::array ShardSpecBuffer::shape_in_pages() const { - auto width_in_pages = page_shape[0] == 0 ? 0 : tensor_shard_spec.shape[0] / page_shape[0]; - auto height_in_pages = page_shape[1] == 0 ? 0 : tensor_shard_spec.shape[1] / page_shape[1]; - return {width_in_pages, height_in_pages}; + auto height_in_pages = page_shape[0] == 0 ? 0 : tensor_shard_spec.shape[0] / page_shape[0]; + auto width_in_pages = page_shape[1] == 0 ? 0 : tensor_shard_spec.shape[1] / page_shape[1]; + return {height_in_pages, width_in_pages}; } -DeviceAddr ShardSpecBuffer::size() const { +DeviceAddr ShardSpecBuffer::num_pages() const { auto shape_in_pages_ = this->shape_in_pages(); return shape_in_pages_[0] * shape_in_pages_[1]; } diff --git a/tt_metal/impl/buffers/dispatch.cpp b/tt_metal/impl/buffers/dispatch.cpp index 8655c8307093..f1de42f22e9b 100644 --- a/tt_metal/impl/buffers/dispatch.cpp +++ b/tt_metal/impl/buffers/dispatch.cpp @@ -77,11 +77,12 @@ ShardedBufferWriteDispatchParams initialize_sharded_buf_dispatch_params( const BufferDispatchConstants& buf_dispatch_constants, const BufferRegion& region) { ShardedBufferWriteDispatchParams dispatch_params; - dispatch_params.width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; + dispatch_params.width_split = + buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape_in_pages[1]; dispatch_params.buffer_page_mapping = (dispatch_params.width_split) ? buffer.get_buffer_page_mapping() : nullptr; dispatch_params.total_pages_to_write = region.size / buffer.page_size(); dispatch_params.total_pages_written = 0; - dispatch_params.max_pages_per_shard = buffer.shard_spec().size(); + dispatch_params.max_pages_per_shard = buffer.shard_spec().num_pages(); dispatch_params.page_size_to_write = buffer.aligned_page_size(); dispatch_params.dst_page_index = region.offset / buffer.page_size(); dispatch_params.starting_dst_host_page_index = region.offset / buffer.page_size(); @@ -587,11 +588,12 @@ ShardedBufferReadDispatchParams initialize_sharded_buf_read_dispatch_params( dispatch_params.src_page_index = region.offset / buffer.page_size(); dispatch_params.starting_src_host_page_index = region.offset / buffer.page_size(); dispatch_params.unpadded_dst_offset = 0; - dispatch_params.width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1]; + dispatch_params.width_split = + buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape_in_pages[1]; dispatch_params.buffer_page_mapping = (dispatch_params.width_split) ? buffer.get_buffer_page_mapping() : nullptr; dispatch_params.total_pages_to_read = region.size / buffer.page_size(); dispatch_params.total_pages_read = 0; - dispatch_params.max_pages_per_shard = buffer.shard_spec().size(); + dispatch_params.max_pages_per_shard = buffer.shard_spec().num_pages(); dispatch_params.expected_num_workers_completed = expected_num_workers_completed; return dispatch_params; } diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 4caeae9b22c1..59e6543a82e1 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -293,8 +293,6 @@ inline void SetRuntimeArgsImpl( } // namespace -// #define DEBUG_PRINT_SHARD - namespace detail { bool WriteToDeviceDRAMChannel(IDevice* device, int dram_channel, uint32_t address, std::vector& host_buffer) { @@ -586,9 +584,6 @@ void ReadFromDeviceSharded(Buffer& buffer, uint8_t* host_buffer, bool shard_orde TensorMemoryLayout buffer_layout = buffer.buffer_layout(); auto device = buffer.device(); -#ifdef DEBUG_PRINT_SHARD - std::cout << "Reading From Device Height Sharded " << std::endl; -#endif auto total_pages = buffer.num_dev_pages(); uint32_t page_size = buffer.page_size(); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index 6951764459fb..a31309388e32 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -342,12 +342,12 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( log_trace(tt::LogOp, "input_buffer->page_size: {}", input_page_size); log_trace( tt::LogOp, - "input_buffer->shard_spec().tensor2d_shape[0]: {}", - input_buffer->shard_spec().tensor2d_shape[0]); + "input_buffer->shard_spec().tensor2d_shape_in_pages[0]: {}", + input_buffer->shard_spec().tensor2d_shape_in_pages[0]); log_trace( tt::LogOp, - "input_buffer->shard_spec().tensor2d_shape[1]: {}", - input_buffer->shard_spec().tensor2d_shape[1]); + "input_buffer->shard_spec().tensor2d_shape_in_pages[1]: {}", + input_buffer->shard_spec().tensor2d_shape_in_pages[1]); } const uint32_t max_buffer_per_chunk = tt::round_down(all_gather_config.get_eth_buffer_size(), input_page_size); const uint32_t max_pages_per_chunk = max_buffer_per_chunk / input_page_size; diff --git a/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp index 1bb57fa6e514..5e221b3fdf73 100644 --- a/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp @@ -155,16 +155,17 @@ std::vector generate_compile_time_args(const tt::tt_metal::Tensor& t) shard_addr_gen_consts::ContiguityType contiguity = (t.buffer()->aligned_page_size() != t.buffer()->page_size()) ? shard_addr_gen_consts::ContiguityType::PADDING_BETWEEN_PAGES - : (buf_shard_spec.tensor2d_shape[1] == (pages_per_shard_x * get_sharding_core_count(t))) + : (buf_shard_spec.tensor2d_shape_in_pages[1] == (pages_per_shard_x * get_sharding_core_count(t))) ? shard_addr_gen_consts::ContiguityType::NO_SHARD_PADDING : shard_addr_gen_consts::ContiguityType::PADDING_IN_RIGHTMOST_SHARD; args.push_back(static_cast(t.memory_config().memory_layout)); // Memory layout args.push_back(static_cast(get_sharding_core_count(t))); // The number of sharding cores args.push_back(static_cast(t.buffer()->aligned_page_size())); // The page size we offset each write to TT_FATAL(t.buffer()->aligned_page_size() > 0, "aligned page size is 0"); - TT_FATAL(buf_shard_spec.tensor2d_shape[1] > 0, "the page is empty"); + TT_FATAL(buf_shard_spec.tensor2d_shape_in_pages[1] > 0, "the page is empty"); args.push_back(static_cast( - buf_shard_spec.tensor2d_shape[1])); // The number of pages in each sharding row not including padding pages + buf_shard_spec + .tensor2d_shape_in_pages[1])); // The number of pages in each sharding row not including padding pages args.push_back(static_cast(contiguity)); // This defines times when contiguous pages can't be calculated args.push_back(pages_per_shard_x); args.push_back(pages_per_shard_y); diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp index 1a7aaf2fa0d1..0753f8468dc8 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp @@ -108,7 +108,7 @@ Tensor tensor_reshape( shard_spec.shape[1] = new_logical_shape[-1]; shard_spec_buffer.page_shape = {1, new_logical_shape[-1]}; - shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1}; + shard_spec_buffer.tensor2d_shape_in_pages = {shard_spec.shape[0], 1}; shard_spec_buffer.set_shard_spec(shard_spec); device_buffer->set_shard_spec(shard_spec_buffer); diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index f119c7bc6217..298f9c6f5e64 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -184,7 +184,7 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn page_shape.height()); const auto width_in_pages = physical_size.width() / page_shape.width(); const auto height_in_pages = physical_size.height() / page_shape.height(); - const std::array tensor2d_shape{height_in_pages, width_in_pages}; + const std::array tensor2d_shape_in_pages{height_in_pages, width_in_pages}; auto shard_spec = memory_config_.shard_spec.value(); @@ -198,7 +198,7 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn default: TT_THROW("Unsupported shard mode {} in compute_shard_spec_buffer!", shard_spec.mode); } - ShardSpecBuffer shard_spec_buffer(shard_spec, std::array(page_shape), tensor2d_shape); + ShardSpecBuffer shard_spec_buffer(shard_spec, std::array(page_shape), tensor2d_shape_in_pages); return shard_spec_buffer; } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 1e5e153417bd..fef10f167c2b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -809,8 +809,8 @@ bool Tensor::is_allocated() const { std::vector Tensor::host_page_ordering() { const auto& buffer_page_mapping = *this->buffer()->get_buffer_page_mapping(); auto cores = buffer_page_mapping.all_cores_; - auto shard_size = buffer()->shard_spec().size(); - auto num_pages = cores.size() * shard_size; + auto shard_num_pages = buffer()->shard_spec().num_pages(); + auto num_pages = cores.size() * shard_num_pages; std::vector ret_vec; ret_vec.reserve(num_pages);