Skip to content

Commit

Permalink
Add transpose WH sharded, generalize row major permute when N > 4, an…
Browse files Browse the repository at this point in the history
…d do a minor refactor of ttnn::permute (#15881)

### Ticket
#14790 add transpose wh sharded implementation when shard shape < height
dimension
#15165 add N-d permute with width dimension
#15589 correct permute dimensionality when less than 4D
#15750 remove the composite flag from permute
#12550 re-enable some permute tests for blackhole
#12349 re-enable working transpose tests for blackhole
#16066 disable test uniform as it's stochastic

### Problem description
This PR addresses several permute and transpose problems all at once

- Transpose WH sharded does not currently work when the shard shape is
less than the height
- Permute on greater than 4 dimensions does not work when moving width
around (for both tiled and RM)
- The Permute kernel when width doesn't change is single core
- Permute has an unclean API in which we have a composite flag that is
not generically applicable
- Permute on less than 4 dimensions gets an incorrect output shape in
cases where it's a no-op
- Permute tests are disabled for BH due to LLK issues
- Transpose tests are disabled for BH due to LLK issues

### What's changed
- Add transpose WH sharded implementation for when shard shape is less
than the height dim (outputs a block sharded output)
- Add an N-d permute kernel that works generically on any row major
input. We have to call a global init each loop of the compute kernel as
transpose sets some registers that aren't cleared (there's no
transpose_uninit). This results in bad pcc when there's more than one
loop. For GS/BH, even the global init doesn't solve the problem so the
test is disabled. For Tiled, we need 5D untilize/tilize. This increases
sweeps coverage for permute from **50%** to **86%**
- For the optimized case where Permute's width dimension is not
shuffled, make the kernel multicore
- Remove composite flag that is default set to to make permute
non-generic. This has caused forge models to have bad pcc as they were
not aware of this optional argument.
- Refactor ttnn::permute to add nop checks and correct shape
calculations
- Re-enable permute and transpose tests for blackhole

When replacing variants of transpose with this RM permute kernel, a lot
of tests on BH/GS failed, so I will do that in a follow-up to address.
The LLK issues are causing pains there. If we get N-d untilize/tilize
support and once the LLK issues are fixed, permute should have the
ability to be generic. The remaining issues for the pytorch 2.0 sweeps
after the untilize/tilize fix are the CB overflow on transpose wh, which
should be fixed out of the box when we replace the kernel that is used
(which I am not doing in this PR since it doesn't work for GS/BH atm).

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12367177499/job/34547311782
(failing test is failing on main)
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12367175575
- [x] Model regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12357119737
- [x] Device performance regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12357115316
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sjameelTT authored Dec 17, 2024
1 parent c45a5b7 commit b80a975
Show file tree
Hide file tree
Showing 21 changed files with 1,041 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
]


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("input_shapes, permute_args", params)
def test_run_permute_test(input_shapes, permute_args, device, function_level_defaults):
datagen_func = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def test_transpose_wh_sharded_program_cache(dtype, device, use_program_cache):
)


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_grayskull("Grayskull has pcc issue when transpose used untilize")
@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("c", [1])
Expand Down Expand Up @@ -333,7 +332,6 @@ def test_tranpose_hw_rm_with_padding(device, n, c, h, w):
assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999)


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_grayskull("Grayskull has pcc issue when transpose used untilize")
@pytest.mark.parametrize("n", [16])
@pytest.mark.parametrize("c", [128])
Expand Down Expand Up @@ -369,13 +367,10 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache):
memory_config=ttnn.L1_MEMORY_CONFIG,
)
activation_pyt_padded = ttnn.transpose(activation_pyt_padded, 2, 3, memory_config=ttnn.L1_MEMORY_CONFIG)
activation_pyt_padded_out = ttnn.to_memory_config(activation_pyt_padded, ttnn.L1_MEMORY_CONFIG)
activation_pyt_padded_out = ttnn.from_device(activation_pyt_padded_out)
activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded_out)
activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded)
assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999)


@skip_for_blackhole("Mismatching on BH, see #12349")
@skip_for_grayskull("Grayskull has pcc issue when transpose used untilize")
@pytest.mark.parametrize("n", [16])
@pytest.mark.parametrize("c", [128])
Expand All @@ -402,7 +397,7 @@ def test_tranpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cache
@pytest.mark.parametrize("c", [224])
@pytest.mark.parametrize("h", [16])
@pytest.mark.parametrize("w", [112])
def test_tranpose_hw_sharded_rm(device, n, c, h, w):
def test_transpose_hw_sharded_rm(device, n, c, h, w):
torch.manual_seed(2005)
torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(2, 3)
Expand Down Expand Up @@ -469,7 +464,6 @@ def run_tranpose_hw_sharded_rm_with_program_cache(device, n, c, h, w):
assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize("n", [16])
@pytest.mark.parametrize("c", [128])
@pytest.mark.parametrize("h", [128])
Expand Down Expand Up @@ -581,7 +575,6 @@ def run_tranpose_hc_sharded(device, n, c, h, w, grid_size):
assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"n, c, h, w, grid_size",
[
Expand Down Expand Up @@ -1011,3 +1004,76 @@ def test_transpose_forge_hc(device, b, h, w, dim0, dim1):
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("c", [1])
@pytest.mark.parametrize("h", [256])
@pytest.mark.parametrize("w", [32])
def test_tranpose_hw_sharded_tiled_8_cores(device, n, c, h, w):
torch.manual_seed(2005)
torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(2, 3)
tt_input_tensor = ttnn.from_torch(
torch_input_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)

sharded_mem_config = ttnn.create_sharded_memory_config(
(32, 32),
core_grid=ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 6)),
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(1, 0)),
}
),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.COL_MAJOR,
use_height_and_width_as_shard_shape=True,
)
tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config)

tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config)
tt_output_tensor = ttnn.to_torch(tt_output_tensor)

assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999)


@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("c", [1])
@pytest.mark.parametrize("h", [224])
@pytest.mark.parametrize("w", [32])
def test_tranpose_hw_sharded_tiled_n_cores(device, n, c, h, w):
torch.manual_seed(2005)
torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor.transpose(2, 3)
tt_input_tensor = ttnn.from_torch(
torch_input_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)

sharded_mem_config = ttnn.create_sharded_memory_config(
(32, 32),
core_grid=ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, h // 32 - 1)),
}
),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.COL_MAJOR,
use_height_and_width_as_shard_shape=True,
)
tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config)

tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config)
tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG)
tt_output_tensor = ttnn.from_device(tt_output_tensor)
tt_output_tensor = ttnn.to_torch(tt_output_tensor)

assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999)
101 changes: 100 additions & 1 deletion tests/ttnn/unit_tests/operations/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch

import ttnn
import itertools

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import is_blackhole
from models.utility_functions import is_blackhole, is_grayskull, skip_for_grayskull, skip_for_blackhole


@pytest.mark.parametrize("h", [32])
Expand Down Expand Up @@ -171,3 +172,101 @@ def test_permute_pad_value(device, pad_value):
assert ttnn.to_torch(a) == float("-inf")
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


def generate_permutations(N):
"""
Generator function that yields all permutations of tuples with values 0 to N-1.
:param N: The number defining the range of values (0 to N-1).
:yield: Tuples representing each permutation.
"""
for perm in itertools.permutations(range(N)):
yield perm


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
@pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)])
@pytest.mark.parametrize("perm", generate_permutations(5))
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_width(shape, perm, memory_config, dtype, device):
torch.manual_seed(2005)
input_a = torch.randn(shape)
torch_output = torch.permute(input_a, perm)

tt_input = ttnn.from_torch(
input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config
)

tt_output = ttnn.permute(tt_input, perm)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)])
@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)])
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_5d_blocked(shape, perm, memory_config, dtype, device):
torch.manual_seed(520)
input_a = torch.randn(shape)

torch_output = torch.permute(input_a, perm)

tt_input = ttnn.from_torch(
input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config
)

tt_output = ttnn.permute(tt_input, perm)
tt_output = ttnn.to_torch(tt_output)

assert_with_pcc(torch_output, tt_output, 0.9999)


@skip_for_blackhole("tilize_block gives bad pcc after second iteration")
@skip_for_grayskull("tilize_block gives bad pcc after second iteration")
def test_permute_nd(device):
torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
output_tensor = ttnn.permute(input_tensor, (0, 2, 4, 3, 5, 1))
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, (0, 2, 4, 3, 5, 1))
assert_with_pcc(torch_output, output_tensor, 0.9999)


def test_permute_squeeze(device):
ones = ttnn.ones((1, 1, 3))
tensor = ttnn.to_device(ones, device)
out = ttnn.permute(tensor, (0, 1, 2))
assert_with_pcc(ttnn.to_torch(out), ttnn.to_torch(ones), 0.9999)


@pytest.mark.parametrize("shape", [(1, 49, 768)])
@pytest.mark.parametrize("perm", generate_permutations(3))
@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32])
def test_permute_3D(shape, perm, layout, memory_config, dtype, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Grayskull doesn't support float32")
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=dtype, memory_config=memory_config)
output_tensor = ttnn.permute(input_tensor, perm)
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, perm)
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)


def test_nil_volume_permute(device):
torch_tensor = torch.rand([1, 0, 30, 32], dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.permute(input_tensor, (0, 1, 3, 2))
output_tensor = ttnn.to_torch(output_tensor)
torch_output = torch.permute(torch_tensor, (0, 1, 3, 2))
assert torch_output.shape == output_tensor.shape
assert_with_pcc(torch_output, output_tensor, 0.9999)
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/operations/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, m
)


@pytest.mark.skip("#16066: Undefined behaviour. It will fail on some runs and pass on others since it's stochastic.")
@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"shape",
Expand Down
39 changes: 39 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,43 @@ template <uint32_t a, uint32_t b>
FORCE_INLINE constexpr uint32_t round_up() {
return b * div_up<a, b>();
}

// Function template to swap two elements in a uint32_t array
template <size_t N>
FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) {
// Perform the swap
uint32_t temp = array[i];
array[i] = array[j];
array[j] = temp;
}

// 2D Transpose function for debug use in reader/writer kernels
FORCE_INLINE void transpose_2d(
uint32_t input_l1_addr,
uint32_t output_l1_addr,
uint32_t X,
uint32_t W,
uint32_t element_size,
uint32_t input_page_size,
uint32_t output_page_size) {
volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t*>(input_l1_addr);
volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t*>(output_l1_addr);
// transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner
// each element is element_size bytes
// each row is W elements, and each row is separated by input_page_size bytes
// each output row is X elements, and each row is separated by output_page_size bytes

for (uint32_t x = 0; x < X; ++x) {
for (uint32_t w = 0; w < W; ++w) {
// Compute the input and output addresses
uint32_t input_addr = x * input_page_size + w * element_size;
uint32_t output_addr = w * output_page_size + x * element_size;
// Copy the element - do we have memcpy? use this for now
for (uint32_t i = 0; i < element_size; ++i) {
output_ptr[output_addr + i] = input_ptr[input_addr + i];
}
}
}
}

} // namespace tt::data_movement::common
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/transpose_wh.h"
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/untilize.h"
#include "compute_kernel_api/pack_untilize.h"

namespace NAMESPACE {
void MAIN {
constexpr uint32_t x_block_size = get_compile_time_arg_val(0);
constexpr uint32_t w_block_size = get_compile_time_arg_val(1);

uint32_t num_blocks = get_arg_val<uint32_t>(0);

constexpr auto cb_in = tt::CBIndex::c_0;
constexpr auto cb_tilize = tt::CBIndex::c_1;
constexpr auto cb_out = tt::CBIndex::c_2;

unary_op_init_common(cb_in, cb_out);

for (uint32_t n = 0; n < num_blocks; n++) {
// tilize input via unpack and then pack
tilize_init_short(cb_in, 1);

cb_wait_front(cb_in, x_block_size);
cb_reserve_back(cb_tilize, 1);

tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize

// tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit?
cb_push_back(cb_tilize, 1);
cb_pop_front(cb_in, x_block_size);

tilize_uninit(cb_in);

// transpose input
cb_wait_front(cb_tilize, 1);
transpose_wh_init_short(cb_tilize);
pack_untilize_dst_init_short<1>(cb_out);

tile_regs_acquire();
transpose_wh_tile(cb_tilize, 0, 0); // transpose call
tile_regs_commit();

// pack and untilize
cb_reserve_back(cb_out, w_block_size);

tile_regs_wait();
pack_untilize_dst<1>(cb_out); // pack call
tile_regs_release();

cb_push_back(cb_out, w_block_size);

cb_wait_front(cb_out, w_block_size);
pack_untilize_uninit(cb_out);

cb_pop_front(cb_tilize, 1);
}
}
} // namespace NAMESPACE
Loading

0 comments on commit b80a975

Please sign in to comment.