From 5d343d461db631892460498ca736a814a1743a9f Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Thu, 28 Nov 2024 22:17:48 +0000 Subject: [PATCH 01/31] #15269: reshape fully on device now --- ttnn/CMakeLists.txt | 2 + .../device/device/rm_reshape_interleaved.cpp | 208 ++++++++++ .../device/host/reshape_rm_host_prep.cpp | 103 +++++ .../reshape_view/device/reshape_rm_op.cpp | 42 +++ .../reshape_view/device/reshape_rm_op.hpp | 31 ++ .../data_movement/reshape_view/reshape.cpp | 356 ++++++++++++------ .../data_movement/reshape_view/reshape.hpp | 19 +- .../reshape_view/reshape_common.hpp | 5 + .../reshape_view/reshape_pybind.cpp | 40 +- 9 files changed, 685 insertions(+), 121 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index ca000262872..a7290e58c96 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -100,6 +100,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp new file mode 100644 index 00000000000..96b0650a4fa --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -0,0 +1,208 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +/* +Function reads from RM and writes to RM + +Assumptions: + +Compile arguments +0. src0_is_dram: 1 if source is dram else 0 +1. read_size_is_pow2: 1 if read size is power of 2 else 0 +2. log_base_2_of_page_size: log base 2 of page size +3. write_size_is_pow2: 1 if write size is power of 2 else 0 +4. log_base_2_of_page_size: log base 2 of page size +5. needs_read_allignment: 1 if read needs allignment else 0 +//Needed if BRAM and page size is not multiple of 64 bytes + +Runtime arguments +0. src_addr: source address +1. dst_addr: destination address +2. source_page_size_bytes: source page size in bytes +3. dest_page_size_bytes: destination page size in bytes +4. source_read_size_bytes: source read size in bytes +5. read_start_page: read start page +6. read_end_page: read end page +7. write_start_page: write start page +*/ +#include +#include "dataflow_api.h" +#include +#include +#include "debug/dprint.h" // required in all kernels using DPRINT + +#define MASK_64 0xFFFFFFFFFFFFFFC0 +#define OFFSET_64 0x000000000000003F +#define MASK_16 0xFFFFFFFFFFFFFFF0 +#define OFFSET_16 0x000000000000000F + + +template +FORCE_INLINE +void tt_memmove ( + const uint32_t dst_l1_addr, + const uint64_t src_l1_addr, + const uint32_t bytes) +{ + //Uses noc_async_read when possible to copy the data over + if constexpr (guaranteed_16B_alligned) + { + noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); + noc_async_read_barrier(); + } + else + { + if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16)) + { + noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); + noc_async_read_barrier(); + } + else + { + memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes)); + } + } +} + + +void kernel_main() { + //We are guranteed to be in 2D going to 2D + + const uint32_t src_addr = get_arg_val(0); + const uint32_t dst_addr = get_arg_val(1); + const uint32_t source_page_size_bytes = get_arg_val(2); + const uint32_t dest_page_size_bytes = get_arg_val(3); + //If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B) + const uint32_t source_read_size_bytes = get_arg_val(4); + const uint32_t read_start_page = get_arg_val(5); + const uint32_t read_end_page = get_arg_val(6); + const uint32_t write_start_page = get_arg_val(7); + //cb_id_in0 is a circular buffer with 1 source_page_size_bytes page if no alignment needed + //source_read_size_bytes otherwise + const uint32_t cb_id_in0 = get_arg_val(8); + //cb_id_in1 is a circular buffer with 1 dest_page_size_bytes+16 (rounded up to next 64B) page + const uint32_t cb_id_in1 = get_arg_val(9); + + + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; + #define src_aligned_to_64 get_compile_time_arg_val(1) == 1 + #define src_aligned_to_16 get_compile_time_arg_val(2) == 1 + #define dst_aligned_to_16 get_compile_time_arg_val(3) == 1 + + + const InterleavedAddrGen s = { + .bank_base_address = src_addr, + .page_size = source_page_size_bytes + }; + + const InterleavedAddrGen d = { + .bank_base_address = dst_addr, + .page_size = dest_page_size_bytes + }; + + + uint32_t read_offset = 0; + uint32_t write_page = write_start_page; + uint32_t readable = 0; + uint32_t transaction = 0; + uint32_t writable = dest_page_size_bytes; + //cb_id_in0 is a CB source_read_size_bytes page size, 1 page + //cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page + cb_reserve_back(cb_id_in0, 1); + cb_reserve_back(cb_id_in1, 1); + const uint32_t source_buffer = get_write_ptr(cb_id_in0); + const uint32_t dest_buffer = get_write_ptr(cb_id_in1); + + uint64_t dst_noc_addr = get_noc_addr(write_page, d); +#if (dst_aligned_to_16) + uint32_t write_offset = 0; +#else + uint32_t write_offset = dst_noc_addr&OFFSET_16; + uint32_t begin_write_offset = write_offset; +#endif + for (uint32_t i = read_start_page; i <= read_end_page; i++) { + //Read from source + uint64_t src_noc_addr = s.get_noc_addr(i,0); + +#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) + //Aligned to 64 bytes or 16 bytes but L1 + noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes); + read_offset = 0; +#elif (tensor_is_dram) + //DDR but not alligned to 64 (potentially also not alligned to 16) + noc_async_read(src_noc_addr&MASK_64, source_buffer, source_read_size_bytes); + read_offset = src_noc_addr&OFFSET_64; +#else + //L1 but not alligned to 16 + noc_async_read(src_noc_addr&MASK_16, source_buffer, source_read_size_bytes); + read_offset = src_noc_addr&OFFSET_16; +#endif + readable = source_page_size_bytes; + noc_async_read_barrier(); + + //Write to dest + while (readable > 0) + { + noc_async_write_barrier(); + if (readable < writable) + { + tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); + writable = writable -readable; + write_offset = write_offset + readable; + readable = 0; + } + else if (readable == writable) + { + tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); +#if ((dst_aligned_to_16)) + noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); +#else + noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); +#endif + writable = dest_page_size_bytes; + readable = 0; + if (i == read_end_page-1) + { + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); + return; + } + write_page++; + dst_noc_addr = get_noc_addr(write_page, d); +#if ((dst_aligned_to_16)) + write_offset=0; +#else + write_offset = dst_noc_addr&OFFSET_16; + begin_write_offset = write_offset; +#endif + } + else + { + //writable < readable + + tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, writable); +#if ((dst_aligned_to_16)) + noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); +#else + noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); +#endif + readable = readable - writable; + read_offset = read_offset + writable; + write_page++; + dst_noc_addr = get_noc_addr(write_page, d); +#if ((dst_aligned_to_16)) + write_offset=0; +#else + write_offset = dst_noc_addr&OFFSET_16; + begin_write_offset = write_offset; +#endif + writable = dest_page_size_bytes; + } + } + } + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); + return; +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp new file mode 100644 index 00000000000..2c7410f7a6b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttnn/operations/cb_utils.hpp" +#include "ttnn/operations/math.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/operations/core/work_split/work_split_tilize.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp" + +#include +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/core.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/types.hpp" +#include "ttnn/decorators.hpp" + +#define MASK_64 0xFFFFFFFFFFFFFFC0 +#define OFFSET_64 0x000000000000003F +#define MASK_16 0xFFFFFFFFFFFFFFF0 +#define OFFSET_16 0x000000000000000F + +namespace ttnn::operations::data_movement::rm_reshape{ + +operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output) +{ + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + //get datum size + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + const uint32_t data_size = input.element_size(); + CoreRange core({0, 0}, {0, 0}); + + tt::tt_metal::Device *device = input.device(); + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + tt::log_debug("row major reshape"); + tt::log_debug("input shape: {}", input_log_shape); + tt::log_debug("output shape: {}", output_log_shape); + tt::log_debug("data size: {}", data_size); + uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; + uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; + uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128; + uint32_t read_start_page = 0; + uint32_t read_end_page = input_log_shape[-2]; + uint32_t write_start_page = 0; + tt::tt_metal::Buffer *src_buffer = input.buffer(); + tt::tt_metal::Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + const uint32_t cb_size0 = source_read_size_bytes; + const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; + + uint32_t src0_cb_index = 0; + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_size0*2, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size0); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size1); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + //set the runtime args + //set the compile time args + uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector compile_time_args = { + (std::uint32_t) src0_is_dram, + (std::uint32_t) (source_page_size_bytes%64==0) ? 1 : 0, + (std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0, + (std::uint32_t) (dest_page_size_bytes%16==0) ? 1 : 0, + }; + + tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); + std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + read_start_page, + read_end_page, + write_start_page, + src0_cb_index, + src1_cb_index + }; + tt::tt_metal::SetRuntimeArgs( + program, + reader_kernel_id, + core, + reader_runtime_args + ); + return {.program=std::move(program)}; +} +}; // namespace ttnn::operations::data_movement::rm_reshape diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp new file mode 100644 index 00000000000..6dae98e84aa --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.hpp" +#include "tt_metal/host_api.hpp" + +#include + +namespace ttnn { + +void RM_RESHAPE_STRUCT::validate(const std::vector& input_tensors) const { + //Validate the input tensor + const Tensor& input_tensor_a = input_tensors.at(0); + TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to reshape need to be on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!"); + TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "This function is for RM->RM"); + TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::UINT32 or input_tensor_a.get_dtype() == DataType::FLOAT32, "Can only work with bfloat16/float32 or uint32 tensors"); + TT_FATAL(this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout, "Output tensor must have the same memory layout as input tensor"); +} + +std::vector RM_RESHAPE_STRUCT::compute_output_shapes(const std::vector& input_tensors) const { + return {output_shape.logical_shape()}; +} + +std::vector RM_RESHAPE_STRUCT::create_output_tensors(const std::vector& input_tensors) const { + //Create the output tensor + const auto& input_tensor_a = input_tensors.at(0); + auto mem_config = this->output_mem_config; + if (input_tensor_a.memory_config().is_sharded()) { + auto shard_spec = input_tensor_a.shard_spec().value(); + shard_spec.shape[0] = output_shape[0]; + mem_config.shard_spec = shard_spec; + } + return {create_device_tensor(output_shape, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), input_tensor_a.device(), mem_config)}; +} + +operation::ProgramWithCallbacks RM_RESHAPE_STRUCT::create_program( const std::vector& input_tensors, std::vector& output_tensors) const +{ + return operations::data_movement::rm_reshape::rm_reshape_preparer(input_tensors.at(0), output_tensors.at(0)); +} +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.hpp new file mode 100644 index 00000000000..6ac1e4911fc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.hpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/run_operation.hpp" +#include "ttnn/operations/eltwise/binary/binary.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape_common.hpp" +namespace ttnn { + +struct RM_RESHAPE_STRUCT { + const ttnn::Shape output_shape; + MemoryConfig output_mem_config; + + + //Required functions to all tensor op functions + void update_structure (const Tensor& input_tensor); + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; +}; + + +}// namespace ttnn +namespace ttnn::operations::data_movement::rm_reshape{ + +operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output); +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 1bf6949fc0b..f2137dfc2e3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -2,9 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 + #include "ttnn/common/constants.hpp" #include "ttnn/run_operation.hpp" #include "reshape.hpp" +#include "reshape_common.hpp" #include "tt_metal/common/constants.hpp" #include #include @@ -14,12 +16,20 @@ #include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" #include "ttnn/operations/data_movement/slice/slice.hpp" #include "ttnn/operations/core/core.hpp" +#include "device/reshape_rm_op.hpp" +#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" +#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" +#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp" +#include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp" namespace ttnn::operations::data_movement { + namespace detail { ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { + //This function is due to embedding issue + tt::log_warning("host_reshape is deprecated and will be removed in the near future"); if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) { return tensor.reshape(shape); } @@ -48,116 +58,189 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) return device_tensor; } +//Wrapper to turn the ND-> MD problem into 3D->3D for tiled and 2D->2D for Row Major + ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout( - const ttnn::Tensor& tensor, const ttnn::Shape& shape) { + const ttnn::Tensor& tensor, + const ttnn::Shape& shape, + const uint32_t tile_first_dim, + const uint32_t tile_second_dim, + const MemoryConfig &memory_config, + const uint8_t queue_id, + const PadValue &pad_value + ) +{ + //This function turns ND -> MD into 2D->MD for row major and 3D->MD for tiled using a 0 cost view const auto layout = tensor.get_layout(); - auto shape_with_padding = shape.padded_shape(); - auto tensor_shape = tensor.get_shape(); - auto tensor_shape_with_padding = tensor_shape.padded_shape(); + const auto tensor_shape = tensor.get_shape(); + TT_FATAL((tensor_shape.rank()!=0), "can't do reshape from rank 0 tensor"); + if(layout == ttnn::ROW_MAJOR_LAYOUT) + { + //Collapse into the second last dimension + uint32_t second_dim = 1; + for (int i=0; i MD into an equivalent 2D->2D conversion and then turns the 2D output back to MD using a 0 cost view + TT_FATAL((shape.rank()!=0), "can't do reshape to rank 0 tensor"); + //Collapse into the second last dimension + uint32_t second_dim = 1; + for (int i=0; i = 2 and shape.rank() >= 2) { - // Handle the case when the tensor is not contiguous but the last two dimensions are the same and so reshape - // is possible - if (tensor_shape[-1] == shape[-1] and tensor_shape[-2] == shape[-2] and - tensor_shape_with_padding[-1] == shape_with_padding[-1] and - tensor_shape_with_padding[-2] == shape_with_padding[-2]) { - reshaped_rm_tensor = rm_tensor.reshape(shape); - } - } else { - reshaped_rm_tensor = host_reshape(tensor, shape); - } +//Entry points into device prep code +ttnn::Tensor perform_reshape_on_2D_RM( + const ttnn::Tensor& tensor, + const ttnn::Shape& shape, + const MemoryConfig &memory_config, + const uint8_t queue_id + ) +{ + auto temp_tensor = tensor; + auto intermediate_mem_config = tensor.memory_config(); + auto intermediate_out_memory_config = memory_config; + if(tensor.memory_config().is_sharded()) + { + auto temp_memory_config = tensor.memory_config(); + temp_memory_config.memory_layout = TensorMemoryLayout::INTERLEAVED; + temp_tensor = ttnn::sharded_to_interleaved(queue_id, tensor, temp_memory_config, std::nullopt); } - // Can'd do untilize on device due to inner dim size - else { - reshaped_rm_tensor = host_reshape(tensor, shape); + if (memory_config.is_sharded()) + { + intermediate_out_memory_config.memory_layout = TensorMemoryLayout::INTERLEAVED; } - - if (((shape[-1] * tensor.element_size()) % sizeof(uint32_t) == 0) and reshaped_rm_tensor.layout() != layout) { - return ttnn::to_layout(reshaped_rm_tensor, layout, std::nullopt, std::nullopt, (Device*)nullptr); - } else { - return reshaped_rm_tensor; + //Guaranteed to be interleaved + //We are guaranteed to be working 2D->2D in this function + auto temp_tensor2 = operation::run( + RM_RESHAPE_STRUCT + { + shape, + intermediate_out_memory_config + }, + {temp_tensor}, + {}, + {}, + queue_id + ).at(0); + if(memory_config.is_sharded()) + { + return ttnn::interleaved_to_sharded(queue_id,temp_tensor2, memory_config,std::nullopt); + } + else + { + return temp_tensor2; } } -} // namespace detail +} -ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape) { - // Apply the correct padding metadata to the target shape +ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim) { + //Apply the correct padding metadata to the target shape auto padded = shape.with_tile_padding(); auto rank = shape.rank(); - const int8_t correction_1 = - (ttnn::types::TILE_SIZE - (int)padded[-1] % ttnn::types::TILE_SIZE) % ttnn::types::TILE_SIZE; - if (rank == 1) { - return ttnn::Shape({1, shape[0]}, {32, padded[0] + correction_1}); - } - const int8_t correction_2 = - (ttnn::types::TILE_SIZE - (int)padded[-2] % ttnn::types::TILE_SIZE) % ttnn::types::TILE_SIZE; - switch (rank) { - case 2: return ttnn::Shape({shape[0], shape[1]}, {padded[0] + correction_2, padded[1] + correction_1}); break; + const int8_t correction_1 =(tile_first_dim - (int)padded[-1] % tile_first_dim) % tile_first_dim; + if(rank == 1) + { + return ttnn::Shape({1,shape[0]},{32,padded[0]+correction_1}); + } + const int8_t correction_2 =(tile_second_dim - (int)padded[-2] % tile_second_dim) % tile_second_dim; + switch(rank) + { + case 2: + return ttnn::Shape({shape[0],shape[1]},{padded[0]+correction_2,padded[1]+correction_1}); + break; case 3: - return ttnn::Shape( - {shape[0], shape[1], shape[2]}, {padded[0], padded[1] + correction_2, padded[2] + correction_1}); + return ttnn::Shape({shape[0],shape[1],shape[2]},{padded[0],padded[1]+correction_2,padded[2]+correction_1}); break; case 4: - return ttnn::Shape( - {shape[0], shape[1], shape[2], shape[3]}, - {padded[0], padded[1], padded[2] + correction_2, padded[3] + correction_1}); + return ttnn::Shape({shape[0],shape[1],shape[2],shape[3]},{padded[0],padded[1],padded[2]+correction_2,padded[3]+correction_1}); break; + } return shape; } -ttnn::Tensor PerformView(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { +ttnn::Tensor PerformView(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim) { + if (tensor.get_shape() == shape) { + return tensor; + } if (tensor.get_layout() == ttnn::TILE_LAYOUT && - (shape[-1] % ttnn::types::TILE_SIZE != 0 || shape[-2] % ttnn::types::TILE_SIZE != 0)) { - // Correct the output shape to add padding metadata before reshape (view) - return tensor.reshape(tiling_reshape_corrector(shape)); + (shape[-1]%tile_first_dim!=0 || shape.rank()==1 || shape[-2]%tile_second_dim!=0 )) + { + //Correct the output shape to add padding metadata before reshape (view) + return tensor.reshape(tiling_reshape_corrector(shape, tile_first_dim, tile_second_dim)); } - // Perform a reshape (view) + //Perform a reshape (view) return tensor.reshape(shape); } -void Validate_transform(const ttnn::Shape& input_shape, const ttnn::Shape& output_shape) { - // Reshape should not be adding or removing data - uint32_t input_volume = 1; - ; - uint32_t output_volume = 1; - for (int i = 0; i < input_shape.rank(); i++) { - input_volume = input_volume * input_shape[i]; - } - for (int i = 0; i < output_shape.rank(); i++) { - output_volume = output_volume * output_shape[i]; - } - TT_FATAL(input_volume == output_volume, "Invalid Reshape, input and output volume must match"); -} - -ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + const ttnn::Shape& shape, + const std::optional &memory_config, + const uint8_t queue_id, + const std::optional &pad_value + ) { auto layout = tensor.get_layout(); auto tensor_shape = tensor.get_shape(); @@ -165,21 +248,28 @@ ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn if (tensor_shape == shape) { return tensor; } - // This is a constraint Torch places on reshape I was assuming, but it causes half of the codebase to fail if added - // Validate_transform(tensor_shape, shape) - // For view the following cases work: - // RM: The last dimension is the same - // Tiled: The last two dimensions are the same or there is no padding on the second last dimension - const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2] : 1; - const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2] : 1; - bool this_is_view = - (tensor_shape[-1] == shape[-1]) && - ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || // Its row major - (shape_second_last_dim == tensor_shape_second_last_dim) || // Second last dimension is the same - (shape_second_last_dim % ttnn::types::TILE_SIZE == 0 && - tensor_shape_second_last_dim % ttnn::types::TILE_SIZE == - 0)); // There is no padding on the second last dimension + PadValue default_pad_value; + if(tensor.get_dtype() == DataType::BFLOAT16 or tensor.get_dtype() == DataType::FLOAT32) { + default_pad_value = 0.0f; + } + else { + default_pad_value = (uint32_t)0; + } + //const uint32_t tile_first_dim =tensor.get_tile().get_width(); + //const uint32_t tile_second_dim =tensor.get_tile().get_height(); + const uint32_t tile_first_dim = 32; + const uint32_t tile_second_dim = 32; + //The following case should only be called for the device storage case, the rest is a bandaid + //for issue 15317 + + + const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2]:1; + const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2]:1; + bool this_is_view = (tensor_shape[-1] == shape[-1]) && + ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || //Its row major + (tensor_shape_second_last_dim==shape_second_last_dim) || //Second last dimension is the same + (shape_second_last_dim % tile_second_dim==0 && tensor_shape_second_last_dim % tile_first_dim==0)); //There is no padding on the second last dimension bool tile_tensor_view_reshape_possible = (layout == ttnn::Layout::TILE and shape.with_tile_padding().rank() >= 2 and shape.with_tile_padding()[-2] % ttnn::TILE_SIZE == 0 and @@ -190,30 +280,68 @@ ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn // This case has been allowed in the past though it means introducing padding values to the data return tensor.reshape(shape); } - if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) or this_is_view) { - return PerformView(tensor, shape); + + + if (this_is_view) { + return PerformView(tensor,shape, tile_first_dim, tile_second_dim); } - if (tensor_shape.rank() > 3) { - uint32_t mult_factor = 1; - for (int i = 0; i < tensor_shape.rank() - 3; i++) { - mult_factor = mult_factor * tensor_shape[i]; - } - const ttnn::Tensor temp_tensor = - PerformView(tensor, ttnn::Shape{tensor_shape[-3] * mult_factor, tensor_shape[-2], tensor_shape[-1]}); - return detail::convert_tensor_to_rm_reshape_convert_back_to_orig_layout(temp_tensor, shape); + if(shape.logical_shape().volume() != tensor.get_logical_volume()) + { + //This is a completely incorrect test but it is due to issue + return detail::host_reshape(tensor, shape); } // Catch-all // Do the reshape in row-major + return detail::convert_tensor_to_rm_reshape_convert_back_to_orig_layout( + tensor, + shape, + tile_first_dim, + tile_second_dim, + memory_config.value_or(tensor.memory_config()), + queue_id, + pad_value.value_or(default_pad_value) + ); +} + +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + const ttnn::Shape& shape + ) { + return invoke(tensor, shape,std::nullopt,0,std::nullopt); + } + +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + const ttnn::SimpleShape& shape, + const std::optional &memory_config, + const uint8_t queue_id, + const std::optional &pad_value + ) { + return invoke(tensor, ttnn::Shape(shape.view()),memory_config,queue_id,pad_value); +} - return detail::convert_tensor_to_rm_reshape_convert_back_to_orig_layout(tensor, shape); +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + const ttnn::SimpleShape& shape + ) { + return invoke(tensor, ttnn::Shape(shape.view()),std::nullopt,0,std::nullopt); } -ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { - return invoke(tensor, ttnn::Shape(shape.view())); +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + tt::stl::Span shape_vector, + const std::optional &memory_config, + const uint8_t queue_id, + const std::optional &pad_value + ) { + return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector),memory_config,queue_id,pad_value); } -ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, tt::stl::Span shape_vector) { - return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector)); +ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + tt::stl::Span shape_vector + ) { + return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector),std::nullopt,0,std::nullopt); } -} // namespace ttnn::operations::data_movement +} // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp index 063f21063e4..208d6b526d8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp @@ -5,19 +5,34 @@ #pragma once #include "ttnn/decorators.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape_common.hpp" + namespace ttnn { namespace operations::data_movement { +namespace detail { + ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape); + ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id, const PadValue &pad_value); + ttnn::Tensor fix_shape_and_perform_reshape_on_2D_RM(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id); + ttnn::Tensor perform_reshape_on_2D_RM(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const MemoryConfig &memory_config, const uint8_t queue_id); +} + +ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape); +ttnn::Tensor PerformView(const ttnn::Tensor& tensor, const ttnn::Shape& shapeconst, uint32_t tile_first_dim, const uint32_t tile_second_dim); +void Validate_transform (const ttnn::Shape& input_shape, const ttnn::Shape& output_shape); struct ReshapeViewOperation { + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape,const std::optional &memory_config,const uint8_t queue_id,const std::optional &pad_value); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape,const std::optional &memory_config,const uint8_t queue_id,const std::optional &pad_value); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector,const std::optional &memory_config,const uint8_t queue_id,const std::optional &pad_value); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector); }; + } // namespace operations::data_movement -constexpr auto reshape = - ttnn::register_operation<"ttnn::reshape", ttnn::operations::data_movement::ReshapeViewOperation>(); +constexpr auto reshape = ttnn::register_operation<"ttnn::reshape", ttnn::operations::data_movement::ReshapeViewOperation>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp new file mode 100644 index 00000000000..a39245856b0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +typedef std::variant PadValue; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp index 0df612b0e1c..2b757878e4c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp @@ -10,6 +10,8 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/types.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape_common.hpp" + namespace ttnn::operations::data_movement { @@ -22,22 +24,45 @@ void bind_reshape_view(pybind11::module& module, const data_movement_operation_t operation, doc, ttnn::pybind_overload_t{ - [](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, const ttnn::Shape& shape) - -> ttnn::Tensor { return self(input_tensor, shape); }, + [](const data_movement_operation_t& self, + const ttnn::Tensor& input_tensor, + const ttnn::Shape& shape, + const std::optional &memory_config, + const uint8_t queue_id, + const std::optional &pad_value + ) -> ttnn::Tensor { + return self(input_tensor, shape); + }, py::arg("input_tensor"), py::arg("shape"), - }, + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("queue_id") = 0, + py::arg("pad_value") = std::nullopt + }, ttnn::pybind_overload_t{ [](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, - const ttnn::SmallVector shape) -> ttnn::Tensor { return self(input_tensor, shape); }, + const ttnn::SmallVector shape, + const std::optional &memory_config, + const uint8_t queue_id, + const std::optional &pad_value + ) -> ttnn::Tensor { + return self(input_tensor, shape); + }, py::arg("input_tensor"), py::arg("shape"), - }); + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("queue_id") = 0, + py::arg("pad_value") = std::nullopt + } + ); } } // namespace detail + void py_bind_reshape_view(pybind11::module& module) { detail::bind_reshape_view( module, @@ -53,6 +78,11 @@ void py_bind_reshape_view(pybind11::module& module) { * input_tensor: Input Tensor. * new_shape: New shape of tensor. + Keyword Args: + * :attr:`memory_config`: Memory Config of the output tensor. Default is to match input tensor memory config + * :attr:`queue_id`: command queue id. Default is 0. + * :attr:`pad_value` (number): Value to pad the output tensor. Default is 0 + Returns: ttnn.Tensor: the output tensor with the new shape. From e2bbda5e4b57ac9351d3a9153cf77223c92e9800 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Thu, 28 Nov 2024 22:34:25 +0000 Subject: [PATCH 02/31] #15558 edited comment to mention this issue --- .../ttnn/operations/data_movement/reshape_view/reshape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index f2137dfc2e3..bb6ec763ceb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -28,7 +28,7 @@ namespace ttnn::operations::data_movement { namespace detail { ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { - //This function is due to embedding issue + //This function is due to embedding issue 15558, once the issue is fixed we want to delete it tt::log_warning("host_reshape is deprecated and will be removed in the near future"); if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) { return tensor.reshape(shape); @@ -287,7 +287,7 @@ ttnn::Tensor ReshapeViewOperation::invoke( } if(shape.logical_shape().volume() != tensor.get_logical_volume()) { - //This is a completely incorrect test but it is due to issue + //This is a completely incorrect test but it is due to issue 15558 return detail::host_reshape(tensor, shape); } // Catch-all From 43b0bd23d95475c030306025fca3b5f017d947ad Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 13:42:37 +0000 Subject: [PATCH 03/31] #0: move tt_memmove to common library and ensure tilize/untilize is only ever called on 3D shapes --- .../data_movement/common/kernels/common.hpp | 35 ++++++++ .../device/device/rm_reshape_interleaved.cpp | 41 +-------- .../data_movement/reshape_view/reshape.cpp | 85 +++++++++++++++++-- .../data_movement/reshape_view/reshape.hpp | 2 + 4 files changed, 120 insertions(+), 43 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 5a7e3472c57..df7d5d7dfca 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -7,8 +7,43 @@ // Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking // issues +#define MASK_64 0xFFFFFFFFFFFFFFC0 +#define OFFSET_64 0x000000000000003F +#define MASK_16 0xFFFFFFFFFFFFFFF0 +#define OFFSET_16 0x000000000000000F + namespace tt::data_movement::common { +template +FORCE_INLINE +void tt_memmove ( + const uint32_t dst_l1_addr, + const uint64_t src_l1_addr, + const uint32_t bytes) +{ + //Function performs a memory copy between two l1 addresses in the local core + //Uses noc_async_read when possible to copy the data over + //Set guaranteed 16B alligned to true if the source and destination are externally guaranteed to be 16B alligned (dangerous) + //Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later + if constexpr (guaranteed_16B_alligned) + { + noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); + if constexpr (!copy_async) {noc_async_read_barrier();} + } + else + { + if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16)) + { + noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); + if constexpr (!copy_async) {noc_async_read_barrier();} + } + else + { + memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes)); + } + } +} + // this function is useful for converting bfloat16 values to float32 FORCE_INLINE float bfloat16_to_float32(uint16_t bfloat16_data) { uint32_t bits = static_cast(bfloat16_data) << 16; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 96b0650a4fa..98f61f04633 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -32,40 +32,7 @@ Runtime arguments #include #include #include "debug/dprint.h" // required in all kernels using DPRINT - -#define MASK_64 0xFFFFFFFFFFFFFFC0 -#define OFFSET_64 0x000000000000003F -#define MASK_16 0xFFFFFFFFFFFFFFF0 -#define OFFSET_16 0x000000000000000F - - -template -FORCE_INLINE -void tt_memmove ( - const uint32_t dst_l1_addr, - const uint64_t src_l1_addr, - const uint32_t bytes) -{ - //Uses noc_async_read when possible to copy the data over - if constexpr (guaranteed_16B_alligned) - { - noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); - noc_async_read_barrier(); - } - else - { - if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16)) - { - noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); - noc_async_read_barrier(); - } - else - { - memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes)); - } - } -} - +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" void kernel_main() { //We are guranteed to be in 2D going to 2D @@ -148,14 +115,14 @@ void kernel_main() { noc_async_write_barrier(); if (readable < writable) { - tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); + tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); writable = writable -readable; write_offset = write_offset + readable; readable = 0; } else if (readable == writable) { - tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); + tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); #if ((dst_aligned_to_16)) noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); #else @@ -182,7 +149,7 @@ void kernel_main() { { //writable < readable - tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, writable); + tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, writable); #if ((dst_aligned_to_16)) noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); #else diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index bb6ec763ceb..ec761a1649b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -27,6 +27,29 @@ namespace ttnn::operations::data_movement { namespace detail { +ttnn::Tensor convert_tile_to_rm( + const ttnn::Tensor& tensor, + const ttnn::Shape& shape, + const uint32_t tile_first_dim, + const uint32_t tile_second_dim, + const MemoryConfig &memory_config, + const uint8_t queue_id, + const PadValue &pad_value +) { + //Convert the 3D->3D reshaping to row major and back to tile + auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + rm_tensor = convert_tensor_to_rm_reshape_convert_back_to_orig_layout( + rm_tensor, + shape, + tile_first_dim, + tile_second_dim, + memory_config, + queue_id, + pad_value + ); + rm_tensor = ttnn::to_layout(rm_tensor, ttnn::Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr); + return rm_tensor; +} ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { //This function is due to embedding issue 15558, once the issue is fixed we want to delete it tt::log_warning("host_reshape is deprecated and will be removed in the near future"); @@ -82,6 +105,7 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout( { second_dim = second_dim * tensor_shape[i]; } + //Call reshape with the equivalent data 2D Row Major input tensor return fix_shape_and_perform_reshape_on_2D_RM( PerformView ( @@ -99,22 +123,71 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout( } else if (layout == ttnn::Layout::TILE) { - auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); - rm_tensor = convert_tensor_to_rm_reshape_convert_back_to_orig_layout( - rm_tensor, - shape, + uint32_t third_dim = 1; + //Collapse into the third last dimension + for (int i=0; i 1 ? tensor_shape[-2] : 1; + //Call reshape with the equivalent data 3D Tile input tensor + return fix_shape_and_perform_reshape_on_3D_TILE( + PerformView + ( + tensor, + ttnn::Shape{third_dim,second_dim,tensor_shape[-1]}, + tile_first_dim, + tile_second_dim + ) + ,shape, tile_first_dim, tile_second_dim, memory_config, queue_id, pad_value ); - rm_tensor = ttnn::to_layout(rm_tensor, ttnn::Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr); - return rm_tensor; } TT_FATAL(false, "layout is neither tile nor row major"); } + +ttnn::Tensor fix_shape_and_perform_reshape_on_3D_TILE( + const ttnn::Tensor& tensor, + const ttnn::Shape& shape, + const uint32_t tile_first_dim, + const uint32_t tile_second_dim, + const MemoryConfig &memory_config, + const uint8_t queue_id, + const PadValue &pad_value + ) +{ + //This function turns a TILE 3D->MD into an equivalent 3D->3D conversion and then turns the 3D output back to MD using a 0 cost view + //Collapse into the third last dimension + TT_FATAL((shape.rank()!=0), "can't do reshape to rank 0 tensor"); + uint32_t third_dim = 1; + for (int i=0; i 1 ? shape[-2] : 1; + return PerformView + ( + convert_tile_to_rm( + tensor, + ttnn::Shape{third_dim,second_dim,shape[-1]}, + tile_first_dim, + tile_second_dim, + memory_config, + queue_id, + pad_value + ), + shape, + tile_first_dim, + tile_second_dim); +} + ttnn::Tensor fix_shape_and_perform_reshape_on_2D_RM( const ttnn::Tensor& tensor, const ttnn::Shape& shape, diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp index 208d6b526d8..566a0d1250c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp @@ -14,7 +14,9 @@ namespace detail { ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape); ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id, const PadValue &pad_value); ttnn::Tensor fix_shape_and_perform_reshape_on_2D_RM(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id); + ttnn::Tensor fix_shape_and_perform_reshape_on_3D_TILE( const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id, const PadValue &pad_value); ttnn::Tensor perform_reshape_on_2D_RM(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const MemoryConfig &memory_config, const uint8_t queue_id); + ttnn::Tensor convert_tile_to_rm(const ttnn::Tensor& tensor, const ttnn::Shape& shape, const uint32_t tile_first_dim, const uint32_t tile_second_dim, const MemoryConfig &memory_config, const uint8_t queue_id, const PadValue &pad_value); } ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape); From 8a94c1beb37685f7080297bcac29f7f1348baebd Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 14:45:33 +0000 Subject: [PATCH 04/31] #0: added corrector for implied shape dimensions --- .../data_movement/reshape_view/reshape.cpp | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index ec761a1649b..f9465a25edc 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -307,16 +307,44 @@ ttnn::Tensor PerformView(const ttnn::Tensor& tensor, const ttnn::Shape& shape, c return tensor.reshape(shape); } +ttnn::Shape shape_corrector(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { + //Correct the shape to account for inferred dimensions + uint32_t input_volume = tensor.get_logical_volume(); + uint32_t output_volume = 1; + uint32_t inferred_dim = -1; + for (uint32_t i=0; i< shape.rank(); i++) { + if (((int)(shape[i])) == -1) { + if (inferred_dim != -1) { + TT_FATAL(false, "Only one dimension can be inferred in reshape"); + } + inferred_dim = i; + } else { + output_volume = output_volume * shape[i]; + } + } + if (inferred_dim == -1) + { + return shape; + } + + uint32_t implied_dim_value = (output_volume == 0) ? 0: input_volume/output_volume; + ttnn::SmallVector new_shape(shape.size()); + auto old_shape = shape.logical_shape().view(); + std::copy(old_shape.begin(), old_shape.end(), new_shape.begin()); + new_shape[inferred_dim] = implied_dim_value; + return ttnn::Shape(std::move(new_shape)); +} + ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, - const ttnn::Shape& shape, + const ttnn::Shape& input_shape, const std::optional &memory_config, const uint8_t queue_id, const std::optional &pad_value ) { auto layout = tensor.get_layout(); auto tensor_shape = tensor.get_shape(); - + const ttnn::Shape shape = shape_corrector(tensor, input_shape); // First Case, No reshape Required if (tensor_shape == shape) { return tensor; From 4fd8c45dd4bf73a52bde0a4ee8fd2bc6a6c29791 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 14:54:01 +0000 Subject: [PATCH 05/31] #13889: Added test to prove this issue is resolved --- tests/ttnn/unit_tests/test_reshape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index f3ae5e8112f..b0ce6e6e9fd 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -293,6 +293,7 @@ def test_reshape_tile_layout_only_change_shape(device): ((1, 1445, 192), (1445, 192)), ((1, 256), (1, 1, 256)), ((16, 1, 32), (16, 1, 32)), + ((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889 ], ) @pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) From 285d1fac3ad4abd9a4d8c74318cbac56ab2eedd2 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 15:06:46 +0000 Subject: [PATCH 06/31] #12153: Adding test to verify that issue is resolved --- tests/ttnn/unit_tests/test_reshape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index b0ce6e6e9fd..b0d64b72e20 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -294,6 +294,7 @@ def test_reshape_tile_layout_only_change_shape(device): ((1, 256), (1, 1, 256)), ((16, 1, 32), (16, 1, 32)), ((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889 + ((2888, 49, 96), (8, 19, 19, 7, 7, 96)), # issue 12153 ], ) @pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) From 7d40412ccccae2b00ed8c1bb9bba0b6aea835b3d Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 16:30:06 +0000 Subject: [PATCH 07/31] #15048: being more careful about bandaid for issues #15137 and #13338 --- tests/ttnn/unit_tests/test_reshape.py | 20 ++++++++++++ .../data_movement/reshape_view/reshape.cpp | 32 ++++++++++--------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index b0d64b72e20..a21a99ce3fd 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -306,6 +306,26 @@ def test_reshape_tile_with_padding(input_shape, output_shape, layout, device): ttnn_output = ttnn.reshape(input_tensor, output_shape) assert layout == ttnn_output.layout output = ttnn.to_torch(ttnn_output) + assert_with_pcc(torch_result, output, 0.9999) + + +# issue 15048 +def test_broken_reshape(device): + src_shape = (1, 56, 56, 64) + target_shape = (1, 1, 56 * 56, 64) + torch_input_tensor = torch.randn(src_shape, dtype=torch.bfloat16) + torch_result = torch_input_tensor.reshape(target_shape) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + ttnn_output = ttnn.reshape(input_tensor, target_shape) + output = ttnn.to_torch(ttnn_output) assert_with_pcc(torch_result, output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index f9465a25edc..55b38e45a16 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -38,16 +38,14 @@ ttnn::Tensor convert_tile_to_rm( ) { //Convert the 3D->3D reshaping to row major and back to tile auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); - rm_tensor = convert_tensor_to_rm_reshape_convert_back_to_orig_layout( + rm_tensor = ReshapeViewOperation::invoke( rm_tensor, shape, - tile_first_dim, - tile_second_dim, memory_config, queue_id, pad_value ); - rm_tensor = ttnn::to_layout(rm_tensor, ttnn::Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr); + rm_tensor = ttnn::to_layout(rm_tensor, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); return rm_tensor; } ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { @@ -371,23 +369,27 @@ ttnn::Tensor ReshapeViewOperation::invoke( ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || //Its row major (tensor_shape_second_last_dim==shape_second_last_dim) || //Second last dimension is the same (shape_second_last_dim % tile_second_dim==0 && tensor_shape_second_last_dim % tile_first_dim==0)); //There is no padding on the second last dimension - bool tile_tensor_view_reshape_possible = - (layout == ttnn::Layout::TILE and shape.with_tile_padding().rank() >= 2 and - shape.with_tile_padding()[-2] % ttnn::TILE_SIZE == 0 and - shape.with_tile_padding()[-1] % ttnn::TILE_SIZE == 0 and - tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1]); - - if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) or tile_tensor_view_reshape_possible) { - // This case has been allowed in the past though it means introducing padding values to the data - return tensor.reshape(shape); - } - + if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE))) { + // This case has been allowed in the past though it means introducing padding values to the data + return tensor.reshape(shape); + } if (this_is_view) { return PerformView(tensor,shape, tile_first_dim, tile_second_dim); } if(shape.logical_shape().volume() != tensor.get_logical_volume()) { + //This is completely incorrect but it is due to issue 15137 or issue 15558 + bool tile_tensor_view_reshape_possible = + (layout == ttnn::Layout::TILE and shape.with_tile_padding().rank() >= 2 and + shape.with_tile_padding()[-2] % ttnn::TILE_SIZE == 0 and + shape.with_tile_padding()[-1] % ttnn::TILE_SIZE == 0 and + tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1]); + + if (tile_tensor_view_reshape_possible) { + // This case has been allowed in the past though it means introducing padding values to the data + return tensor.reshape(shape); + } //This is a completely incorrect test but it is due to issue 15558 return detail::host_reshape(tensor, shape); } From 7335c1169d192dd5c26a1118d210330c89d98344 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 16:38:45 +0000 Subject: [PATCH 08/31] #14676: Adding test to verify that this issue is resolved --- tests/ttnn/unit_tests/test_reshape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index a21a99ce3fd..f942b9f02b1 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -295,6 +295,7 @@ def test_reshape_tile_layout_only_change_shape(device): ((16, 1, 32), (16, 1, 32)), ((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889 ((2888, 49, 96), (8, 19, 19, 7, 7, 96)), # issue 12153 + ((128, 1, 1, 128), (128, 128)), # issue 14676 ], ) @pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) From 6110d4769ff9c5c759ce0f36ee1b282bfc026f1c Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 17:03:50 +0000 Subject: [PATCH 09/31] #0: adding libraries for memmove to common --- .../ttnn/operations/data_movement/common/kernels/common.hpp | 3 ++- .../reshape_view/device/device/rm_reshape_interleaved.cpp | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index df7d5d7dfca..30e7d597a90 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -6,7 +6,8 @@ // It's best to copy and paste the functions in rather than include the header as code size will likely explode // Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking // issues - +#include +#include #define MASK_64 0xFFFFFFFFFFFFFFC0 #define OFFSET_64 0x000000000000003F #define MASK_16 0xFFFFFFFFFFFFFFF0 diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 98f61f04633..5e4b446e3fd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -29,8 +29,6 @@ Runtime arguments */ #include #include "dataflow_api.h" -#include -#include #include "debug/dprint.h" // required in all kernels using DPRINT #include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" From cbc2830a3aedd52653f68276a439258dcd80f598 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 17:28:17 +0000 Subject: [PATCH 10/31] #14513: Adding test to prove issue is resolved --- tests/ttnn/unit_tests/test_reshape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index f942b9f02b1..56f4c5e29b0 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -296,6 +296,7 @@ def test_reshape_tile_layout_only_change_shape(device): ((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889 ((2888, 49, 96), (8, 19, 19, 7, 7, 96)), # issue 12153 ((128, 1, 1, 128), (128, 128)), # issue 14676 + ((5, 4, 208, 156), (3, 13, 8, 2080)), # issue 14513 ], ) @pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) From 055dc0f817ed5f61907f9b8614b734851c505905 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 22:39:34 +0000 Subject: [PATCH 11/31] #15269: added multi-core support --- .../device/device/rm_reshape_interleaved.cpp | 52 ++++----- .../device/host/reshape_rm_host_prep.cpp | 107 ++++++++++++------ 2 files changed, 93 insertions(+), 66 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 5e4b446e3fd..3be51578b04 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -44,18 +44,21 @@ void kernel_main() { const uint32_t read_start_page = get_arg_val(5); const uint32_t read_end_page = get_arg_val(6); const uint32_t write_start_page = get_arg_val(7); + const uint32_t write_start_offset = get_arg_val(8); //cb_id_in0 is a circular buffer with 1 source_page_size_bytes page if no alignment needed //source_read_size_bytes otherwise - const uint32_t cb_id_in0 = get_arg_val(8); + const uint32_t cb_id_in0 = get_arg_val(9); //cb_id_in1 is a circular buffer with 1 dest_page_size_bytes+16 (rounded up to next 64B) page - const uint32_t cb_id_in1 = get_arg_val(9); - - + const uint32_t cb_id_in1 = get_arg_val(10); + const uint32_t nop = get_arg_val(11); constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; #define src_aligned_to_64 get_compile_time_arg_val(1) == 1 #define src_aligned_to_16 get_compile_time_arg_val(2) == 1 - #define dst_aligned_to_16 get_compile_time_arg_val(3) == 1 + if (nop == 1) + { + return; + } const InterleavedAddrGen s = { .bank_base_address = src_addr, @@ -71,8 +74,9 @@ void kernel_main() { uint32_t read_offset = 0; uint32_t write_page = write_start_page; uint32_t readable = 0; + uint32_t end_to_write = 0; uint32_t transaction = 0; - uint32_t writable = dest_page_size_bytes; + uint32_t writable = dest_page_size_bytes - write_start_offset; //cb_id_in0 is a CB source_read_size_bytes page size, 1 page //cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page cb_reserve_back(cb_id_in0, 1); @@ -81,13 +85,9 @@ void kernel_main() { const uint32_t dest_buffer = get_write_ptr(cb_id_in1); uint64_t dst_noc_addr = get_noc_addr(write_page, d); -#if (dst_aligned_to_16) - uint32_t write_offset = 0; -#else - uint32_t write_offset = dst_noc_addr&OFFSET_16; - uint32_t begin_write_offset = write_offset; -#endif - for (uint32_t i = read_start_page; i <= read_end_page; i++) { + uint64_t write_offset = dst_noc_addr&OFFSET_16 + write_start_offset; + uint64_t begin_write_offset = write_offset; + for (uint32_t i = read_start_page; i < read_end_page; i++) { //Read from source uint64_t src_noc_addr = s.get_noc_addr(i,0); @@ -117,15 +117,19 @@ void kernel_main() { writable = writable -readable; write_offset = write_offset + readable; readable = 0; + end_to_write = end_to_write + readable; + if (i == read_end_page-1) + { + noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, end_to_write); + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); + return; + } } else if (readable == writable) { tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); -#if ((dst_aligned_to_16)) - noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); -#else noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); -#endif writable = dest_page_size_bytes; readable = 0; if (i == read_end_page-1) @@ -134,35 +138,25 @@ void kernel_main() { cb_push_back(cb_id_in1, 1); return; } + end_to_write = 0; write_page++; dst_noc_addr = get_noc_addr(write_page, d); -#if ((dst_aligned_to_16)) - write_offset=0; -#else write_offset = dst_noc_addr&OFFSET_16; begin_write_offset = write_offset; -#endif } else { //writable < readable tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, writable); -#if ((dst_aligned_to_16)) - noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); -#else noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); -#endif + end_to_write = 0; readable = readable - writable; read_offset = read_offset + writable; write_page++; dst_noc_addr = get_noc_addr(write_page, d); -#if ((dst_aligned_to_16)) - write_offset=0; -#else write_offset = dst_noc_addr&OFFSET_16; begin_write_offset = write_offset; -#endif writable = dest_page_size_bytes; } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index 2c7410f7a6b..4956c8be3df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -35,11 +35,16 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T //get datum size tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); const uint32_t data_size = input.element_size(); - CoreRange core({0, 0}, {0, 0}); - tt::tt_metal::Device *device = input.device(); + //Multi device pre-computation + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + uint32_t responsibility = (input_log_shape[-2]-1)/num_cores_total + 1; //How many input pages each core is responsible for tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); @@ -48,56 +53,84 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128; uint32_t read_start_page = 0; - uint32_t read_end_page = input_log_shape[-2]; uint32_t write_start_page = 0; + uint32_t write_start_offset = 0; tt::tt_metal::Buffer *src_buffer = input.buffer(); tt::tt_metal::Buffer *dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - const uint32_t cb_size0 = source_read_size_bytes; - const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; - - uint32_t src0_cb_index = 0; - uint32_t src1_cb_index = 1; - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_size0*2, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, cb_size0); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, cb_size1); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); - //set the runtime args - //set the compile time args uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector compile_time_args = { (std::uint32_t) src0_is_dram, (std::uint32_t) (source_page_size_bytes%64==0) ? 1 : 0, - (std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0, - (std::uint32_t) (dest_page_size_bytes%16==0) ? 1 : 0, + (std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0 }; tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); - std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - read_start_page, - read_end_page, - write_start_page, - src0_cb_index, - src1_cb_index - }; - tt::tt_metal::SetRuntimeArgs( - program, - reader_kernel_id, - core, - reader_runtime_args + total_cores, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args) ); + + const uint32_t cb_size0 = source_read_size_bytes; + const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; + uint32_t done = 0; + for(int i=0; i input_log_shape[-2] ? input_log_shape[-2] : end_of_read; + read_start_page = end_of_read; + write_start_page = write_start_page + write_start_offset/dest_page_size_bytes; + write_start_offset = write_start_offset % dest_page_size_bytes; + while (write_start_offset != 0 && end_of_read < input_log_shape[-2] ) + { + + read_start_page++; + end_of_read++; + write_start_offset = write_start_offset + source_page_size_bytes; + write_start_page = write_start_page + write_start_offset/dest_page_size_bytes; + write_start_offset = write_start_offset % dest_page_size_bytes; + } + std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + done ? 0 : start_of_read, + done ? 0 : end_of_read, + cur_write_start, + cur_write_offset, + src0_cb_index, + src1_cb_index, + done + + }; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs( + program, + reader_kernel_id, + core, + reader_runtime_args + ); + } return {.program=std::move(program)}; } }; // namespace ttnn::operations::data_movement::rm_reshape From fac56d8ef248acf847d9948c2dd621530f553b03 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 23:07:55 +0000 Subject: [PATCH 12/31] #0: addressing PR comments --- .../reshape_view/device/device/rm_reshape_interleaved.cpp | 2 +- .../ttnn/operations/data_movement/reshape_view/reshape.cpp | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 3be51578b04..02fb9911daa 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -54,7 +54,7 @@ void kernel_main() { constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; #define src_aligned_to_64 get_compile_time_arg_val(1) == 1 #define src_aligned_to_16 get_compile_time_arg_val(2) == 1 - + //Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core if (nop == 1) { return; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 55b38e45a16..7c0de2deb67 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -340,6 +340,7 @@ ttnn::Tensor ReshapeViewOperation::invoke( const uint8_t queue_id, const std::optional &pad_value ) { + MemoryConfig mem_config = memory_config.value_or(tensor.memory_config()); auto layout = tensor.get_layout(); auto tensor_shape = tensor.get_shape(); const ttnn::Shape shape = shape_corrector(tensor, input_shape); @@ -366,6 +367,8 @@ ttnn::Tensor ReshapeViewOperation::invoke( const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2]:1; const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2]:1; bool this_is_view = (tensor_shape[-1] == shape[-1]) && + (mem_config.is_sharded()==tensor.memory_config().is_sharded()) && + (mem_config.is_l1()==tensor.memory_config().is_l1()) ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || //Its row major (tensor_shape_second_last_dim==shape_second_last_dim) || //Second last dimension is the same (shape_second_last_dim % tile_second_dim==0 && tensor_shape_second_last_dim % tile_first_dim==0)); //There is no padding on the second last dimension @@ -400,7 +403,7 @@ ttnn::Tensor ReshapeViewOperation::invoke( shape, tile_first_dim, tile_second_dim, - memory_config.value_or(tensor.memory_config()), + mem_config, queue_id, pad_value.value_or(default_pad_value) ); From 7ff95d2b443d4b4f243a9e6482a0a67f0cc5b22a Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Fri, 29 Nov 2024 23:16:16 +0000 Subject: [PATCH 13/31] #0: small oops --- .../data_movement/reshape_view/reshape.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 7c0de2deb67..76752416694 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -366,12 +366,13 @@ ttnn::Tensor ReshapeViewOperation::invoke( const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2]:1; const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2]:1; - bool this_is_view = (tensor_shape[-1] == shape[-1]) && - (mem_config.is_sharded()==tensor.memory_config().is_sharded()) && - (mem_config.is_l1()==tensor.memory_config().is_l1()) - ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || //Its row major - (tensor_shape_second_last_dim==shape_second_last_dim) || //Second last dimension is the same - (shape_second_last_dim % tile_second_dim==0 && tensor_shape_second_last_dim % tile_first_dim==0)); //There is no padding on the second last dimension + bool this_is_view = + (tensor_shape[-1] == shape[-1]) && (mem_config.is_sharded() == tensor.memory_config().is_sharded()) && + (mem_config.is_l1() == tensor.memory_config().is_l1()) && + ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || // Its row major + (tensor_shape_second_last_dim == shape_second_last_dim) || // Second last dimension is the same + (shape_second_last_dim % tile_second_dim == 0 && + tensor_shape_second_last_dim % tile_first_dim == 0)); // There is no padding on the second last dimension if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE))) { // This case has been allowed in the past though it means introducing padding values to the data return tensor.reshape(shape); From 2edceb59e850aabc4451aac88ef1e1f3e69e4221 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 16:42:33 +0000 Subject: [PATCH 14/31] #15269: Host code optimizations --- .../device/host/reshape_rm_host_prep.cpp | 116 ++++++++++-------- .../data_movement/reshape_view/reshape.cpp | 5 +- 2 files changed, 65 insertions(+), 56 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index 4956c8be3df..91d6177e5e4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -44,7 +44,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); - uint32_t responsibility = (input_log_shape[-2]-1)/num_cores_total + 1; //How many input pages each core is responsible for tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); @@ -54,11 +53,16 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128; uint32_t read_start_page = 0; uint32_t write_start_page = 0; - uint32_t write_start_offset = 0; tt::tt_metal::Buffer *src_buffer = input.buffer(); tt::tt_metal::Buffer *dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - + // Find how many input pages each core is responsible for so that we always start at the begining of a read and + // write page Since the logical volumes match, we are guaranteed that the very last page is aligned + uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; + while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { + responsibility++; + } + const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector compile_time_args = { (std::uint32_t) src0_is_dram, @@ -76,60 +80,64 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T const uint32_t cb_size0 = source_read_size_bytes; const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; uint32_t done = 0; - for(int i=0; i reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + 0, + 0, + 0, + 0, + 0, + 1, + 1 - uint32_t src0_cb_index = 0; - uint32_t src1_cb_index = 1; - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_size0*2, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, cb_size0); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, cb_size1); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); - //set the runtime args - //set the compile time args - uint32_t start_of_read = read_start_page; - uint32_t end_of_read = read_start_page + responsibility; - uint32_t cur_write_start = write_start_page; - uint32_t cur_write_offset = write_start_offset; - write_start_offset = write_start_offset + responsibility * source_page_size_bytes; - end_of_read = end_of_read > input_log_shape[-2] ? input_log_shape[-2] : end_of_read; - read_start_page = end_of_read; - write_start_page = write_start_page + write_start_offset/dest_page_size_bytes; - write_start_offset = write_start_offset % dest_page_size_bytes; - while (write_start_offset != 0 && end_of_read < input_log_shape[-2] ) - { + }; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else { + // Create the circular buffers + uint32_t src0_cb_index = 0; + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_size0 * 2, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size0); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size1); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + // set the runtime args + // set the compile time args + uint32_t start_of_read = read_start_page; + uint32_t end_of_read = read_start_page + responsibility; - read_start_page++; - end_of_read++; - write_start_offset = write_start_offset + source_page_size_bytes; - write_start_page = write_start_page + write_start_offset/dest_page_size_bytes; - write_start_offset = write_start_offset % dest_page_size_bytes; - } - std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - done ? 0 : start_of_read, - done ? 0 : end_of_read, - cur_write_start, - cur_write_offset, - src0_cb_index, - src1_cb_index, - done + std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + start_of_read, + end_of_read, + write_start_page, + 0, + src0_cb_index, + src1_cb_index, + done - }; - done = (end_of_read == input_log_shape[-2]) ? 1 : 0; - tt::tt_metal::SetRuntimeArgs( - program, - reader_kernel_id, - core, - reader_runtime_args - ); + }; + write_start_page += write_jump; + read_start_page = end_of_read; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } + } } return {.program=std::move(program)}; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 76752416694..9b4be6f13b0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -37,7 +37,8 @@ ttnn::Tensor convert_tile_to_rm( const PadValue &pad_value ) { //Convert the 3D->3D reshaping to row major and back to tile - auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + auto rm_tensor = + ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr); rm_tensor = ReshapeViewOperation::invoke( rm_tensor, shape, @@ -45,7 +46,7 @@ ttnn::Tensor convert_tile_to_rm( queue_id, pad_value ); - rm_tensor = ttnn::to_layout(rm_tensor, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + rm_tensor = ttnn::to_layout(rm_tensor, ttnn::TILE_LAYOUT, rm_tensor.get_dtype(), memory_config, (Device*)nullptr); return rm_tensor; } ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { From ad28b55a6d4012d3f146e097dea430313a295a22 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 16:57:57 +0000 Subject: [PATCH 15/31] #15269: Move compute buffers to compile time --- .../device/device/rm_reshape_interleaved.cpp | 16 +++---- .../device/host/reshape_rm_host_prep.cpp | 45 +++++++++---------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 02fb9911daa..9432e112052 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -45,15 +45,13 @@ void kernel_main() { const uint32_t read_end_page = get_arg_val(6); const uint32_t write_start_page = get_arg_val(7); const uint32_t write_start_offset = get_arg_val(8); - //cb_id_in0 is a circular buffer with 1 source_page_size_bytes page if no alignment needed - //source_read_size_bytes otherwise - const uint32_t cb_id_in0 = get_arg_val(9); - //cb_id_in1 is a circular buffer with 1 dest_page_size_bytes+16 (rounded up to next 64B) page - const uint32_t cb_id_in1 = get_arg_val(10); - const uint32_t nop = get_arg_val(11); - constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; - #define src_aligned_to_64 get_compile_time_arg_val(1) == 1 - #define src_aligned_to_16 get_compile_time_arg_val(2) == 1 + const uint32_t nop = get_arg_val(9); + + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; +#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 +#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); //Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core if (nop == 1) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index 91d6177e5e4..3c374158b46 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -64,21 +64,33 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T } const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + const uint32_t cb_size0 = source_read_size_bytes; + const uint32_t cb_size1 = ((dest_page_size_bytes - 1) & MASK_64) + 80; + + uint32_t src0_cb_index = 0; + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_size0 * 2, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size0); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size1); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); + std::vector compile_time_args = { - (std::uint32_t) src0_is_dram, - (std::uint32_t) (source_page_size_bytes%64==0) ? 1 : 0, - (std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0 - }; + (std::uint32_t)src0_is_dram, + (std::uint32_t)(source_page_size_bytes % 64 == 0) ? 1 : 0, + (std::uint32_t)(source_page_size_bytes % 16 == 0) ? 1 : 0, + src0_cb_index, + src1_cb_index}; tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp", total_cores, - tt::tt_metal::ReaderDataMovementConfig(compile_time_args) - ); - - const uint32_t cb_size0 = source_read_size_bytes; - const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; + tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); uint32_t done = 0; for (int core_x = 0; core_x < num_cores_x; core_x++) { for (int core_y = 0; core_y < num_cores_y; core_y++) { @@ -94,24 +106,13 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T 0, 0, 0, - 0, - 1, 1 }; tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); } else { // Create the circular buffers - uint32_t src0_cb_index = 0; - uint32_t src1_cb_index = 1; - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig(cb_size0 * 2, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, cb_size0); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - tt::tt_metal::CircularBufferConfig cb_src1_config = - tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, cb_size1); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + // set the runtime args // set the compile time args uint32_t start_of_read = read_start_page; @@ -127,8 +128,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T end_of_read, write_start_page, 0, - src0_cb_index, - src1_cb_index, done }; From 55013247f2d7f48c8b5acac970efca4b34a83280 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 17:19:25 +0000 Subject: [PATCH 16/31] #15269: adding override_runtime_args_callback --- .../device/host/reshape_rm_host_prep.cpp | 84 +++++++++++++++++-- 1 file changed, 79 insertions(+), 5 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index 3c374158b46..f6ecc4f1bec 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -96,7 +96,7 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T for (int core_y = 0; core_y < num_cores_y; core_y++) { CoreCoord core = {core_x, core_y}; if (done == 1) { - std::vector reader_runtime_args = { + const std::vector reader_runtime_args = { src_buffer->address(), dst_buffer->address(), source_page_size_bytes, @@ -115,10 +115,10 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T // set the runtime args // set the compile time args - uint32_t start_of_read = read_start_page; - uint32_t end_of_read = read_start_page + responsibility; + const uint32_t start_of_read = read_start_page; + const uint32_t end_of_read = read_start_page + responsibility; - std::vector reader_runtime_args = { + const std::vector reader_runtime_args = { src_buffer->address(), dst_buffer->address(), source_page_size_bytes, @@ -138,6 +138,80 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T } } } - return {.program=std::move(program)}; + auto override_runtime_args_callback = [reader_kernel_id, compute_with_storage_grid_size]( + const void* operation, + const Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + auto input = input_tensors.at(0); + auto output = output_tensors.at(0); + const uint32_t data_size = input.element_size(); + tt::tt_metal::Buffer* src_buffer = input.buffer(); + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; + uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; + uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 128; + uint32_t read_start_page = 0; + uint32_t write_start_page = 0; + uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; + while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { + responsibility++; + } + const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; + uint32_t done = 0; + for (int core_x = 0; core_x < num_cores_x; core_x++) { + for (int core_y = 0; core_y < num_cores_y; core_y++) { + CoreCoord core = {core_x, core_y}; + if (done == 1) { + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + 0, + 0, + 0, + 0, + 1 + + }; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else { + // Create the circular buffers + + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + const uint32_t end_of_read = read_start_page + responsibility; + + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + start_of_read, + end_of_read, + write_start_page, + 0, + done + + }; + write_start_page += write_jump; + read_start_page = end_of_read; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } + } + } + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } }; // namespace ttnn::operations::data_movement::rm_reshape From 109615f431db3e36530fe254af07ec17c33decd9 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 22:41:51 +0000 Subject: [PATCH 17/31] #15269: improve the tt_memmove to use read or write noc as per user needs --- .../device/device/rm_reshape_interleaved.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 9432e112052..3950c1f276c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -81,6 +81,8 @@ void kernel_main() { cb_reserve_back(cb_id_in1, 1); const uint32_t source_buffer = get_write_ptr(cb_id_in0); const uint32_t dest_buffer = get_write_ptr(cb_id_in1); + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); uint64_t dst_noc_addr = get_noc_addr(write_page, d); uint64_t write_offset = dst_noc_addr&OFFSET_16 + write_start_offset; @@ -111,29 +113,26 @@ void kernel_main() { noc_async_write_barrier(); if (readable < writable) { - tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); writable = writable -readable; write_offset = write_offset + readable; readable = 0; end_to_write = end_to_write + readable; if (i == read_end_page-1) { - noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, end_to_write); - cb_push_back(cb_id_in0, 1); - cb_push_back(cb_id_in1, 1); + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); return; } } else if (readable == writable) { - tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, readable); + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); writable = dest_page_size_bytes; readable = 0; - if (i == read_end_page-1) - { - cb_push_back(cb_id_in0, 1); - cb_push_back(cb_id_in1, 1); + if (i == read_end_page - 1) { return; } end_to_write = 0; @@ -146,7 +145,8 @@ void kernel_main() { { //writable < readable - tt::data_movement::common::tt_memmove(dest_buffer+write_offset, source_buffer + read_offset, writable); + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, writable); noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); end_to_write = 0; readable = readable - writable; @@ -159,7 +159,5 @@ void kernel_main() { } } } - cb_push_back(cb_id_in0, 1); - cb_push_back(cb_id_in1, 1); return; } From 631b1c51343ab9859d88d1250a96a8611c56f70b Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 22:42:56 +0000 Subject: [PATCH 18/31] #15269: improve the tt_memmove to use read or write datamover --- .../data_movement/common/kernels/common.hpp | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 30e7d597a90..d559630736e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -15,32 +15,43 @@ namespace tt::data_movement::common { -template -FORCE_INLINE -void tt_memmove ( - const uint32_t dst_l1_addr, - const uint64_t src_l1_addr, - const uint32_t bytes) -{ +template +FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_addr, const uint32_t bytes) { //Function performs a memory copy between two l1 addresses in the local core //Uses noc_async_read when possible to copy the data over //Set guaranteed 16B alligned to true if the source and destination are externally guaranteed to be 16B alligned (dangerous) //Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later - if constexpr (guaranteed_16B_alligned) - { - noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); - if constexpr (!copy_async) {noc_async_read_barrier();} - } - else - { - if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16)) - { - noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); - if constexpr (!copy_async) {noc_async_read_barrier();} + if constexpr (use_read_datamover) { + if constexpr (guaranteed_16B_alligned) { + noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); + if constexpr (!copy_async) { + noc_async_read_barrier(); + } + } else { + if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) { + noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); + if constexpr (!copy_async) { + noc_async_read_barrier(); + } + } else { + memmove((void*)(dst_l1_addr), (void*)(src_l1_addr), (size_t)(bytes)); + } } - else - { - memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes)); + } else { + if constexpr (guaranteed_16B_alligned) { + noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); + if constexpr (!copy_async) { + noc_async_write_barrier(); + } + } else { + if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) { + noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); + if constexpr (!copy_async) { + noc_async_write_barrier(); + } + } else { + memmove((void*)(dst_l1_addr), (void*)(src_l1_addr), (size_t)(bytes)); + } } } } From 10de9a27f76721c105930eb9d5ebc7e55484ff3b Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 22:45:26 +0000 Subject: [PATCH 19/31] #15269: add broken multi risk code --- .../device/rm_reshape_interleaved_reader.cpp | 163 +++++++++++++ .../device/rm_reshape_interleaved_writer.cpp | 193 +++++++++++++++ .../device/host/reshape_rm_host_prep.cpp | 220 +++++++++++++++++- 3 files changed, 574 insertions(+), 2 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp new file mode 100644 index 00000000000..8a5cc4d333c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +NOTE: This function is an improvement on rm_reshape_interleaved.cpp but it has a bug causing a hang for some cases that +needs to be debugged first + +Function reads from RM and writes to RM + +Assumptions: + +Compile arguments +0. src0_is_dram: 1 if source is dram else 0 +1. read_size_is_pow2: 1 if read size is power of 2 else 0 +2. log_base_2_of_page_size: log base 2 of page size +3. write_size_is_pow2: 1 if write size is power of 2 else 0 +4. log_base_2_of_page_size: log base 2 of page size +5. needs_read_allignment: 1 if read needs allignment else 0 +//Needed if BRAM and page size is not multiple of 64 bytes + +Runtime arguments +0. src_addr: source address +1. dst_addr: destination address +2. source_page_size_bytes: source page size in bytes +3. dest_page_size_bytes: destination page size in bytes +4. source_read_size_bytes: source read size in bytes +5. read_start_page: read start page +6. read_end_page: read end page +7. write_start_page: write start page +*/ +#include +#include "dataflow_api.h" +#include "debug/dprint.h" // required in all kernels using DPRINT +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +void kernel_main() { + // We are guranteed to be in 2D going to 2D + + const uint32_t src_addr = get_arg_val(0); + const uint32_t dst_addr = get_arg_val(1); + const uint32_t source_page_size_bytes = get_arg_val(2); + const uint32_t dest_page_size_bytes = get_arg_val(3); + // If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 + // (rounded up to next 16B) + const uint32_t source_read_size_bytes = get_arg_val(4); + const uint32_t read_start_page = get_arg_val(5); + const uint32_t read_end_page = get_arg_val(6); + const uint32_t write_start_page = get_arg_val(7); + const uint32_t write_start_offset = get_arg_val(8); + const uint32_t nop = get_arg_val(9); + const uint64_t ping_read_has_data = get_noc_addr(get_semaphore(get_arg_val(10))); + const uint64_t pong_read_has_data = get_noc_addr(get_semaphore(get_arg_val(11))); + volatile uint32_t* ping_buf_is_free = + reinterpret_cast(get_semaphore(get_arg_val(12))); + volatile uint32_t* pong_buf_is_free = + reinterpret_cast(get_semaphore(get_arg_val(13))); + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; +#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 +#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); + constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(4); + // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this + // core + if (nop == 1) { + return; + } + + const InterleavedAddrGen s = {.bank_base_address = src_addr, .page_size = source_page_size_bytes}; + + uint32_t read_offset = 0; + uint32_t write_page = write_start_page; + uint32_t readable = 0; + uint32_t end_to_write = 0; + uint32_t transaction = 0; + uint32_t writable = dest_page_size_bytes - write_start_offset; + // cb_id_in0 is a CB source_read_size_bytes +4 page size, 1 page + // cb_id_in1 is a CB source_read_size_bytes +4 page size, 1 page + // cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page + cb_reserve_back(cb_id_in0, 1); + cb_reserve_back(cb_id_in1, 1); + cb_reserve_back(cb_id_in2, 1); + const uint32_t source_buffer_ping = get_write_ptr(cb_id_in0); + const uint32_t source_buffer_pong = get_write_ptr(cb_id_in1); + const uint32_t dest_buffer = get_write_ptr(cb_id_in2); + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); + cb_push_back(cb_id_in2, 1); + uint32_t source_buffer; + + volatile tt_l1_ptr std::uint32_t* read_offset_ptr_ping = + (volatile tt_l1_ptr uint32_t*)(source_buffer_ping + source_read_size_bytes); + volatile tt_l1_ptr std::uint32_t* read_offset_ptr_pong = + (volatile tt_l1_ptr uint32_t*)(source_buffer_pong + source_read_size_bytes); + bool is_ping = true; + bool first = true; + bool second = true; + bool third = true; + bool first_pong = true; + bool second_pong = true; + bool third_pong = true; + for (uint32_t i = read_start_page; i < read_end_page; i++) { + // Read from source + uint64_t src_noc_addr = s.get_noc_addr(i, 0); + if (is_ping) { + if (first) { + first = false; + WAYPOINT("FARW"); + } else if (second) { + second = false; + WAYPOINT("SARW"); + } else if (third) { + third = false; + WAYPOINT("TARW"); + } else { + WAYPOINT("ARW"); + } + source_buffer = source_buffer_ping; + noc_semaphore_wait(ping_buf_is_free, 1); + WAYPOINT("ARD"); + } else { + if (first_pong) { + first_pong = false; + WAYPOINT("FBRW"); + } else if (second_pong) { + second_pong = false; + WAYPOINT("SBRW"); + } else { + WAYPOINT("BRW"); + } + source_buffer = source_buffer_pong; + noc_semaphore_wait(pong_buf_is_free, 1); + WAYPOINT("BRD"); + } + +#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) + // Aligned to 64 bytes or 16 bytes but L1 + noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes); + read_offset = 0; +#elif (tensor_is_dram) + // DDR but not alligned to 64 (potentially also not alligned to 16) + noc_async_read(src_noc_addr & MASK_64, source_buffer, source_read_size_bytes); + read_offset = src_noc_addr & OFFSET_64; +#else + // L1 but not alligned to 16 + noc_async_read(src_noc_addr & MASK_16, source_buffer, source_read_size_bytes); + read_offset = src_noc_addr & OFFSET_16; +#endif + if (is_ping) { + *read_offset_ptr_ping = read_offset; + } else { + *read_offset_ptr_pong = read_offset; + } + noc_async_read_barrier(); + if (is_ping) { + noc_semaphore_inc(ping_read_has_data, 1); + } else { + noc_semaphore_inc(pong_read_has_data, 1); + } + } + return; +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp new file mode 100644 index 00000000000..c2feffa6669 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp @@ -0,0 +1,193 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +NOTE: This function is an improvement on rm_reshape_interleaved.cpp but it has a bug causing a hang for some cases that +needs to be debugged first Function reads from RM and writes to RM + +Assumptions: + +Compile arguments +0. src0_is_dram: 1 if source is dram else 0 +1. read_size_is_pow2: 1 if read size is power of 2 else 0 +2. log_base_2_of_page_size: log base 2 of page size +3. write_size_is_pow2: 1 if write size is power of 2 else 0 +4. log_base_2_of_page_size: log base 2 of page size +5. needs_read_allignment: 1 if read needs allignment else 0 +//Needed if BRAM and page size is not multiple of 64 bytes + +Runtime arguments +0. src_addr: source address +1. dst_addr: destination address +2. source_page_size_bytes: source page size in bytes +3. dest_page_size_bytes: destination page size in bytes +4. source_read_size_bytes: source read size in bytes +5. read_start_page: read start page +6. read_end_page: read end page +7. write_start_page: write start page +*/ +#include +#include "dataflow_api.h" +#include "debug/dprint.h" // required in all kernels using DPRINT +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +void kernel_main() { + // We are guranteed to be in 2D going to 2D + + const uint32_t src_addr = get_arg_val(0); + const uint32_t dst_addr = get_arg_val(1); + const uint32_t source_page_size_bytes = get_arg_val(2); + const uint32_t dest_page_size_bytes = get_arg_val(3); + // If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 + // (rounded up to next 16B) + const uint32_t source_read_size_bytes = get_arg_val(4); + const uint32_t read_start_page = get_arg_val(5); + const uint32_t read_end_page = get_arg_val(6); + const uint32_t write_start_page = get_arg_val(7); + const uint32_t write_start_offset = get_arg_val(8); + const uint32_t nop = get_arg_val(9); + volatile uint32_t* ping_read_has_data = + reinterpret_cast(get_semaphore(get_arg_val(10))); + volatile uint32_t* pong_read_has_data = + reinterpret_cast(get_semaphore(get_arg_val(11))); + const uint64_t ping_buf_is_free = get_noc_addr(get_semaphore(get_arg_val(12))); + const uint64_t pong_buf_is_free = get_noc_addr(get_semaphore(get_arg_val(13))); + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; +#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 +#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); + constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(4); + // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this + // core + if (nop == 1) { + return; + } + + const InterleavedAddrGen d = {.bank_base_address = dst_addr, .page_size = dest_page_size_bytes}; + uint32_t read_offset = 0; + uint32_t write_page = write_start_page; + uint32_t readable = 0; + uint32_t end_to_write = 0; + uint32_t transaction = 0; + uint32_t writable = dest_page_size_bytes - write_start_offset; + // cb_id_in0 is a CB source_read_size_bytes +4 page size, 1 page + // cb_id_in1 is a CB source_read_size_bytes +4 page size, 1 page + // cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page + cb_reserve_back(cb_id_in0, 1); + cb_reserve_back(cb_id_in1, 1); + cb_reserve_back(cb_id_in2, 1); + const uint32_t source_buffer_ping = get_write_ptr(cb_id_in0); + const uint32_t source_buffer_pong = get_write_ptr(cb_id_in1); + const uint32_t dest_buffer = get_write_ptr(cb_id_in2); + cb_push_back(cb_id_in0, 1); + cb_push_back(cb_id_in1, 1); + cb_push_back(cb_id_in2, 1); + uint32_t source_buffer; + uint64_t dst_noc_addr = get_noc_addr(write_page, d); + uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset; + uint64_t begin_write_offset = write_offset; + volatile tt_l1_ptr std::uint32_t* read_offset_ptr_ping = + (volatile tt_l1_ptr uint32_t*)(source_buffer_ping + source_read_size_bytes); + volatile tt_l1_ptr std::uint32_t* read_offset_ptr_pong = + (volatile tt_l1_ptr uint32_t*)(source_buffer_pong + source_read_size_bytes); + bool is_ping = true; + bool first = true; + bool second = true; + bool third = true; + bool first_pong = true; + bool second_pong = true; + for (uint32_t i = read_start_page; i < read_end_page; i++) { + if (is_ping) { + if (first) { + first = false; + WAYPOINT("FAWW"); + } else if (second) { + second = false; + WAYPOINT("SAWW"); + } else if (third) { + third = false; + WAYPOINT("TAWW"); + } else { + WAYPOINT("AWW"); + } + source_buffer = source_buffer_ping; + noc_semaphore_wait(ping_read_has_data, 1); + read_offset = *read_offset_ptr_ping; + WAYPOINT("AWD"); + } else { + if (first_pong) { + first_pong = false; + WAYPOINT("FBWW"); + } else if (second_pong) { + second_pong = false; + WAYPOINT("SBWW"); + } + + else { + WAYPOINT("BWW"); + } + source_buffer = source_buffer_pong; + noc_semaphore_wait(pong_read_has_data, 1); + read_offset = *read_offset_ptr_pong; + WAYPOINT("BWD"); + } + readable = source_page_size_bytes; + // Write to dest + while (readable > 0) { + noc_async_write_barrier(); + if (readable < writable) { + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); + if (is_ping) { + noc_semaphore_inc(ping_buf_is_free, 1); + } else { + noc_semaphore_inc(pong_buf_is_free, 1); + } + writable = writable - readable; + write_offset = write_offset + readable; + readable = 0; + end_to_write = end_to_write + readable; + if (i == read_end_page - 1) { + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); + return; + } + } else if (readable == writable) { + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); + if (is_ping) { + noc_semaphore_inc(ping_buf_is_free, 1); + } else { + noc_semaphore_inc(pong_buf_is_free, 1); + } + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + writable = dest_page_size_bytes; + readable = 0; + if (i == read_end_page - 1) { + return; + } + end_to_write = 0; + write_page++; + dst_noc_addr = get_noc_addr(write_page, d); + write_offset = dst_noc_addr & OFFSET_16; + begin_write_offset = write_offset; + } else { + // writable < readable + + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, writable); + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + end_to_write = 0; + readable = readable - writable; + read_offset = read_offset + writable; + write_page++; + dst_noc_addr = get_noc_addr(write_page, d); + write_offset = dst_noc_addr & OFFSET_16; + begin_write_offset = write_offset; + writable = dest_page_size_bytes; + } + } + } + return; +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index f6ecc4f1bec..b6c6fedfaa7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -29,8 +29,7 @@ namespace ttnn::operations::data_movement::rm_reshape{ -operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output) -{ +operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& input, const Tensor& output) { tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); //get datum size tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); @@ -214,4 +213,221 @@ operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const T }; return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } + +operation::ProgramWithCallbacks rm_reshape_preparer_multi_risk(const Tensor& input, const Tensor& output) { + // NOTE: This function is an improvement on rm_reshape_preparer_single_risk but it has a bug causing a hang for some + // cases that needs to be debugged first This function uses both risk cores + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + // get datum size + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + const uint32_t data_size = input.element_size(); + tt::tt_metal::Device* device = input.device(); + // Multi device pre-computation + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + tt::log_debug("row major reshape"); + tt::log_debug("input shape: {}", input_log_shape); + tt::log_debug("output shape: {}", output_log_shape); + tt::log_debug("data size: {}", data_size); + uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; + uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; + uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 128; + uint32_t read_start_page = 0; + uint32_t write_start_page = 0; + tt::tt_metal::Buffer* src_buffer = input.buffer(); + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + // Find how many input pages each core is responsible for so that we always start at the begining of a read and + // write page Since the logical volumes match, we are guaranteed that the very last page is aligned + uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; + while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { + responsibility++; + } + const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; + uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + const uint32_t cb_size0 = source_read_size_bytes; + const uint32_t cb_size1 = source_read_size_bytes; + const uint32_t cb_size2 = (((dest_page_size_bytes - 1) & MASK_64) + 80); + + const uint32_t src0_cb_index = 0; + const uint32_t src1_cb_index = 1; + const uint32_t src2_cb_index = 2; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_size0 * 2, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size0); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size1); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); + tt::tt_metal::CircularBufferConfig cb_src2_config = + tt::tt_metal::CircularBufferConfig(cb_size2, {{src2_cb_index, cb_data_format}}) + .set_page_size(src2_cb_index, cb_size2); + auto cb_src2 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src2_config); + + std::vector compile_time_args = { + (std::uint32_t)src0_is_dram, + (std::uint32_t)(source_page_size_bytes % 64 == 0) ? 1 : 0, + (std::uint32_t)(source_page_size_bytes % 16 == 0) ? 1 : 0, + src0_cb_index, + src1_cb_index, + src2_cb_index}; + + tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); + tt::tt_metal::KernelHandle writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp", + total_cores, + tt::tt_metal::WriterDataMovementConfig(compile_time_args)); + uint32_t done = 0; + for (int core_x = 0; core_x < num_cores_x; core_x++) { + for (int core_y = 0; core_y < num_cores_y; core_y++) { + CoreCoord core = {core_x, core_y}; + if (done == 1) { + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0 + + }; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, reader_runtime_args); + } else { + // Create the circular buffers + auto ping_read_buf_has_data = CreateSemaphore(program, core, 0); + auto pong_read_buf_has_data = CreateSemaphore(program, core, 0); + auto ping_read_buf_is_free = CreateSemaphore(program, core, 1); + auto pong_read_buf_is_free = CreateSemaphore(program, core, 1); + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + const uint32_t end_of_read = read_start_page + responsibility; + + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + start_of_read, + end_of_read, + write_start_page, + 0, + done, + ping_read_buf_has_data, + pong_read_buf_has_data, + ping_read_buf_is_free, + pong_read_buf_is_free + + }; + write_start_page += write_jump; + read_start_page = end_of_read; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, reader_runtime_args); + } + } + } + auto override_runtime_args_callback = [reader_kernel_id, compute_with_storage_grid_size]( + const void* operation, + const Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + auto input = input_tensors.at(0); + auto output = output_tensors.at(0); + const uint32_t data_size = input.element_size(); + tt::tt_metal::Buffer* src_buffer = input.buffer(); + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; + uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; + uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 128; + uint32_t read_start_page = 0; + uint32_t write_start_page = 0; + uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; + while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { + responsibility++; + } + const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; + uint32_t done = 0; + for (int core_x = 0; core_x < num_cores_x; core_x++) { + for (int core_y = 0; core_y < num_cores_y; core_y++) { + CoreCoord core = {core_x, core_y}; + if (done == 1) { + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + 0, + 0, + 0, + 0, + 1 + + }; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else { + // Create the circular buffers + + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + const uint32_t end_of_read = read_start_page + responsibility; + + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + source_page_size_bytes, + dest_page_size_bytes, + source_read_size_bytes, + start_of_read, + end_of_read, + write_start_page, + 0, + done + + }; + write_start_page += write_jump; + read_start_page = end_of_read; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } + } + } + }; + return {.program = std::move(program)}; +} + +operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output) { + return rm_reshape_preparer_single_risk(input, output); +} + }; // namespace ttnn::operations::data_movement::rm_reshape From b709d9c7a10c4ee792fd4d1b41c823d3af2f87f9 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Mon, 2 Dec 2024 23:14:40 +0000 Subject: [PATCH 20/31] #15269: employing pow2 optimization --- .../device/device/rm_reshape_interleaved.cpp | 33 ++++++++++----- .../device/host/reshape_rm_host_prep.cpp | 41 +++++++------------ 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 3950c1f276c..21ff099cab4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -36,15 +36,13 @@ void kernel_main() { //We are guranteed to be in 2D going to 2D const uint32_t src_addr = get_arg_val(0); - const uint32_t dst_addr = get_arg_val(1); - const uint32_t source_page_size_bytes = get_arg_val(2); - const uint32_t dest_page_size_bytes = get_arg_val(3); + const uint32_t dst_addr = get_arg_val(1); //If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B) - const uint32_t source_read_size_bytes = get_arg_val(4); - const uint32_t read_start_page = get_arg_val(5); - const uint32_t read_end_page = get_arg_val(6); - const uint32_t write_start_page = get_arg_val(7); - const uint32_t write_start_offset = get_arg_val(8); + const uint32_t source_read_size_bytes = get_arg_val(2); + const uint32_t read_start_page = get_arg_val(3); + const uint32_t read_end_page = get_arg_val(4); + const uint32_t write_start_page = get_arg_val(5); + const uint32_t write_start_offset = get_arg_val(6); const uint32_t nop = get_arg_val(9); constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; @@ -52,22 +50,35 @@ void kernel_main() { #define src_aligned_to_16 get_compile_time_arg_val(2) == 1 constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); + constexpr uint32_t source_page_size_bytes = get_compile_time_arg_val(5); + constexpr uint32_t dest_page_size_bytes = get_compile_time_arg_val(6); +#define source_page_is_pow_2 get_compile_time_arg_val(7) == 1 + constexpr uint32_t source_page_pow_2 = get_compile_time_arg_val(8); +#define dest_page_is_pow_2 get_compile_time_arg_val(9) == 1 + constexpr uint32_t dest_page_pow_2 = get_compile_time_arg_val(10); //Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core if (nop == 1) { return; } - +#if source_page_is_pow_2 + const InterleavedPow2AddrGen s = { + .bank_base_address = src_addr, .log_base_2_of_page_size = source_page_pow_2}; +#else const InterleavedAddrGen s = { .bank_base_address = src_addr, .page_size = source_page_size_bytes }; - +#endif +#if dest_page_is_pow_2 + const InterleavedPow2AddrGen d = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = dest_page_pow_2}; +#else const InterleavedAddrGen d = { .bank_base_address = dst_addr, .page_size = dest_page_size_bytes }; - +#endif uint32_t read_offset = 0; uint32_t write_page = write_start_page; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index b6c6fedfaa7..8e15fae6536 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -77,13 +77,22 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) .set_page_size(src1_cb_index, cb_size1); auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); - + bool source_page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(source_page_size_bytes); + uint32_t source_page_pow_2 = source_page_is_pow_2 ? (std::uint32_t)std::log2(source_page_size_bytes) : 0; + bool dest_page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(dest_page_size_bytes); + uint32_t dest_page_pow_2 = dest_page_is_pow_2 ? (std::uint32_t)std::log2(dest_page_size_bytes) : 0; std::vector compile_time_args = { (std::uint32_t)src0_is_dram, (std::uint32_t)(source_page_size_bytes % 64 == 0) ? 1 : 0, (std::uint32_t)(source_page_size_bytes % 16 == 0) ? 1 : 0, src0_cb_index, - src1_cb_index}; + src1_cb_index, + source_page_size_bytes, + dest_page_size_bytes, + source_page_is_pow_2, + source_page_pow_2, + dest_page_is_pow_2, + dest_page_pow_2}; tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( program, @@ -96,16 +105,7 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in CoreCoord core = {core_x, core_y}; if (done == 1) { const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - 0, - 0, - 0, - 0, - 1 + src_buffer->address(), dst_buffer->address(), source_read_size_bytes, 0, 0, 0, 0, 1 }; tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); @@ -120,8 +120,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in const std::vector reader_runtime_args = { src_buffer->address(), dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, source_read_size_bytes, start_of_read, end_of_read, @@ -169,16 +167,7 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in CoreCoord core = {core_x, core_y}; if (done == 1) { const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - 0, - 0, - 0, - 0, - 1 + src_buffer->address(), dst_buffer->address(), source_read_size_bytes, 0, 0, 0, 0, 1 }; tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); @@ -193,8 +182,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in const std::vector reader_runtime_args = { src_buffer->address(), dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, source_read_size_bytes, start_of_read, end_of_read, @@ -236,7 +223,7 @@ operation::ProgramWithCallbacks rm_reshape_preparer_multi_risk(const Tensor& inp tt::log_debug("data size: {}", data_size); uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; - uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 128; + uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 256; uint32_t read_start_page = 0; uint32_t write_start_page = 0; tt::tt_metal::Buffer* src_buffer = input.buffer(); From 4fa2bf3055e3a7a1ac011136c84411c3bc25933d Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Tue, 3 Dec 2024 15:47:10 +0000 Subject: [PATCH 21/31] #15269: Added optimization to do one less copy on aligned only transfers --- .../device/device/rm_reshape_interleaved.cpp | 49 ++-- .../device/host/reshape_rm_host_prep.cpp | 233 ++---------------- 2 files changed, 51 insertions(+), 231 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 21ff099cab4..44d503ea40b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -96,8 +96,10 @@ void kernel_main() { cb_push_back(cb_id_in1, 1); uint64_t dst_noc_addr = get_noc_addr(write_page, d); - uint64_t write_offset = dst_noc_addr&OFFSET_16 + write_start_offset; + uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset; uint64_t begin_write_offset = write_offset; + constexpr bool can_be_clean = ((source_page_size_bytes % 16) == 0 && (dest_page_size_bytes % 16) == 0); + uint64_t dst_noc_addr_offset = 0; for (uint32_t i = read_start_page; i < read_end_page; i++) { //Read from source uint64_t src_noc_addr = s.get_noc_addr(i,0); @@ -124,23 +126,34 @@ void kernel_main() { noc_async_write_barrier(); if (readable < writable) { - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, readable); + if constexpr (can_be_clean) { + noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); + dst_noc_addr_offset = dst_noc_addr_offset + readable; + } else { + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); + if (i == read_end_page - 1) { + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); + return; + } + } writable = writable -readable; write_offset = write_offset + readable; readable = 0; end_to_write = end_to_write + readable; - if (i == read_end_page-1) - { - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); - return; - } + } else if (readable == writable) { - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, readable); - noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); + if constexpr (can_be_clean) { + noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); + } else { + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, readable); + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + } + dst_noc_addr_offset = 0; + writable = dest_page_size_bytes; readable = 0; if (i == read_end_page - 1) { @@ -154,15 +167,19 @@ void kernel_main() { } else { - //writable < readable - - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, writable); - noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); + if constexpr (can_be_clean) { + noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, writable); + } else { + tt::data_movement::common::tt_memmove( + dest_buffer + write_offset, source_buffer + read_offset, writable); + noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + } + // writable < readable end_to_write = 0; readable = readable - writable; read_offset = read_offset + writable; write_page++; + dst_noc_addr_offset = 0; dst_noc_addr = get_noc_addr(write_page, d); write_offset = dst_noc_addr&OFFSET_16; begin_write_offset = write_offset; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index 8e15fae6536..ac86b4576c1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -52,18 +52,26 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128; uint32_t read_start_page = 0; uint32_t write_start_page = 0; + uint32_t write_start_offset = 0; tt::tt_metal::Buffer *src_buffer = input.buffer(); tt::tt_metal::Buffer *dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); // Find how many input pages each core is responsible for so that we always start at the begining of a read and // write page Since the logical volumes match, we are guaranteed that the very last page is aligned - uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; + uint32_t responsibility = ((input_log_shape[-2] - 1) / num_cores_total) + 1; while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { responsibility++; } const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; + const uint32_t offset_jump = (responsibility * source_page_size_bytes) % dest_page_size_bytes; uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - + printf( + "source is %d bytes, dest is %d bytes, responsibility %d, wj %d, oj %d\n", + source_page_size_bytes, + dest_page_size_bytes, + responsibility, + write_jump, + offset_jump); const uint32_t cb_size0 = source_read_size_bytes; const uint32_t cb_size1 = ((dest_page_size_bytes - 1) & MASK_64) + 80; @@ -115,7 +123,8 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in // set the runtime args // set the compile time args const uint32_t start_of_read = read_start_page; - const uint32_t end_of_read = read_start_page + responsibility; + uint32_t end_of_read = read_start_page + responsibility; + end_of_read = end_of_read < input_log_shape[-2] ? end_of_read : input_log_shape[-2]; const std::vector reader_runtime_args = { src_buffer->address(), @@ -124,215 +133,20 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in start_of_read, end_of_read, write_start_page, - 0, + write_start_offset, done }; - write_start_page += write_jump; - read_start_page = end_of_read; - done = (end_of_read == input_log_shape[-2]) ? 1 : 0; - tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); - } - } - } - auto override_runtime_args_callback = [reader_kernel_id, compute_with_storage_grid_size]( - const void* operation, - const Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors) { - auto input = input_tensors.at(0); - auto output = output_tensors.at(0); - const uint32_t data_size = input.element_size(); - tt::tt_metal::Buffer* src_buffer = input.buffer(); - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores_total = num_cores_x * num_cores_y; - ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); - ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); - uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; - uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; - uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 128; - uint32_t read_start_page = 0; - uint32_t write_start_page = 0; - uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; - while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { - responsibility++; - } - const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; - uint32_t done = 0; - for (int core_x = 0; core_x < num_cores_x; core_x++) { - for (int core_y = 0; core_y < num_cores_y; core_y++) { - CoreCoord core = {core_x, core_y}; - if (done == 1) { - const std::vector reader_runtime_args = { - src_buffer->address(), dst_buffer->address(), source_read_size_bytes, 0, 0, 0, 0, 1 - - }; - tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + write_start_offset += offset_jump; + if (write_start_offset >= dest_page_size_bytes) { + write_start_page += write_jump + 1; + write_start_offset -= dest_page_size_bytes; } else { - // Create the circular buffers - - // set the runtime args - // set the compile time args - const uint32_t start_of_read = read_start_page; - const uint32_t end_of_read = read_start_page + responsibility; - - const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_read_size_bytes, - start_of_read, - end_of_read, - write_start_page, - 0, - done - - }; write_start_page += write_jump; - read_start_page = end_of_read; - done = (end_of_read == input_log_shape[-2]) ? 1 : 0; - tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); } - } - } - }; - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks rm_reshape_preparer_multi_risk(const Tensor& input, const Tensor& output) { - // NOTE: This function is an improvement on rm_reshape_preparer_single_risk but it has a bug causing a hang for some - // cases that needs to be debugged first This function uses both risk cores - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - // get datum size - tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - const uint32_t data_size = input.element_size(); - tt::tt_metal::Device* device = input.device(); - // Multi device pre-computation - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores_total = num_cores_x * num_cores_y; - CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); - ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); - tt::log_debug("row major reshape"); - tt::log_debug("input shape: {}", input_log_shape); - tt::log_debug("output shape: {}", output_log_shape); - tt::log_debug("data size: {}", data_size); - uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; - uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; - uint32_t source_read_size_bytes = ((source_page_size_bytes - 1) & MASK_64) + 256; - uint32_t read_start_page = 0; - uint32_t write_start_page = 0; - tt::tt_metal::Buffer* src_buffer = input.buffer(); - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - // Find how many input pages each core is responsible for so that we always start at the begining of a read and - // write page Since the logical volumes match, we are guaranteed that the very last page is aligned - uint32_t responsibility = (input_log_shape[-2] - 1) / num_cores_total + 1; - while ((responsibility * source_page_size_bytes) % dest_page_size_bytes != 0) { - responsibility++; - } - const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; - uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - - const uint32_t cb_size0 = source_read_size_bytes; - const uint32_t cb_size1 = source_read_size_bytes; - const uint32_t cb_size2 = (((dest_page_size_bytes - 1) & MASK_64) + 80); - - const uint32_t src0_cb_index = 0; - const uint32_t src1_cb_index = 1; - const uint32_t src2_cb_index = 2; - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig(cb_size0 * 2, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, cb_size0); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); - tt::tt_metal::CircularBufferConfig cb_src1_config = - tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, cb_size1); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); - tt::tt_metal::CircularBufferConfig cb_src2_config = - tt::tt_metal::CircularBufferConfig(cb_size2, {{src2_cb_index, cb_data_format}}) - .set_page_size(src2_cb_index, cb_size2); - auto cb_src2 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src2_config); - - std::vector compile_time_args = { - (std::uint32_t)src0_is_dram, - (std::uint32_t)(source_page_size_bytes % 64 == 0) ? 1 : 0, - (std::uint32_t)(source_page_size_bytes % 16 == 0) ? 1 : 0, - src0_cb_index, - src1_cb_index, - src2_cb_index}; - - tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp", - total_cores, - tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); - tt::tt_metal::KernelHandle writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp", - total_cores, - tt::tt_metal::WriterDataMovementConfig(compile_time_args)); - uint32_t done = 0; - for (int core_x = 0; core_x < num_cores_x; core_x++) { - for (int core_y = 0; core_y < num_cores_y; core_y++) { - CoreCoord core = {core_x, core_y}; - if (done == 1) { - const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0 - - }; - tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); - tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, reader_runtime_args); - } else { - // Create the circular buffers - auto ping_read_buf_has_data = CreateSemaphore(program, core, 0); - auto pong_read_buf_has_data = CreateSemaphore(program, core, 0); - auto ping_read_buf_is_free = CreateSemaphore(program, core, 1); - auto pong_read_buf_is_free = CreateSemaphore(program, core, 1); - // set the runtime args - // set the compile time args - const uint32_t start_of_read = read_start_page; - const uint32_t end_of_read = read_start_page + responsibility; - - const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - start_of_read, - end_of_read, - write_start_page, - 0, - done, - ping_read_buf_has_data, - pong_read_buf_has_data, - ping_read_buf_is_free, - pong_read_buf_is_free - - }; - write_start_page += write_jump; read_start_page = end_of_read; done = (end_of_read == input_log_shape[-2]) ? 1 : 0; tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); - tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, reader_runtime_args); } } } @@ -368,16 +182,7 @@ operation::ProgramWithCallbacks rm_reshape_preparer_multi_risk(const Tensor& inp CoreCoord core = {core_x, core_y}; if (done == 1) { const std::vector reader_runtime_args = { - src_buffer->address(), - dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, - source_read_size_bytes, - 0, - 0, - 0, - 0, - 1 + src_buffer->address(), dst_buffer->address(), source_read_size_bytes, 0, 0, 0, 0, 1 }; tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); @@ -392,8 +197,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer_multi_risk(const Tensor& inp const std::vector reader_runtime_args = { src_buffer->address(), dst_buffer->address(), - source_page_size_bytes, - dest_page_size_bytes, source_read_size_bytes, start_of_read, end_of_read, From 72f2d5c53af5e82cda229bc2b196af6dd0f577ee Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Tue, 3 Dec 2024 18:43:02 +0000 Subject: [PATCH 22/31] #15269: added packet form of noc_async_read and write --- .../data_movement/common/kernels/common.hpp | 36 ++++++++++++++++--- .../device/device/rm_reshape_interleaved.cpp | 36 ++++++++++++------- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index d559630736e..dd302e8628f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -15,7 +15,33 @@ namespace tt::data_movement::common { -template +#define max_packet_size 8192 + +template +FORCE_INLINE void enhanced_noc_async_read( + const uint64_t src_noc_addr, const uint32_t dst_l1_addr, const uint32_t bytes) { + // If you do not know the max_transfer_size at compile time write 0 to it. + // only reads is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as True + if constexpr (((max_transfer_size < max_packet_size) && (max_transfer_size != 0)) || only_reads) { + noc_async_read_one_packet(src_noc_addr, dst_l1_addr, bytes); + } else { + noc_async_read(src_noc_addr, dst_l1_addr, bytes); + } +} + +template +FORCE_INLINE void enhanced_noc_async_write( + const uint32_t src_l1_addr, const uint64_t dst_noc_addr, const uint32_t bytes) { + // If you do not know the max_transfer_size at compile time write 0 to it. + // only writes is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as False + if constexpr (((max_transfer_size < max_packet_size) && (max_transfer_size != 0)) || only_writes) { + noc_async_write_one_packet(src_l1_addr, dst_noc_addr, bytes); + } else { + noc_async_write(src_l1_addr, dst_noc_addr, bytes); + } +} + +template FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_addr, const uint32_t bytes) { //Function performs a memory copy between two l1 addresses in the local core //Uses noc_async_read when possible to copy the data over @@ -23,13 +49,13 @@ FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_a //Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later if constexpr (use_read_datamover) { if constexpr (guaranteed_16B_alligned) { - noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); + enhanced_noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); if constexpr (!copy_async) { noc_async_read_barrier(); } } else { if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) { - noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); + enhanced_noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); if constexpr (!copy_async) { noc_async_read_barrier(); } @@ -39,13 +65,13 @@ FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_a } } else { if constexpr (guaranteed_16B_alligned) { - noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); + enhanced_noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); if constexpr (!copy_async) { noc_async_write_barrier(); } } else { if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) { - noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); + enhanced_noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); if constexpr (!copy_async) { noc_async_write_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 44d503ea40b..320bda81277 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -56,6 +56,7 @@ void kernel_main() { constexpr uint32_t source_page_pow_2 = get_compile_time_arg_val(8); #define dest_page_is_pow_2 get_compile_time_arg_val(9) == 1 constexpr uint32_t dest_page_pow_2 = get_compile_time_arg_val(10); + //Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core if (nop == 1) { @@ -92,8 +93,8 @@ void kernel_main() { cb_reserve_back(cb_id_in1, 1); const uint32_t source_buffer = get_write_ptr(cb_id_in0); const uint32_t dest_buffer = get_write_ptr(cb_id_in1); - cb_push_back(cb_id_in0, 1); cb_push_back(cb_id_in1, 1); + cb_push_back(cb_id_in0, 1); uint64_t dst_noc_addr = get_noc_addr(write_page, d); uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset; @@ -106,15 +107,18 @@ void kernel_main() { #if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) //Aligned to 64 bytes or 16 bytes but L1 - noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes); + tt::data_movement::common::enhanced_noc_async_read( + src_noc_addr, source_buffer, source_page_size_bytes); read_offset = 0; #elif (tensor_is_dram) //DDR but not alligned to 64 (potentially also not alligned to 16) - noc_async_read(src_noc_addr&MASK_64, source_buffer, source_read_size_bytes); + tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>( + src_noc_addr & MASK_64, source_buffer, source_read_size_bytes); read_offset = src_noc_addr&OFFSET_64; #else //L1 but not alligned to 16 - noc_async_read(src_noc_addr&MASK_16, source_buffer, source_read_size_bytes); + tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>( + src_noc_addr & MASK_16, source_buffer, source_read_size_bytes); read_offset = src_noc_addr&OFFSET_16; #endif readable = source_page_size_bytes; @@ -127,13 +131,15 @@ void kernel_main() { if (readable < writable) { if constexpr (can_be_clean) { - noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); + tt::data_movement::common::enhanced_noc_async_write( + source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); dst_noc_addr_offset = dst_noc_addr_offset + readable; } else { - tt::data_movement::common::tt_memmove( + tt::data_movement::common::tt_memmove( dest_buffer + write_offset, source_buffer + read_offset, readable); if (i == read_end_page - 1) { - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); + tt::data_movement::common::enhanced_noc_async_write( + dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); return; } } @@ -146,11 +152,13 @@ void kernel_main() { else if (readable == writable) { if constexpr (can_be_clean) { - noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); + tt::data_movement::common::enhanced_noc_async_write( + source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); } else { - tt::data_movement::common::tt_memmove( + tt::data_movement::common::tt_memmove( dest_buffer + write_offset, source_buffer + read_offset, readable); - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + tt::data_movement::common::enhanced_noc_async_write( + dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); } dst_noc_addr_offset = 0; @@ -168,11 +176,13 @@ void kernel_main() { else { if constexpr (can_be_clean) { - noc_async_write(source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, writable); + tt::data_movement::common::enhanced_noc_async_write( + source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, writable); } else { - tt::data_movement::common::tt_memmove( + tt::data_movement::common::tt_memmove( dest_buffer + write_offset, source_buffer + read_offset, writable); - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); + tt::data_movement::common::enhanced_noc_async_write( + dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); } // writable < readable end_to_write = 0; From 09482f3ddeb4f223017706475b77f9944e76cbda Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Tue, 3 Dec 2024 19:04:05 +0000 Subject: [PATCH 23/31] #15269: further small optimizations --- .../device/device/rm_reshape_interleaved.cpp | 22 +++++++++++-------- .../device/host/reshape_rm_host_prep.cpp | 7 ------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp index 320bda81277..3a66652597c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp @@ -142,11 +142,11 @@ void kernel_main() { dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); return; } + write_offset = write_offset + readable; + end_to_write = end_to_write + readable; } - writable = writable -readable; - write_offset = write_offset + readable; + writable = writable - readable; readable = 0; - end_to_write = end_to_write + readable; } else if (readable == writable) @@ -167,11 +167,13 @@ void kernel_main() { if (i == read_end_page - 1) { return; } - end_to_write = 0; write_page++; dst_noc_addr = get_noc_addr(write_page, d); - write_offset = dst_noc_addr&OFFSET_16; - begin_write_offset = write_offset; + if constexpr (!can_be_clean) { + end_to_write = 0; + write_offset = dst_noc_addr & OFFSET_16; + begin_write_offset = write_offset; + } } else { @@ -185,14 +187,16 @@ void kernel_main() { dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); } // writable < readable - end_to_write = 0; readable = readable - writable; read_offset = read_offset + writable; write_page++; dst_noc_addr_offset = 0; dst_noc_addr = get_noc_addr(write_page, d); - write_offset = dst_noc_addr&OFFSET_16; - begin_write_offset = write_offset; + if constexpr (!can_be_clean) { + end_to_write = 0; + write_offset = dst_noc_addr & OFFSET_16; + begin_write_offset = write_offset; + } writable = dest_page_size_bytes; } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp index ac86b4576c1..e46916bf993 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp @@ -65,13 +65,6 @@ operation::ProgramWithCallbacks rm_reshape_preparer_single_risk(const Tensor& in const uint32_t write_jump = (responsibility * source_page_size_bytes) / dest_page_size_bytes; const uint32_t offset_jump = (responsibility * source_page_size_bytes) % dest_page_size_bytes; uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - printf( - "source is %d bytes, dest is %d bytes, responsibility %d, wj %d, oj %d\n", - source_page_size_bytes, - dest_page_size_bytes, - responsibility, - write_jump, - offset_jump); const uint32_t cb_size0 = source_read_size_bytes; const uint32_t cb_size1 = ((dest_page_size_bytes - 1) & MASK_64) + 80; From 323ace7b4f396d3a786dc888d01bf5d6af726375 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 14:42:17 +0000 Subject: [PATCH 24/31] #15702: Added a skip for grayskull due to issue 15702 --- tests/ttnn/unit_tests/test_reshape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index 56f4c5e29b0..bf7d1bf568d 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -9,6 +9,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_grayskull @pytest.mark.parametrize("n", [16]) @@ -354,6 +355,7 @@ def test_reshape_host(input_shape, output_shape, device): # required for Embedding +@skip_for_grayskull("avoid this test while issue 15702 is resolved") @pytest.mark.parametrize( "input_shape, output_shape", [ From 81c282de7b1cf93a5fdcec05f00bc1ec3fddd824 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 15:37:11 +0000 Subject: [PATCH 25/31] #15269: updating mnist device perf targets --- models/demos/mnist/tests/test_perf_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/mnist/tests/test_perf_mnist.py b/models/demos/mnist/tests/test_perf_mnist.py index c60efe5f16e..a4d701bcaba 100644 --- a/models/demos/mnist/tests/test_perf_mnist.py +++ b/models/demos/mnist/tests/test_perf_mnist.py @@ -112,9 +112,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds): num_iterations = 1 margin = 0.03 if is_grayskull(): - expected_perf = 653017.5 + expected_perf = 390000.0 elif is_wormhole_b0(): - expected_perf = 1383185.64944 + expected_perf = 900000.0 command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py::test_mnist" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] From 82c0eb788cbe35e418bffe573b910b0ec3dd2853 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 19:15:38 +0000 Subject: [PATCH 26/31] #15269: updating other perf targets --- models/demos/bert_tiny/tests/test_performance.py | 2 +- models/demos/vgg/tests/test_perf_vgg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/demos/bert_tiny/tests/test_performance.py b/models/demos/bert_tiny/tests/test_performance.py index e8b44f7d9cb..6ce1bc48b44 100644 --- a/models/demos/bert_tiny/tests/test_performance.py +++ b/models/demos/bert_tiny/tests/test_performance.py @@ -119,7 +119,7 @@ def test_perf_device_bare_metal(batch_size, expected_perf): margin = 0.03 if is_wormhole_b0(): - expected_perf = 4155.25 + expected_perf = 3990.0 else: expected_perf = 3476.55 diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py index a91bd0b3fa4..14d2fa3b587 100644 --- a/models/demos/vgg/tests/test_perf_vgg.py +++ b/models/demos/vgg/tests/test_perf_vgg.py @@ -137,10 +137,10 @@ def test_perf_device_bare_metal_vgg(batch_size, model_name): margin = 0.03 if model_name == "ttnn_vgg11": - expected_perf = 168 if is_grayskull() else 283.289 + expected_perf = 36 if is_grayskull() else 283.289 command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py" else: - expected_perf = 144 if is_grayskull() else 194.84 + expected_perf = 34 if is_grayskull() else 194.84 command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] From d835144361ee895ac45f27f71cbebdcc19b505cd Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 19:55:55 +0000 Subject: [PATCH 27/31] #15269: removing broken unused kernel code --- .../device/rm_reshape_interleaved_reader.cpp | 163 --------------- .../device/rm_reshape_interleaved_writer.cpp | 193 ------------------ 2 files changed, 356 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp deleted file mode 100644 index 8a5cc4d333c..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_reader.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -/* -NOTE: This function is an improvement on rm_reshape_interleaved.cpp but it has a bug causing a hang for some cases that -needs to be debugged first - -Function reads from RM and writes to RM - -Assumptions: - -Compile arguments -0. src0_is_dram: 1 if source is dram else 0 -1. read_size_is_pow2: 1 if read size is power of 2 else 0 -2. log_base_2_of_page_size: log base 2 of page size -3. write_size_is_pow2: 1 if write size is power of 2 else 0 -4. log_base_2_of_page_size: log base 2 of page size -5. needs_read_allignment: 1 if read needs allignment else 0 -//Needed if BRAM and page size is not multiple of 64 bytes - -Runtime arguments -0. src_addr: source address -1. dst_addr: destination address -2. source_page_size_bytes: source page size in bytes -3. dest_page_size_bytes: destination page size in bytes -4. source_read_size_bytes: source read size in bytes -5. read_start_page: read start page -6. read_end_page: read end page -7. write_start_page: write start page -*/ -#include -#include "dataflow_api.h" -#include "debug/dprint.h" // required in all kernels using DPRINT -#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" - -void kernel_main() { - // We are guranteed to be in 2D going to 2D - - const uint32_t src_addr = get_arg_val(0); - const uint32_t dst_addr = get_arg_val(1); - const uint32_t source_page_size_bytes = get_arg_val(2); - const uint32_t dest_page_size_bytes = get_arg_val(3); - // If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 - // (rounded up to next 16B) - const uint32_t source_read_size_bytes = get_arg_val(4); - const uint32_t read_start_page = get_arg_val(5); - const uint32_t read_end_page = get_arg_val(6); - const uint32_t write_start_page = get_arg_val(7); - const uint32_t write_start_offset = get_arg_val(8); - const uint32_t nop = get_arg_val(9); - const uint64_t ping_read_has_data = get_noc_addr(get_semaphore(get_arg_val(10))); - const uint64_t pong_read_has_data = get_noc_addr(get_semaphore(get_arg_val(11))); - volatile uint32_t* ping_buf_is_free = - reinterpret_cast(get_semaphore(get_arg_val(12))); - volatile uint32_t* pong_buf_is_free = - reinterpret_cast(get_semaphore(get_arg_val(13))); - constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; -#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 -#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 - constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); - constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); - constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(4); - // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this - // core - if (nop == 1) { - return; - } - - const InterleavedAddrGen s = {.bank_base_address = src_addr, .page_size = source_page_size_bytes}; - - uint32_t read_offset = 0; - uint32_t write_page = write_start_page; - uint32_t readable = 0; - uint32_t end_to_write = 0; - uint32_t transaction = 0; - uint32_t writable = dest_page_size_bytes - write_start_offset; - // cb_id_in0 is a CB source_read_size_bytes +4 page size, 1 page - // cb_id_in1 is a CB source_read_size_bytes +4 page size, 1 page - // cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page - cb_reserve_back(cb_id_in0, 1); - cb_reserve_back(cb_id_in1, 1); - cb_reserve_back(cb_id_in2, 1); - const uint32_t source_buffer_ping = get_write_ptr(cb_id_in0); - const uint32_t source_buffer_pong = get_write_ptr(cb_id_in1); - const uint32_t dest_buffer = get_write_ptr(cb_id_in2); - cb_push_back(cb_id_in0, 1); - cb_push_back(cb_id_in1, 1); - cb_push_back(cb_id_in2, 1); - uint32_t source_buffer; - - volatile tt_l1_ptr std::uint32_t* read_offset_ptr_ping = - (volatile tt_l1_ptr uint32_t*)(source_buffer_ping + source_read_size_bytes); - volatile tt_l1_ptr std::uint32_t* read_offset_ptr_pong = - (volatile tt_l1_ptr uint32_t*)(source_buffer_pong + source_read_size_bytes); - bool is_ping = true; - bool first = true; - bool second = true; - bool third = true; - bool first_pong = true; - bool second_pong = true; - bool third_pong = true; - for (uint32_t i = read_start_page; i < read_end_page; i++) { - // Read from source - uint64_t src_noc_addr = s.get_noc_addr(i, 0); - if (is_ping) { - if (first) { - first = false; - WAYPOINT("FARW"); - } else if (second) { - second = false; - WAYPOINT("SARW"); - } else if (third) { - third = false; - WAYPOINT("TARW"); - } else { - WAYPOINT("ARW"); - } - source_buffer = source_buffer_ping; - noc_semaphore_wait(ping_buf_is_free, 1); - WAYPOINT("ARD"); - } else { - if (first_pong) { - first_pong = false; - WAYPOINT("FBRW"); - } else if (second_pong) { - second_pong = false; - WAYPOINT("SBRW"); - } else { - WAYPOINT("BRW"); - } - source_buffer = source_buffer_pong; - noc_semaphore_wait(pong_buf_is_free, 1); - WAYPOINT("BRD"); - } - -#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) - // Aligned to 64 bytes or 16 bytes but L1 - noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes); - read_offset = 0; -#elif (tensor_is_dram) - // DDR but not alligned to 64 (potentially also not alligned to 16) - noc_async_read(src_noc_addr & MASK_64, source_buffer, source_read_size_bytes); - read_offset = src_noc_addr & OFFSET_64; -#else - // L1 but not alligned to 16 - noc_async_read(src_noc_addr & MASK_16, source_buffer, source_read_size_bytes); - read_offset = src_noc_addr & OFFSET_16; -#endif - if (is_ping) { - *read_offset_ptr_ping = read_offset; - } else { - *read_offset_ptr_pong = read_offset; - } - noc_async_read_barrier(); - if (is_ping) { - noc_semaphore_inc(ping_read_has_data, 1); - } else { - noc_semaphore_inc(pong_read_has_data, 1); - } - } - return; -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp deleted file mode 100644 index c2feffa6669..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved_writer.cpp +++ /dev/null @@ -1,193 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -/* -NOTE: This function is an improvement on rm_reshape_interleaved.cpp but it has a bug causing a hang for some cases that -needs to be debugged first Function reads from RM and writes to RM - -Assumptions: - -Compile arguments -0. src0_is_dram: 1 if source is dram else 0 -1. read_size_is_pow2: 1 if read size is power of 2 else 0 -2. log_base_2_of_page_size: log base 2 of page size -3. write_size_is_pow2: 1 if write size is power of 2 else 0 -4. log_base_2_of_page_size: log base 2 of page size -5. needs_read_allignment: 1 if read needs allignment else 0 -//Needed if BRAM and page size is not multiple of 64 bytes - -Runtime arguments -0. src_addr: source address -1. dst_addr: destination address -2. source_page_size_bytes: source page size in bytes -3. dest_page_size_bytes: destination page size in bytes -4. source_read_size_bytes: source read size in bytes -5. read_start_page: read start page -6. read_end_page: read end page -7. write_start_page: write start page -*/ -#include -#include "dataflow_api.h" -#include "debug/dprint.h" // required in all kernels using DPRINT -#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" - -void kernel_main() { - // We are guranteed to be in 2D going to 2D - - const uint32_t src_addr = get_arg_val(0); - const uint32_t dst_addr = get_arg_val(1); - const uint32_t source_page_size_bytes = get_arg_val(2); - const uint32_t dest_page_size_bytes = get_arg_val(3); - // If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 - // (rounded up to next 16B) - const uint32_t source_read_size_bytes = get_arg_val(4); - const uint32_t read_start_page = get_arg_val(5); - const uint32_t read_end_page = get_arg_val(6); - const uint32_t write_start_page = get_arg_val(7); - const uint32_t write_start_offset = get_arg_val(8); - const uint32_t nop = get_arg_val(9); - volatile uint32_t* ping_read_has_data = - reinterpret_cast(get_semaphore(get_arg_val(10))); - volatile uint32_t* pong_read_has_data = - reinterpret_cast(get_semaphore(get_arg_val(11))); - const uint64_t ping_buf_is_free = get_noc_addr(get_semaphore(get_arg_val(12))); - const uint64_t pong_buf_is_free = get_noc_addr(get_semaphore(get_arg_val(13))); - constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; -#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 -#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 - constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); - constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); - constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(4); - // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this - // core - if (nop == 1) { - return; - } - - const InterleavedAddrGen d = {.bank_base_address = dst_addr, .page_size = dest_page_size_bytes}; - uint32_t read_offset = 0; - uint32_t write_page = write_start_page; - uint32_t readable = 0; - uint32_t end_to_write = 0; - uint32_t transaction = 0; - uint32_t writable = dest_page_size_bytes - write_start_offset; - // cb_id_in0 is a CB source_read_size_bytes +4 page size, 1 page - // cb_id_in1 is a CB source_read_size_bytes +4 page size, 1 page - // cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page - cb_reserve_back(cb_id_in0, 1); - cb_reserve_back(cb_id_in1, 1); - cb_reserve_back(cb_id_in2, 1); - const uint32_t source_buffer_ping = get_write_ptr(cb_id_in0); - const uint32_t source_buffer_pong = get_write_ptr(cb_id_in1); - const uint32_t dest_buffer = get_write_ptr(cb_id_in2); - cb_push_back(cb_id_in0, 1); - cb_push_back(cb_id_in1, 1); - cb_push_back(cb_id_in2, 1); - uint32_t source_buffer; - uint64_t dst_noc_addr = get_noc_addr(write_page, d); - uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset; - uint64_t begin_write_offset = write_offset; - volatile tt_l1_ptr std::uint32_t* read_offset_ptr_ping = - (volatile tt_l1_ptr uint32_t*)(source_buffer_ping + source_read_size_bytes); - volatile tt_l1_ptr std::uint32_t* read_offset_ptr_pong = - (volatile tt_l1_ptr uint32_t*)(source_buffer_pong + source_read_size_bytes); - bool is_ping = true; - bool first = true; - bool second = true; - bool third = true; - bool first_pong = true; - bool second_pong = true; - for (uint32_t i = read_start_page; i < read_end_page; i++) { - if (is_ping) { - if (first) { - first = false; - WAYPOINT("FAWW"); - } else if (second) { - second = false; - WAYPOINT("SAWW"); - } else if (third) { - third = false; - WAYPOINT("TAWW"); - } else { - WAYPOINT("AWW"); - } - source_buffer = source_buffer_ping; - noc_semaphore_wait(ping_read_has_data, 1); - read_offset = *read_offset_ptr_ping; - WAYPOINT("AWD"); - } else { - if (first_pong) { - first_pong = false; - WAYPOINT("FBWW"); - } else if (second_pong) { - second_pong = false; - WAYPOINT("SBWW"); - } - - else { - WAYPOINT("BWW"); - } - source_buffer = source_buffer_pong; - noc_semaphore_wait(pong_read_has_data, 1); - read_offset = *read_offset_ptr_pong; - WAYPOINT("BWD"); - } - readable = source_page_size_bytes; - // Write to dest - while (readable > 0) { - noc_async_write_barrier(); - if (readable < writable) { - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, readable); - if (is_ping) { - noc_semaphore_inc(ping_buf_is_free, 1); - } else { - noc_semaphore_inc(pong_buf_is_free, 1); - } - writable = writable - readable; - write_offset = write_offset + readable; - readable = 0; - end_to_write = end_to_write + readable; - if (i == read_end_page - 1) { - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); - return; - } - } else if (readable == writable) { - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, readable); - if (is_ping) { - noc_semaphore_inc(ping_buf_is_free, 1); - } else { - noc_semaphore_inc(pong_buf_is_free, 1); - } - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); - writable = dest_page_size_bytes; - readable = 0; - if (i == read_end_page - 1) { - return; - } - end_to_write = 0; - write_page++; - dst_noc_addr = get_noc_addr(write_page, d); - write_offset = dst_noc_addr & OFFSET_16; - begin_write_offset = write_offset; - } else { - // writable < readable - - tt::data_movement::common::tt_memmove( - dest_buffer + write_offset, source_buffer + read_offset, writable); - noc_async_write(dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); - end_to_write = 0; - readable = readable - writable; - read_offset = read_offset + writable; - write_page++; - dst_noc_addr = get_noc_addr(write_page, d); - write_offset = dst_noc_addr & OFFSET_16; - begin_write_offset = write_offset; - writable = dest_page_size_bytes; - } - } - } - return; -} From dad9a4fac4c809168fb2481834d2f0c845b49fa6 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 20:53:21 +0000 Subject: [PATCH 28/31] #0: pre commit formatting change --- models/demos/vgg/tests/test_perf_vgg.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py index 6187dee4f7d..14d2fa3b587 100644 --- a/models/demos/vgg/tests/test_perf_vgg.py +++ b/models/demos/vgg/tests/test_perf_vgg.py @@ -137,12 +137,10 @@ def test_perf_device_bare_metal_vgg(batch_size, model_name): margin = 0.03 if model_name == "ttnn_vgg11": - expected_perf = 36 if is_grayskull() else 283.289 command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py" else: expected_perf = 34 if is_grayskull() else 194.84 - command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] From 0ea118be3461c73b1df36a48f5ad48da11facb82 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 22:13:51 +0000 Subject: [PATCH 29/31] #0 addressing artem PR review changes on 15572 PR --- tests/ttnn/unit_tests/test_reshape.py | 2 +- ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp | 2 +- .../operations/data_movement/reshape_view/reshape_common.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index bf7d1bf568d..f448624e19b 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -313,7 +313,7 @@ def test_reshape_tile_with_padding(input_shape, output_shape, layout, device): # issue 15048 -def test_broken_reshape(device): +def test_previously_failing_test(device): src_shape = (1, 56, 56, 64) target_shape = (1, 1, 56 * 56, 64) torch_input_tensor = torch.randn(src_shape, dtype=torch.bfloat16) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 9b4be6f13b0..6af8646a007 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -312,7 +312,7 @@ ttnn::Shape shape_corrector(const ttnn::Tensor& tensor, const ttnn::Shape& shape uint32_t output_volume = 1; uint32_t inferred_dim = -1; for (uint32_t i=0; i< shape.rank(); i++) { - if (((int)(shape[i])) == -1) { + if ((static_cast(shape[i])) == -1) { if (inferred_dim != -1) { TT_FATAL(false, "Only one dimension can be inferred in reshape"); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp index a39245856b0..c20674ff369 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp @@ -2,4 +2,4 @@ // // SPDX-License-Identifier: Apache-2.0 -typedef std::variant PadValue; +using PadValue = std::variant; From 66a28c1b613d06233db1db02488825be779d9d5f Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 22:19:30 +0000 Subject: [PATCH 30/31] #15269: updating vgg device targets --- models/demos/vgg/tests/test_perf_vgg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py index 14d2fa3b587..4a74e17bd95 100644 --- a/models/demos/vgg/tests/test_perf_vgg.py +++ b/models/demos/vgg/tests/test_perf_vgg.py @@ -137,10 +137,10 @@ def test_perf_device_bare_metal_vgg(batch_size, model_name): margin = 0.03 if model_name == "ttnn_vgg11": - expected_perf = 36 if is_grayskull() else 283.289 + expected_perf = 36 if is_grayskull() else 104 command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py" else: - expected_perf = 34 if is_grayskull() else 194.84 + expected_perf = 34 if is_grayskull() else 90 command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] From 0e411b0f94383e0d25f87aa71e80a14ead2dcf01 Mon Sep 17 00:00:00 2001 From: Juan Camilo Vega Date: Wed, 4 Dec 2024 22:36:46 +0000 Subject: [PATCH 31/31] #0: addressing austin PR review changes on 15572 PR --- .../data_movement/common/kernels/common.hpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index dd302e8628f..6ab3a58f5cb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -15,14 +15,12 @@ namespace tt::data_movement::common { -#define max_packet_size 8192 - template FORCE_INLINE void enhanced_noc_async_read( const uint64_t src_noc_addr, const uint32_t dst_l1_addr, const uint32_t bytes) { // If you do not know the max_transfer_size at compile time write 0 to it. // only reads is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as True - if constexpr (((max_transfer_size < max_packet_size) && (max_transfer_size != 0)) || only_reads) { + if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_reads) { noc_async_read_one_packet(src_noc_addr, dst_l1_addr, bytes); } else { noc_async_read(src_noc_addr, dst_l1_addr, bytes); @@ -34,21 +32,22 @@ FORCE_INLINE void enhanced_noc_async_write( const uint32_t src_l1_addr, const uint64_t dst_noc_addr, const uint32_t bytes) { // If you do not know the max_transfer_size at compile time write 0 to it. // only writes is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as False - if constexpr (((max_transfer_size < max_packet_size) && (max_transfer_size != 0)) || only_writes) { + if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_writes) { noc_async_write_one_packet(src_l1_addr, dst_noc_addr, bytes); } else { noc_async_write(src_l1_addr, dst_noc_addr, bytes); } } -template +template FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_addr, const uint32_t bytes) { - //Function performs a memory copy between two l1 addresses in the local core - //Uses noc_async_read when possible to copy the data over - //Set guaranteed 16B alligned to true if the source and destination are externally guaranteed to be 16B alligned (dangerous) - //Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later + // Function performs a memory copy between two l1 addresses in the local core + // Uses noc_async_read when possible to copy the data over + // Set guaranteed 16B aligned to true if the source and destination are externally guaranteed to be 16B aligned + // (dangerous) Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add + // a noc_async_read_barrier to synchronize later if constexpr (use_read_datamover) { - if constexpr (guaranteed_16B_alligned) { + if constexpr (guaranteed_16B_aligned) { enhanced_noc_async_read(get_noc_addr(src_l1_addr), dst_l1_addr, bytes); if constexpr (!copy_async) { noc_async_read_barrier(); @@ -64,7 +63,7 @@ FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_a } } } else { - if constexpr (guaranteed_16B_alligned) { + if constexpr (guaranteed_16B_aligned) { enhanced_noc_async_write(src_l1_addr, get_noc_addr(dst_l1_addr), bytes); if constexpr (!copy_async) { noc_async_write_barrier();