-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#15572: Rewrite of Reshape OP from scratch #15572
Merged
Merged
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
5d343d4
#15269: reshape fully on device now
jvegaTT e2bbda5
#15558 edited comment to mention this issue
jvegaTT 43b0bd2
#0: move tt_memmove to common library and ensure tilize/untilize is o…
jvegaTT 8a94c1b
#0: added corrector for implied shape dimensions
jvegaTT 4fd8c45
#13889: Added test to prove this issue is resolved
jvegaTT 285d1fa
#12153: Adding test to verify that issue is resolved
jvegaTT 7d40412
#15048: being more careful about bandaid for issues #15137 and #13338
jvegaTT 7335c11
#14676: Adding test to verify that this issue is resolved
jvegaTT 6110d47
#0: adding libraries for memmove to common
jvegaTT cbc2830
#14513: Adding test to prove issue is resolved
jvegaTT 98c4b01
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT 055dc0f
#15269: added multi-core support
jvegaTT 1d058c7
Merge branch 'jvega/reshape_rm_on_device' of github.com:tenstorrent/t…
jvegaTT fac56d8
#0: addressing PR comments
jvegaTT 90b6255
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT 7ff95d2
#0: small oops
jvegaTT 2edceb5
#15269: Host code optimizations
jvegaTT ad28b55
#15269: Move compute buffers to compile time
jvegaTT 5501324
#15269: adding override_runtime_args_callback
jvegaTT 109615f
#15269: improve the tt_memmove to use read or write noc as per user n…
jvegaTT 631b1c5
#15269: improve the tt_memmove to use read or write datamover
jvegaTT 10de9a2
#15269: add broken multi risk code
jvegaTT b709d9c
#15269: employing pow2 optimization
jvegaTT 4fa2bf3
#15269: Added optimization to do one less copy on aligned only transfers
jvegaTT 72f2d5c
#15269: added packet form of noc_async_read and write
jvegaTT 09482f3
#15269: further small optimizations
jvegaTT 3e30253
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT 323ace7
#15702: Added a skip for grayskull due to issue 15702
jvegaTT 81c282d
#15269: updating mnist device perf targets
jvegaTT 82c0eb7
#15269: updating other perf targets
jvegaTT d835144
#15269: removing broken unused kernel code
jvegaTT 16dd8ad
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT dad9a4f
#0: pre commit formatting change
jvegaTT 0ea118b
#0 addressing artem PR review changes on 15572 PR
jvegaTT 66a28c1
#15269: updating vgg device targets
jvegaTT 0e411b0
#0: addressing austin PR review changes on 15572 PR
jvegaTT b8ce2b0
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
ttnn/cpp/ttnn/operations/data_movement/reshape_view/device/device/rm_reshape_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
/* | ||
Function reads from RM and writes to RM | ||
|
||
Assumptions: | ||
|
||
Compile arguments | ||
0. src0_is_dram: 1 if source is dram else 0 | ||
1. read_size_is_pow2: 1 if read size is power of 2 else 0 | ||
2. log_base_2_of_page_size: log base 2 of page size | ||
3. write_size_is_pow2: 1 if write size is power of 2 else 0 | ||
4. log_base_2_of_page_size: log base 2 of page size | ||
5. needs_read_allignment: 1 if read needs allignment else 0 | ||
//Needed if BRAM and page size is not multiple of 64 bytes | ||
|
||
Runtime arguments | ||
0. src_addr: source address | ||
1. dst_addr: destination address | ||
2. source_page_size_bytes: source page size in bytes | ||
3. dest_page_size_bytes: destination page size in bytes | ||
4. source_read_size_bytes: source read size in bytes | ||
5. read_start_page: read start page | ||
6. read_end_page: read end page | ||
7. write_start_page: write start page | ||
*/ | ||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
#include "debug/dprint.h" // required in all kernels using DPRINT | ||
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" | ||
|
||
void kernel_main() { | ||
//We are guranteed to be in 2D going to 2D | ||
|
||
const uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t dst_addr = get_arg_val<uint32_t>(1); | ||
//If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B) | ||
const uint32_t source_read_size_bytes = get_arg_val<uint32_t>(2); | ||
const uint32_t read_start_page = get_arg_val<uint32_t>(3); | ||
const uint32_t read_end_page = get_arg_val<uint32_t>(4); | ||
const uint32_t write_start_page = get_arg_val<uint32_t>(5); | ||
const uint32_t write_start_offset = get_arg_val<uint32_t>(6); | ||
const uint32_t nop = get_arg_val<uint32_t>(9); | ||
|
||
constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; | ||
#define src_aligned_to_64 get_compile_time_arg_val(1) == 1 | ||
#define src_aligned_to_16 get_compile_time_arg_val(2) == 1 | ||
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); | ||
constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); | ||
constexpr uint32_t source_page_size_bytes = get_compile_time_arg_val(5); | ||
constexpr uint32_t dest_page_size_bytes = get_compile_time_arg_val(6); | ||
#define source_page_is_pow_2 get_compile_time_arg_val(7) == 1 | ||
constexpr uint32_t source_page_pow_2 = get_compile_time_arg_val(8); | ||
#define dest_page_is_pow_2 get_compile_time_arg_val(9) == 1 | ||
constexpr uint32_t dest_page_pow_2 = get_compile_time_arg_val(10); | ||
|
||
//Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core | ||
if (nop == 1) | ||
{ | ||
return; | ||
} | ||
#if source_page_is_pow_2 | ||
const InterleavedPow2AddrGen<tensor_is_dram> s = { | ||
.bank_base_address = src_addr, .log_base_2_of_page_size = source_page_pow_2}; | ||
#else | ||
const InterleavedAddrGen<tensor_is_dram> s = { | ||
.bank_base_address = src_addr, | ||
.page_size = source_page_size_bytes | ||
}; | ||
#endif | ||
#if dest_page_is_pow_2 | ||
const InterleavedPow2AddrGen<tensor_is_dram> d = { | ||
.bank_base_address = dst_addr, .log_base_2_of_page_size = dest_page_pow_2}; | ||
#else | ||
const InterleavedAddrGen<tensor_is_dram> d = { | ||
.bank_base_address = dst_addr, | ||
.page_size = dest_page_size_bytes | ||
}; | ||
#endif | ||
|
||
uint32_t read_offset = 0; | ||
uint32_t write_page = write_start_page; | ||
uint32_t readable = 0; | ||
uint32_t end_to_write = 0; | ||
uint32_t transaction = 0; | ||
uint32_t writable = dest_page_size_bytes - write_start_offset; | ||
//cb_id_in0 is a CB source_read_size_bytes page size, 1 page | ||
//cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page | ||
cb_reserve_back(cb_id_in0, 1); | ||
cb_reserve_back(cb_id_in1, 1); | ||
const uint32_t source_buffer = get_write_ptr(cb_id_in0); | ||
const uint32_t dest_buffer = get_write_ptr(cb_id_in1); | ||
cb_push_back(cb_id_in1, 1); | ||
cb_push_back(cb_id_in0, 1); | ||
|
||
uint64_t dst_noc_addr = get_noc_addr(write_page, d); | ||
uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset; | ||
uint64_t begin_write_offset = write_offset; | ||
constexpr bool can_be_clean = ((source_page_size_bytes % 16) == 0 && (dest_page_size_bytes % 16) == 0); | ||
uint64_t dst_noc_addr_offset = 0; | ||
for (uint32_t i = read_start_page; i < read_end_page; i++) { | ||
//Read from source | ||
uint64_t src_noc_addr = s.get_noc_addr(i,0); | ||
|
||
#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16)) | ||
//Aligned to 64 bytes or 16 bytes but L1 | ||
tt::data_movement::common::enhanced_noc_async_read<source_page_size_bytes, false>( | ||
src_noc_addr, source_buffer, source_page_size_bytes); | ||
read_offset = 0; | ||
#elif (tensor_is_dram) | ||
//DDR but not alligned to 64 (potentially also not alligned to 16) | ||
tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>( | ||
src_noc_addr & MASK_64, source_buffer, source_read_size_bytes); | ||
read_offset = src_noc_addr&OFFSET_64; | ||
#else | ||
//L1 but not alligned to 16 | ||
tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>( | ||
src_noc_addr & MASK_16, source_buffer, source_read_size_bytes); | ||
read_offset = src_noc_addr&OFFSET_16; | ||
#endif | ||
readable = source_page_size_bytes; | ||
noc_async_read_barrier(); | ||
|
||
//Write to dest | ||
while (readable > 0) | ||
{ | ||
noc_async_write_barrier(); | ||
if (readable < writable) | ||
{ | ||
if constexpr (can_be_clean) { | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); | ||
dst_noc_addr_offset = dst_noc_addr_offset + readable; | ||
} else { | ||
tt::data_movement::common::tt_memmove<false, true, false, dest_page_size_bytes>( | ||
dest_buffer + write_offset, source_buffer + read_offset, readable); | ||
if (i == read_end_page - 1) { | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
dest_buffer + begin_write_offset, dst_noc_addr, end_to_write); | ||
return; | ||
} | ||
write_offset = write_offset + readable; | ||
end_to_write = end_to_write + readable; | ||
} | ||
writable = writable - readable; | ||
readable = 0; | ||
|
||
} | ||
else if (readable == writable) | ||
{ | ||
if constexpr (can_be_clean) { | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable); | ||
} else { | ||
tt::data_movement::common::tt_memmove<false, false, false, dest_page_size_bytes>( | ||
dest_buffer + write_offset, source_buffer + read_offset, readable); | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); | ||
} | ||
dst_noc_addr_offset = 0; | ||
|
||
writable = dest_page_size_bytes; | ||
readable = 0; | ||
if (i == read_end_page - 1) { | ||
return; | ||
} | ||
write_page++; | ||
dst_noc_addr = get_noc_addr(write_page, d); | ||
if constexpr (!can_be_clean) { | ||
end_to_write = 0; | ||
write_offset = dst_noc_addr & OFFSET_16; | ||
begin_write_offset = write_offset; | ||
} | ||
} | ||
else | ||
{ | ||
if constexpr (can_be_clean) { | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, writable); | ||
} else { | ||
tt::data_movement::common::tt_memmove<false, false, false, dest_page_size_bytes>( | ||
dest_buffer + write_offset, source_buffer + read_offset, writable); | ||
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>( | ||
dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes); | ||
} | ||
// writable < readable | ||
readable = readable - writable; | ||
read_offset = read_offset + writable; | ||
write_page++; | ||
dst_noc_addr_offset = 0; | ||
dst_noc_addr = get_noc_addr(write_page, d); | ||
if constexpr (!can_be_clean) { | ||
end_to_write = 0; | ||
write_offset = dst_noc_addr & OFFSET_16; | ||
begin_write_offset = write_offset; | ||
} | ||
writable = dest_page_size_bytes; | ||
} | ||
} | ||
} | ||
return; | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a perf drop or increase (lower or higher is better)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perf increased but the nature of the test makes it report a decrease. The device test currently uses tracy to record how long it runs on the device during the duration of the model. However, we are currently making the really expensive on-host reshapes (which waste time in moving data in and out of the device) run on device instead which has a significant end-to-end perf improvement. However this also means device run-time goes up since the device is doing more and the host is doing less. The very expensive memory copy is not being reflected. That is why we also ran full model experiments and saw that we significantly improved perf when taken in a wholistic approach. We do however need to update these targets to reflect the new device runtime though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
freaking excellent, great job @jvegaTT !