-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#15572: Rewrite of Reshape OP from scratch #15572
Changes from 11 commits
5d343d4
e2bbda5
43b0bd2
8a94c1b
4fd8c45
285d1fa
7d40412
7335c11
6110d47
cbc2830
98c4b01
055dc0f
1d058c7
fac56d8
90b6255
7ff95d2
2edceb5
ad28b55
5501324
109615f
631b1c5
10de9a2
b709d9c
4fa2bf3
72f2d5c
09482f3
3e30253
323ace7
81c282d
82c0eb7
d835144
16dd8ad
dad9a4f
0ea118b
66a28c1
0e411b0
b8ce2b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,45 @@ | |
// It's best to copy and paste the functions in rather than include the header as code size will likely explode | ||
// Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking | ||
// issues | ||
#include <stdio.h> | ||
#include <cstring> | ||
#define MASK_64 0xFFFFFFFFFFFFFFC0 | ||
#define OFFSET_64 0x000000000000003F | ||
#define MASK_16 0xFFFFFFFFFFFFFFF0 | ||
#define OFFSET_16 0x000000000000000F | ||
|
||
namespace tt::data_movement::common { | ||
|
||
template <bool guaranteed_16B_alligned, bool copy_async> | ||
FORCE_INLINE | ||
void tt_memmove ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great. Can you announce this utility function in TT developers. This might be useful for others. |
||
const uint32_t dst_l1_addr, | ||
const uint64_t src_l1_addr, | ||
const uint32_t bytes) | ||
{ | ||
//Function performs a memory copy between two l1 addresses in the local core | ||
//Uses noc_async_read when possible to copy the data over | ||
//Set guaranteed 16B alligned to true if the source and destination are externally guaranteed to be 16B alligned (dangerous) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should also not be hardcoding alignment values like 16 and 64 since there's no guarantee they're always that value across arches, these can be queried and passed from host as compile time args or accessed in kernels ex L1_ALIGNMENT. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit annoying to fix as I am also using MASK and OFFSET that is hardcoded to what should be used if the shift 16 and 64 respectively and they are also used in other files as this is the common code for all data mover ops. I can't take the value as compile time arg as this is in common. However because it is in common I think updating it should a new arch use a new alignment is not very difficult. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Come to think of it, I think the cleanest fix would be to define a DDR_ALIGNMENT at the same location L1_ALIGNMENT is defined that is set automatically at compile time depending on the ARCH type. The OPs would then use that constant value instead which is auto updated to match the architecture. However that would not be the solution for this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be acceptable to let this merge in as is for now and then I will post an issue where I can fix it in a cleaner way, just in the interest of getting this in as soon as possible since the change would alter multiple files, require a new function, and I am not comfortable doing that without re-running the pipelines that take hours There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DRAM_ALIGNMENT is also already defined (is this what you meant by DDR_ALIGNMENT?). It's fine if you want to address later, I'm not blocking the pr on any of these comments but wanted to make sure this was considered, as there is currently work being done regarding what alignments are being used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh cool, good to know. Yeah will do a subsequent PR since I do want to avoid having to re-visit alignment when the next arch comes in |
||
//Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later | ||
if constexpr (guaranteed_16B_alligned) | ||
{ | ||
noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); | ||
if constexpr (!copy_async) {noc_async_read_barrier();} | ||
} | ||
else | ||
{ | ||
if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16)) | ||
{ | ||
noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes); | ||
if constexpr (!copy_async) {noc_async_read_barrier();} | ||
} | ||
else | ||
{ | ||
memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes)); | ||
} | ||
} | ||
} | ||
|
||
// this function is useful for converting bfloat16 values to float32 | ||
FORCE_INLINE float bfloat16_to_float32(uint16_t bfloat16_data) { | ||
uint32_t bits = static_cast<uint32_t>(bfloat16_data) << 16; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
/* | ||
Function reads from RM and writes to RM | ||
|
||
Assumptions: | ||
|
||
Compile arguments | ||
0. src0_is_dram: 1 if source is dram else 0 | ||
1. read_size_is_pow2: 1 if read size is power of 2 else 0 | ||
2. log_base_2_of_page_size: log base 2 of page size | ||
3. write_size_is_pow2: 1 if write size is power of 2 else 0 | ||
4. log_base_2_of_page_size: log base 2 of page size | ||
5. needs_read_allignment: 1 if read needs allignment else 0 | ||
//Needed if BRAM and page size is not multiple of 64 bytes | ||
|
||
Runtime arguments | ||
0. src_addr: source address | ||
1. dst_addr: destination address | ||
2. source_page_size_bytes: source page size in bytes | ||
3. dest_page_size_bytes: destination page size in bytes | ||
4. source_read_size_bytes: source read size in bytes | ||
5. read_start_page: read start page | ||
6. read_end_page: read end page | ||
7. write_start_page: write start page | ||
*/ | ||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
#include "debug/dprint.h" // required in all kernels using DPRINT | ||
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" | ||
|
||
void kernel_main() { | ||
//We are guranteed to be in 2D going to 2D | ||
|
||
const uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t dst_addr = get_arg_val<uint32_t>(1); | ||
const uint32_t source_page_size_bytes = get_arg_val<uint32_t>(2); | ||
const uint32_t dest_page_size_bytes = get_arg_val<uint32_t>(3); | ||
//If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B) | ||
const uint32_t source_read_size_bytes = get_arg_val<uint32_t>(4); | ||
const uint32_t read_start_page = get_arg_val<uint32_t>(5); | ||
const uint32_t read_end_page = get_arg_val<uint32_t>(6); | ||
const uint32_t write_start_page = get_arg_val<uint32_t>(7); | ||
//cb_id_in0 is a circular buffer with 1 source_page_size_bytes page if no alignment needed | ||
//source_read_size_bytes otherwise | ||
const uint32_t cb_id_in0 = get_arg_val<uint32_t>(8); | ||
//cb_id_in1 is a circular buffer with 1 dest_page_size_bytes+16 (rounded up to next 64B) page | ||
const uint32_t cb_id_in1 = get_arg_val<uint32_t>(9); | ||
|
||
|
||
constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; | ||
#define src_aligned_to_64 get_compile_time_arg_val<uint32_t>(1) == 1 | ||
#define src_aligned_to_16 get_compile_time_arg_val<uint32_t>(2) == 1 | ||
#define dst_aligned_to_16 get_compile_time_arg_val<uint32_t>(3) == 1 | ||
|
||
|
||
const InterleavedAddrGen<tensor_is_dram> s = { | ||
.bank_base_address = src_addr, | ||
.page_size = source_page_size_bytes | ||
}; | ||
|
||
const InterleavedAddrGen<tensor_is_dram> d = { | ||
.bank_base_address = dst_addr, | ||
.page_size = dest_page_size_bytes | ||
}; | ||
|
||
|
||
uint32_t read_offset = 0; | ||
uint32_t write_page = write_start_page; | ||
uint32_t readable = 0; | ||
uint32_t transaction = 0; | ||
uint32_t writable = dest_page_size_bytes; | ||
//cb_id_in0 is a CB source_read_size_bytes page size, 1 page | ||
//cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page | ||
cb_reserve_back(cb_id_in0, 1); | ||
cb_reserve_back(cb_id_in1, 1); | ||
const uint32_t source_buffer = get_write_ptr(cb_id_in0); | ||
const uint32_t dest_buffer = get_write_ptr(cb_id_in1); | ||
|
||
uint64_t dst_noc_addr = get_noc_addr(write_page, d); | ||
#if (dst_aligned_to_16) | ||
uint32_t write_offset = 0; | ||
#else | ||
uint32_t write_offset = dst_noc_addr&OFFSET_16; | ||
uint32_t begin_write_offset = write_offset; | ||
#endif | ||
for (uint32_t i = read_start_page; i <= read_end_page; i++) { | ||
//Read from source | ||
uint64_t src_noc_addr = s.get_noc_addr(i,0); | ||
|
||
#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) | ||
//Aligned to 64 bytes or 16 bytes but L1 | ||
noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes); | ||
read_offset = 0; | ||
#elif (tensor_is_dram) | ||
//DDR but not alligned to 64 (potentially also not alligned to 16) | ||
noc_async_read(src_noc_addr&MASK_64, source_buffer, source_read_size_bytes); | ||
read_offset = src_noc_addr&OFFSET_64; | ||
#else | ||
//L1 but not alligned to 16 | ||
noc_async_read(src_noc_addr&MASK_16, source_buffer, source_read_size_bytes); | ||
read_offset = src_noc_addr&OFFSET_16; | ||
#endif | ||
readable = source_page_size_bytes; | ||
noc_async_read_barrier(); | ||
|
||
//Write to dest | ||
while (readable > 0) | ||
{ | ||
noc_async_write_barrier(); | ||
if (readable < writable) | ||
{ | ||
tt::data_movement::common::tt_memmove<false,true>(dest_buffer+write_offset, source_buffer + read_offset, readable); | ||
writable = writable -readable; | ||
write_offset = write_offset + readable; | ||
readable = 0; | ||
} | ||
else if (readable == writable) | ||
{ | ||
tt::data_movement::common::tt_memmove<false,false>(dest_buffer+write_offset, source_buffer + read_offset, readable); | ||
#if ((dst_aligned_to_16)) | ||
noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); | ||
#else | ||
noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); | ||
#endif | ||
writable = dest_page_size_bytes; | ||
readable = 0; | ||
if (i == read_end_page-1) | ||
{ | ||
cb_push_back(cb_id_in0, 1); | ||
cb_push_back(cb_id_in1, 1); | ||
return; | ||
} | ||
write_page++; | ||
dst_noc_addr = get_noc_addr(write_page, d); | ||
#if ((dst_aligned_to_16)) | ||
write_offset=0; | ||
#else | ||
write_offset = dst_noc_addr&OFFSET_16; | ||
begin_write_offset = write_offset; | ||
#endif | ||
} | ||
else | ||
{ | ||
//writable < readable | ||
|
||
tt::data_movement::common::tt_memmove<false,false>(dest_buffer+write_offset, source_buffer + read_offset, writable); | ||
#if ((dst_aligned_to_16)) | ||
noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes); | ||
#else | ||
noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes); | ||
#endif | ||
readable = readable - writable; | ||
read_offset = read_offset + writable; | ||
write_page++; | ||
dst_noc_addr = get_noc_addr(write_page, d); | ||
#if ((dst_aligned_to_16)) | ||
write_offset=0; | ||
#else | ||
write_offset = dst_noc_addr&OFFSET_16; | ||
begin_write_offset = write_offset; | ||
#endif | ||
writable = dest_page_size_bytes; | ||
} | ||
} | ||
} | ||
cb_push_back(cb_id_in0, 1); | ||
cb_push_back(cb_id_in1, 1); | ||
return; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <math.h> | ||
|
||
#include "ttnn/operations/cb_utils.hpp" | ||
#include "ttnn/operations/math.hpp" | ||
#include "ttnn/operation.hpp" | ||
#include "ttnn/operations/core/work_split/work_split_tilize.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp" | ||
|
||
#include <optional> | ||
#include <variant> | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/core.hpp" | ||
#include "ttnn/device_operation.hpp" | ||
#include "ttnn/types.hpp" | ||
#include "ttnn/decorators.hpp" | ||
|
||
#define MASK_64 0xFFFFFFFFFFFFFFC0 | ||
#define OFFSET_64 0x000000000000003F | ||
#define MASK_16 0xFFFFFFFFFFFFFFF0 | ||
#define OFFSET_16 0x000000000000000F | ||
|
||
namespace ttnn::operations::data_movement::rm_reshape{ | ||
|
||
operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output) | ||
{ | ||
tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); | ||
//get datum size | ||
tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); | ||
const uint32_t data_size = input.element_size(); | ||
CoreRange core({0, 0}, {0, 0}); | ||
|
||
tt::tt_metal::Device *device = input.device(); | ||
ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); | ||
ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); | ||
tt::log_debug("row major reshape"); | ||
tt::log_debug("input shape: {}", input_log_shape); | ||
tt::log_debug("output shape: {}", output_log_shape); | ||
tt::log_debug("data size: {}", data_size); | ||
uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; | ||
uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size; | ||
uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128; | ||
uint32_t read_start_page = 0; | ||
uint32_t read_end_page = input_log_shape[-2]; | ||
uint32_t write_start_page = 0; | ||
tt::tt_metal::Buffer *src_buffer = input.buffer(); | ||
tt::tt_metal::Buffer *dst_buffer = output.buffer(); | ||
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); | ||
|
||
const uint32_t cb_size0 = source_read_size_bytes; | ||
const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80; | ||
|
||
uint32_t src0_cb_index = 0; | ||
uint32_t src1_cb_index = 1; | ||
tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_size0*2, {{src0_cb_index, cb_data_format}}) | ||
.set_page_size(src0_cb_index, cb_size0); | ||
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); | ||
tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}}) | ||
.set_page_size(src1_cb_index, cb_size1); | ||
auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); | ||
//set the runtime args | ||
//set the compile time args | ||
uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; | ||
std::vector<uint32_t> compile_time_args = { | ||
(std::uint32_t) src0_is_dram, | ||
(std::uint32_t) (source_page_size_bytes%64==0) ? 1 : 0, | ||
(std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0, | ||
(std::uint32_t) (dest_page_size_bytes%16==0) ? 1 : 0, | ||
}; | ||
|
||
tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( | ||
program, | ||
"ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp", | ||
core, | ||
tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); | ||
std::vector<uint32_t> reader_runtime_args = { | ||
src_buffer->address(), | ||
dst_buffer->address(), | ||
source_page_size_bytes, | ||
dest_page_size_bytes, | ||
source_read_size_bytes, | ||
read_start_page, | ||
read_end_page, | ||
write_start_page, | ||
src0_cb_index, | ||
src1_cb_index | ||
}; | ||
tt::tt_metal::SetRuntimeArgs( | ||
program, | ||
reader_kernel_id, | ||
core, | ||
reader_runtime_args | ||
); | ||
return {.program=std::move(program)}; | ||
} | ||
}; // namespace ttnn::operations::data_movement::rm_reshape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works now, right? "test_broken_reshape" is such a confusing name :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied the code from the issue, I am going to rename the test to test_previously_failing_test