-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transpose WH sharded, generalize row major permute when N > 4, an…
…d do a minor refactor of ttnn::permute (#15881) ### Ticket #14790 add transpose wh sharded implementation when shard shape < height dimension #15165 add N-d permute with width dimension #15589 correct permute dimensionality when less than 4D #15750 remove the composite flag from permute #12550 re-enable some permute tests for blackhole #12349 re-enable working transpose tests for blackhole #16066 disable test uniform as it's stochastic ### Problem description This PR addresses several permute and transpose problems all at once - Transpose WH sharded does not currently work when the shard shape is less than the height - Permute on greater than 4 dimensions does not work when moving width around (for both tiled and RM) - The Permute kernel when width doesn't change is single core - Permute has an unclean API in which we have a composite flag that is not generically applicable - Permute on less than 4 dimensions gets an incorrect output shape in cases where it's a no-op - Permute tests are disabled for BH due to LLK issues - Transpose tests are disabled for BH due to LLK issues ### What's changed - Add transpose WH sharded implementation for when shard shape is less than the height dim (outputs a block sharded output) - Add an N-d permute kernel that works generically on any row major input. We have to call a global init each loop of the compute kernel as transpose sets some registers that aren't cleared (there's no transpose_uninit). This results in bad pcc when there's more than one loop. For GS/BH, even the global init doesn't solve the problem so the test is disabled. For Tiled, we need 5D untilize/tilize. This increases sweeps coverage for permute from **50%** to **86%** - For the optimized case where Permute's width dimension is not shuffled, make the kernel multicore - Remove composite flag that is default set to to make permute non-generic. This has caused forge models to have bad pcc as they were not aware of this optional argument. - Refactor ttnn::permute to add nop checks and correct shape calculations - Re-enable permute and transpose tests for blackhole When replacing variants of transpose with this RM permute kernel, a lot of tests on BH/GS failed, so I will do that in a follow-up to address. The LLK issues are causing pains there. If we get N-d untilize/tilize support and once the LLK issues are fixed, permute should have the ability to be generic. The remaining issues for the pytorch 2.0 sweeps after the untilize/tilize fix are the CB overflow on transpose wh, which should be fixed out of the box when we replace the kernel that is used (which I am not doing in this PR since it doesn't work for GS/BH atm). ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12367177499/job/34547311782 (failing test is failing on main) - [x] Blackhole Post commit (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12367175575 - [x] Model regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12357119737 - [x] Device performance regression CI testing passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12357115316 - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes
- Loading branch information
Showing
21 changed files
with
1,041 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
...rations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
|
||
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" | ||
#include "compute_kernel_api/transpose_wh.h" | ||
#include "compute_kernel_api/tilize.h" | ||
#include "compute_kernel_api/untilize.h" | ||
#include "compute_kernel_api/pack_untilize.h" | ||
|
||
namespace NAMESPACE { | ||
void MAIN { | ||
constexpr uint32_t x_block_size = get_compile_time_arg_val(0); | ||
constexpr uint32_t w_block_size = get_compile_time_arg_val(1); | ||
|
||
uint32_t num_blocks = get_arg_val<uint32_t>(0); | ||
|
||
constexpr auto cb_in = tt::CBIndex::c_0; | ||
constexpr auto cb_tilize = tt::CBIndex::c_1; | ||
constexpr auto cb_out = tt::CBIndex::c_2; | ||
|
||
unary_op_init_common(cb_in, cb_out); | ||
|
||
for (uint32_t n = 0; n < num_blocks; n++) { | ||
// tilize input via unpack and then pack | ||
tilize_init_short(cb_in, 1); | ||
|
||
cb_wait_front(cb_in, x_block_size); | ||
cb_reserve_back(cb_tilize, 1); | ||
|
||
tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize | ||
|
||
// tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit? | ||
cb_push_back(cb_tilize, 1); | ||
cb_pop_front(cb_in, x_block_size); | ||
|
||
tilize_uninit(cb_in); | ||
|
||
// transpose input | ||
cb_wait_front(cb_tilize, 1); | ||
transpose_wh_init_short(cb_tilize); | ||
pack_untilize_dst_init_short<1>(cb_out); | ||
|
||
tile_regs_acquire(); | ||
transpose_wh_tile(cb_tilize, 0, 0); // transpose call | ||
tile_regs_commit(); | ||
|
||
// pack and untilize | ||
cb_reserve_back(cb_out, w_block_size); | ||
|
||
tile_regs_wait(); | ||
pack_untilize_dst<1>(cb_out); // pack call | ||
tile_regs_release(); | ||
|
||
cb_push_back(cb_out, w_block_size); | ||
|
||
cb_wait_front(cb_out, w_block_size); | ||
pack_untilize_uninit(cb_out); | ||
|
||
cb_pop_front(cb_tilize, 1); | ||
} | ||
} | ||
} // namespace NAMESPACE |
Oops, something went wrong.