Skip to content

Commit

Permalink
Add support for rank-n tensors to tilize and untilize (#15520)
Browse files Browse the repository at this point in the history
### Ticket
#15165

### What's changed
- Adds support for rank-n tensors to the tilize and untilize ops. In
particular this introduces support for 5D tensors.
- We support rank-n tensors by first squeezing to a supported rank <= 4,
then performing the tilize/untilize, then unsqueezing back to the
original shape/rank.
- Adds tests for a handful of 5D cases to both the untilize and tilize
unit test suites.

### Checklist
- [~] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12063111681)
- [x] [Model regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12059576026)
- [x] [Device performance regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12059576677)
- [x] New/Existing tests provide coverage for changes

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
jaykru-tt and github-actions[bot] authored Nov 28, 2024
1 parent 4c7b888 commit efc6f70
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
),
)
def test_run_tilize_test(nb, nc, nh, nw, multicore, device):
nt = nb * nc * nh * nw
shape = [nb, nc, 32 * nh, 32 * nw]
shape = [nb, nc, nh * 32, nw * 32]

inp = torch.rand(*shape).bfloat16()

Expand All @@ -43,3 +42,37 @@ def test_run_tilize_test(nb, nc, nh, nw, multicore, device):
tilized_inp = tilize(inp)
passing = torch.equal(tilized_inp, c)
assert passing


@pytest.mark.parametrize(
"shape",
(
[1, 1, 1, 5, 1],
[1, 1, 1, 4, 2],
[1, 1, 1, 3, 3],
[1, 1, 1, 2, 4],
[1, 1, 1, 1, 5],
[1, 2, 3, 2, 1],
),
)
@pytest.mark.parametrize(
"multicore",
(
False,
True,
),
)
def test_tilize_5d(shape, multicore, device):
# tests that host -> device -> tilize -> untilize -> host is a no-op
shape[-1] *= 32
shape[-2] *= 32

inp = torch.rand(*shape).bfloat16()
a = ttnn.Tensor(
inp,
ttnn.bfloat16,
).to(device)
b = ttnn.tilize(a, use_multicore=multicore)
c = ttnn.untilize(b)
d = c.cpu().to_torch()
assert torch.equal(inp, d)
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,48 @@ def test_run_untilize_test(dtype, nb, nc, nh, nw, device):
passing1 = torch.equal(untilized_inp, c1)

assert passing1


@pytest.mark.parametrize(
"dtype",
(ttnn.bfloat16, ttnn.float32),
ids=["bfloat16", "float"],
)
@pytest.mark.parametrize(
"shape",
(
[1, 1, 1, 32 * 5, 32 * 1],
[1, 1, 1, 32 * 4, 32 * 2],
[1, 1, 1, 32 * 3, 32 * 3],
[1, 1, 1, 32 * 2, 32 * 4],
[1, 1, 1, 32 * 1, 32 * 5],
[1, 2, 3, 32 * 2, 32 * 1],
),
)
def test_run_untilize_5d(dtype, shape, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Skipping float32 tests on Grayskull")

torch.set_printoptions(precision=3, sci_mode=False, linewidth=3000, threshold=10000, edgeitems=128)

torch.manual_seed(10)

if dtype == ttnn.float32:
inp = torch.rand(*shape).float() * 1000.0
else:
inp = torch.rand(*shape).bfloat16()

a = ttnn.from_torch(inp, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT)

out_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1)

our_untilized = ttnn.untilize(a, memory_config=out_mem_config, use_multicore=True, use_pack_untilize=True)
our_untilized = our_untilized.cpu().to_torch()

if dtype == ttnn.float32:
passing1, output = comp_pcc(inp, our_untilized, 0.999999)
logger.info(output)
else:
passing1 = torch.equal(inp, our_untilized)

assert passing1
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ inline void concat_db_print(bool condition, const std::string& msg) {
namespace ttnn {
namespace operations {
namespace data_movement {
using ConcatArgs = std::tuple<const std::vector<ttnn::Tensor>&, int, unsigned int>;
using OwnedConcatArgs = std::tuple<std::vector<ttnn::Tensor>, int, unsigned int>;

using MassagedConcat = MassagedOperation<ttnn::Tensor, const std::vector<ttnn::Tensor>&, int, unsigned int>;
Expand Down
46 changes: 36 additions & 10 deletions ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,53 @@
#include "device/tilize_op.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {
using OwnedTilizeArgs = std::tuple<ttnn::Tensor>;
using BaseTilizeType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedTilize = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedTilizeParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

MassagedTilize build_ndiml_tilize(BaseTilizeType base_tilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});
return MassagedTilize(MassagedTilizeParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = std::move(base_tilize)});
}

ttnn::Tensor ExecuteTilize::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config,
std::optional<DataType> output_dtype,
bool use_multicore) {
return operation::run(
Tilize{
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)
.at(0);
auto base_tilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
Tilize{
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_tilize(base_tilize)(input_tensor);
}

ttnn::Tensor ExecuteTilize::invoke(
Expand Down
48 changes: 37 additions & 11 deletions ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,33 @@
#include "device/untilize_op.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {
using OwnedUntilizeArgs = std::tuple<ttnn::Tensor>;
using BaseUntilizeType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedUntilize = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedUntilizeParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

MassagedUntilize build_ndiml_untilize(BaseUntilizeType base_untilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});
return MassagedUntilize(MassagedUntilizeParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = std::move(base_untilize)});
}

ttnn::Tensor ExecuteUntilize::invoke(
uint8_t queue_id,
Expand All @@ -22,17 +45,20 @@ ttnn::Tensor ExecuteUntilize::invoke(
input_tensor.get_dtype() ==
DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b

return operation::run(
Untilize{
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)
.at(0);
auto base_untilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
Untilize{
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_untilize(base_untilize)(input_tensor);
}

ttnn::Tensor ExecuteUntilize::invoke(
Expand Down

0 comments on commit efc6f70

Please sign in to comment.