Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15572: Rewrite of Reshape OP from scratch #15572

Merged
merged 37 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5d343d4
#15269: reshape fully on device now
jvegaTT Nov 28, 2024
e2bbda5
#15558 edited comment to mention this issue
jvegaTT Nov 28, 2024
43b0bd2
#0: move tt_memmove to common library and ensure tilize/untilize is o…
jvegaTT Nov 29, 2024
8a94c1b
#0: added corrector for implied shape dimensions
jvegaTT Nov 29, 2024
4fd8c45
#13889: Added test to prove this issue is resolved
jvegaTT Nov 29, 2024
285d1fa
#12153: Adding test to verify that issue is resolved
jvegaTT Nov 29, 2024
7d40412
#15048: being more careful about bandaid for issues #15137 and #13338
jvegaTT Nov 29, 2024
7335c11
#14676: Adding test to verify that this issue is resolved
jvegaTT Nov 29, 2024
6110d47
#0: adding libraries for memmove to common
jvegaTT Nov 29, 2024
cbc2830
#14513: Adding test to prove issue is resolved
jvegaTT Nov 29, 2024
98c4b01
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Nov 29, 2024
055dc0f
#15269: added multi-core support
jvegaTT Nov 29, 2024
1d058c7
Merge branch 'jvega/reshape_rm_on_device' of github.com:tenstorrent/t…
jvegaTT Nov 29, 2024
fac56d8
#0: addressing PR comments
jvegaTT Nov 29, 2024
90b6255
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Nov 29, 2024
7ff95d2
#0: small oops
jvegaTT Nov 29, 2024
2edceb5
#15269: Host code optimizations
jvegaTT Dec 2, 2024
ad28b55
#15269: Move compute buffers to compile time
jvegaTT Dec 2, 2024
5501324
#15269: adding override_runtime_args_callback
jvegaTT Dec 2, 2024
109615f
#15269: improve the tt_memmove to use read or write noc as per user n…
jvegaTT Dec 2, 2024
631b1c5
#15269: improve the tt_memmove to use read or write datamover
jvegaTT Dec 2, 2024
10de9a2
#15269: add broken multi risk code
jvegaTT Dec 2, 2024
b709d9c
#15269: employing pow2 optimization
jvegaTT Dec 2, 2024
4fa2bf3
#15269: Added optimization to do one less copy on aligned only transfers
jvegaTT Dec 3, 2024
72f2d5c
#15269: added packet form of noc_async_read and write
jvegaTT Dec 3, 2024
09482f3
#15269: further small optimizations
jvegaTT Dec 3, 2024
3e30253
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 3, 2024
323ace7
#15702: Added a skip for grayskull due to issue 15702
jvegaTT Dec 4, 2024
81c282d
#15269: updating mnist device perf targets
jvegaTT Dec 4, 2024
82c0eb7
#15269: updating other perf targets
jvegaTT Dec 4, 2024
d835144
#15269: removing broken unused kernel code
jvegaTT Dec 4, 2024
16dd8ad
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 4, 2024
dad9a4f
#0: pre commit formatting change
jvegaTT Dec 4, 2024
0ea118b
#0 addressing artem PR review changes on 15572 PR
jvegaTT Dec 4, 2024
66a28c1
#15269: updating vgg device targets
jvegaTT Dec 4, 2024
0e411b0
#0: addressing austin PR review changes on 15572 PR
jvegaTT Dec 4, 2024
b8ce2b0
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def test_reshape_tile_layout_only_change_shape(device):
((1, 1445, 192), (1445, 192)),
((1, 256), (1, 1, 256)),
((16, 1, 32), (16, 1, 32)),
((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889
((2888, 49, 96), (8, 19, 19, 7, 7, 96)), # issue 12153
((128, 1, 1, 128), (128, 128)), # issue 14676
((5, 4, 208, 156), (3, 13, 8, 2080)), # issue 14513
],
)
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
Expand All @@ -304,6 +308,26 @@ def test_reshape_tile_with_padding(input_shape, output_shape, layout, device):
ttnn_output = ttnn.reshape(input_tensor, output_shape)
assert layout == ttnn_output.layout
output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_result, output, 0.9999)


# issue 15048
def test_broken_reshape(device):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works now, right? "test_broken_reshape" is such a confusing name :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied the code from the issue, I am going to rename the test to test_previously_failing_test

src_shape = (1, 56, 56, 64)
target_shape = (1, 1, 56 * 56, 64)
torch_input_tensor = torch.randn(src_shape, dtype=torch.bfloat16)
torch_result = torch_input_tensor.reshape(target_shape)

input_tensor = ttnn.from_torch(
torch_input_tensor,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

ttnn_output = ttnn.reshape(input_tensor, target_shape)
output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)

Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,45 @@
// It's best to copy and paste the functions in rather than include the header as code size will likely explode
// Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking
// issues
#include <stdio.h>
#include <cstring>
#define MASK_64 0xFFFFFFFFFFFFFFC0
#define OFFSET_64 0x000000000000003F
#define MASK_16 0xFFFFFFFFFFFFFFF0
#define OFFSET_16 0x000000000000000F

namespace tt::data_movement::common {

template <bool guaranteed_16B_alligned, bool copy_async>
FORCE_INLINE
void tt_memmove (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Can you announce this utility function in TT developers. This might be useful for others.

const uint32_t dst_l1_addr,
const uint64_t src_l1_addr,
const uint32_t bytes)
{
//Function performs a memory copy between two l1 addresses in the local core
//Uses noc_async_read when possible to copy the data over
//Set guaranteed 16B alligned to true if the source and destination are externally guaranteed to be 16B alligned (dangerous)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also not be hardcoding alignment values like 16 and 64 since there's no guarantee they're always that value across arches, these can be queried and passed from host as compile time args or accessed in kernels ex L1_ALIGNMENT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit annoying to fix as I am also using MASK and OFFSET that is hardcoded to what should be used if the shift 16 and 64 respectively and they are also used in other files as this is the common code for all data mover ops. I can't take the value as compile time arg as this is in common. However because it is in common I think updating it should a new arch use a new alignment is not very difficult.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Come to think of it, I think the cleanest fix would be to define a DDR_ALIGNMENT at the same location L1_ALIGNMENT is defined that is set automatically at compile time depending on the ARCH type. The OPs would then use that constant value instead which is auto updated to match the architecture. However that would not be the solution for this PR

Copy link
Contributor Author

@jvegaTT jvegaTT Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be acceptable to let this merge in as is for now and then I will post an issue where I can fix it in a cleaner way, just in the interest of getting this in as soon as possible since the change would alter multiple files, require a new function, and I am not comfortable doing that without re-running the pipelines that take hours

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DRAM_ALIGNMENT is also already defined (is this what you meant by DDR_ALIGNMENT?). It's fine if you want to address later, I'm not blocking the pr on any of these comments but wanted to make sure this was considered, as there is currently work being done regarding what alignments are being used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, good to know. Yeah will do a subsequent PR since I do want to avoid having to re-visit alignment when the next arch comes in

//Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add a noc_async_read_barrier to synchronize later
if constexpr (guaranteed_16B_alligned)
{
noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes);
if constexpr (!copy_async) {noc_async_read_barrier();}
}
else
{
if ((dst_l1_addr&OFFSET_16) == (src_l1_addr&OFFSET_16))
{
noc_async_read(get_noc_addr(src_l1_addr),dst_l1_addr, bytes);
if constexpr (!copy_async) {noc_async_read_barrier();}
}
else
{
memmove((void *)(dst_l1_addr), (void *)(src_l1_addr), (size_t) (bytes));
}
}
}

// this function is useful for converting bfloat16 values to float32
FORCE_INLINE float bfloat16_to_float32(uint16_t bfloat16_data) {
uint32_t bits = static_cast<uint32_t>(bfloat16_data) << 16;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0


/*
Function reads from RM and writes to RM

Assumptions:

Compile arguments
0. src0_is_dram: 1 if source is dram else 0
1. read_size_is_pow2: 1 if read size is power of 2 else 0
2. log_base_2_of_page_size: log base 2 of page size
3. write_size_is_pow2: 1 if write size is power of 2 else 0
4. log_base_2_of_page_size: log base 2 of page size
5. needs_read_allignment: 1 if read needs allignment else 0
//Needed if BRAM and page size is not multiple of 64 bytes

Runtime arguments
0. src_addr: source address
1. dst_addr: destination address
2. source_page_size_bytes: source page size in bytes
3. dest_page_size_bytes: destination page size in bytes
4. source_read_size_bytes: source read size in bytes
5. read_start_page: read start page
6. read_end_page: read end page
7. write_start_page: write start page
*/
#include <stdint.h>
#include "dataflow_api.h"
#include "debug/dprint.h" // required in all kernels using DPRINT
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp"

void kernel_main() {
//We are guranteed to be in 2D going to 2D

const uint32_t src_addr = get_arg_val<uint32_t>(0);
const uint32_t dst_addr = get_arg_val<uint32_t>(1);
const uint32_t source_page_size_bytes = get_arg_val<uint32_t>(2);
const uint32_t dest_page_size_bytes = get_arg_val<uint32_t>(3);
//If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B)
const uint32_t source_read_size_bytes = get_arg_val<uint32_t>(4);
const uint32_t read_start_page = get_arg_val<uint32_t>(5);
const uint32_t read_end_page = get_arg_val<uint32_t>(6);
const uint32_t write_start_page = get_arg_val<uint32_t>(7);
//cb_id_in0 is a circular buffer with 1 source_page_size_bytes page if no alignment needed
//source_read_size_bytes otherwise
const uint32_t cb_id_in0 = get_arg_val<uint32_t>(8);
//cb_id_in1 is a circular buffer with 1 dest_page_size_bytes+16 (rounded up to next 64B) page
const uint32_t cb_id_in1 = get_arg_val<uint32_t>(9);


constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1;
#define src_aligned_to_64 get_compile_time_arg_val<uint32_t>(1) == 1
#define src_aligned_to_16 get_compile_time_arg_val<uint32_t>(2) == 1
#define dst_aligned_to_16 get_compile_time_arg_val<uint32_t>(3) == 1


const InterleavedAddrGen<tensor_is_dram> s = {
.bank_base_address = src_addr,
.page_size = source_page_size_bytes
};

const InterleavedAddrGen<tensor_is_dram> d = {
.bank_base_address = dst_addr,
.page_size = dest_page_size_bytes
};


uint32_t read_offset = 0;
uint32_t write_page = write_start_page;
uint32_t readable = 0;
uint32_t transaction = 0;
uint32_t writable = dest_page_size_bytes;
//cb_id_in0 is a CB source_read_size_bytes page size, 1 page
//cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page
cb_reserve_back(cb_id_in0, 1);
cb_reserve_back(cb_id_in1, 1);
const uint32_t source_buffer = get_write_ptr(cb_id_in0);
const uint32_t dest_buffer = get_write_ptr(cb_id_in1);

uint64_t dst_noc_addr = get_noc_addr(write_page, d);
#if (dst_aligned_to_16)
uint32_t write_offset = 0;
#else
uint32_t write_offset = dst_noc_addr&OFFSET_16;
uint32_t begin_write_offset = write_offset;
#endif
for (uint32_t i = read_start_page; i <= read_end_page; i++) {
//Read from source
uint64_t src_noc_addr = s.get_noc_addr(i,0);

#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16))
//Aligned to 64 bytes or 16 bytes but L1
noc_async_read(src_noc_addr, source_buffer, source_page_size_bytes);
read_offset = 0;
#elif (tensor_is_dram)
//DDR but not alligned to 64 (potentially also not alligned to 16)
noc_async_read(src_noc_addr&MASK_64, source_buffer, source_read_size_bytes);
read_offset = src_noc_addr&OFFSET_64;
#else
//L1 but not alligned to 16
noc_async_read(src_noc_addr&MASK_16, source_buffer, source_read_size_bytes);
read_offset = src_noc_addr&OFFSET_16;
#endif
readable = source_page_size_bytes;
noc_async_read_barrier();

//Write to dest
while (readable > 0)
{
noc_async_write_barrier();
if (readable < writable)
{
tt::data_movement::common::tt_memmove<false,true>(dest_buffer+write_offset, source_buffer + read_offset, readable);
writable = writable -readable;
write_offset = write_offset + readable;
readable = 0;
}
else if (readable == writable)
{
tt::data_movement::common::tt_memmove<false,false>(dest_buffer+write_offset, source_buffer + read_offset, readable);
#if ((dst_aligned_to_16))
noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes);
#else
noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes);
#endif
writable = dest_page_size_bytes;
readable = 0;
if (i == read_end_page-1)
{
cb_push_back(cb_id_in0, 1);
cb_push_back(cb_id_in1, 1);
return;
}
write_page++;
dst_noc_addr = get_noc_addr(write_page, d);
#if ((dst_aligned_to_16))
write_offset=0;
#else
write_offset = dst_noc_addr&OFFSET_16;
begin_write_offset = write_offset;
#endif
}
else
{
//writable < readable

tt::data_movement::common::tt_memmove<false,false>(dest_buffer+write_offset, source_buffer + read_offset, writable);
#if ((dst_aligned_to_16))
noc_async_write(dest_buffer,dst_noc_addr, dest_page_size_bytes);
#else
noc_async_write(dest_buffer+begin_write_offset,dst_noc_addr, dest_page_size_bytes);
#endif
readable = readable - writable;
read_offset = read_offset + writable;
write_page++;
dst_noc_addr = get_noc_addr(write_page, d);
#if ((dst_aligned_to_16))
write_offset=0;
#else
write_offset = dst_noc_addr&OFFSET_16;
begin_write_offset = write_offset;
#endif
writable = dest_page_size_bytes;
}
}
}
cb_push_back(cb_id_in0, 1);
cb_push_back(cb_id_in1, 1);
return;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <math.h>

#include "ttnn/operations/cb_utils.hpp"
#include "ttnn/operations/math.hpp"
#include "ttnn/operation.hpp"
#include "ttnn/operations/core/work_split/work_split_tilize.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_common.hpp"

#include <optional>
#include <variant>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/core.hpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/types.hpp"
#include "ttnn/decorators.hpp"

#define MASK_64 0xFFFFFFFFFFFFFFC0
#define OFFSET_64 0x000000000000003F
#define MASK_16 0xFFFFFFFFFFFFFFF0
#define OFFSET_16 0x000000000000000F

namespace ttnn::operations::data_movement::rm_reshape{

operation::ProgramWithCallbacks rm_reshape_preparer(const Tensor& input, const Tensor& output)
{
tt::tt_metal::Program program = tt::tt_metal::CreateProgram();
//get datum size
tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype());
const uint32_t data_size = input.element_size();
CoreRange core({0, 0}, {0, 0});

tt::tt_metal::Device *device = input.device();
ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view());
ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view());
tt::log_debug("row major reshape");
tt::log_debug("input shape: {}", input_log_shape);
tt::log_debug("output shape: {}", output_log_shape);
tt::log_debug("data size: {}", data_size);
uint32_t source_page_size_bytes = input_log_shape[-1] * data_size;
uint32_t dest_page_size_bytes = output_log_shape[-1] * data_size;
uint32_t source_read_size_bytes = ((source_page_size_bytes-1) & MASK_64) + 128;
uint32_t read_start_page = 0;
uint32_t read_end_page = input_log_shape[-2];
uint32_t write_start_page = 0;
tt::tt_metal::Buffer *src_buffer = input.buffer();
tt::tt_metal::Buffer *dst_buffer = output.buffer();
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!");

const uint32_t cb_size0 = source_read_size_bytes;
const uint32_t cb_size1 = ((dest_page_size_bytes-1)&MASK_64) + 80;

uint32_t src0_cb_index = 0;
uint32_t src1_cb_index = 1;
tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_size0*2, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_size0);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config);
tt::tt_metal::CircularBufferConfig cb_src1_config = tt::tt_metal::CircularBufferConfig(cb_size1, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, cb_size1);
auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config);
//set the runtime args
//set the compile time args
uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> compile_time_args = {
(std::uint32_t) src0_is_dram,
(std::uint32_t) (source_page_size_bytes%64==0) ? 1 : 0,
(std::uint32_t) (source_page_size_bytes%16==0) ? 1 : 0,
(std::uint32_t) (dest_page_size_bytes%16==0) ? 1 : 0,
};

tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp",
core,
tt::tt_metal::ReaderDataMovementConfig(compile_time_args));
std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(),
dst_buffer->address(),
source_page_size_bytes,
dest_page_size_bytes,
source_read_size_bytes,
read_start_page,
read_end_page,
write_start_page,
src0_cb_index,
src1_cb_index
};
tt::tt_metal::SetRuntimeArgs(
program,
reader_kernel_id,
core,
reader_runtime_args
);
return {.program=std::move(program)};
}
}; // namespace ttnn::operations::data_movement::rm_reshape
Loading
Loading