Skip to content

Commit

Permalink
#0: add c++ api for fold
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jul 29, 2024
1 parent c3a318e commit 37a564f
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 629 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,27 @@ def test_fold_with_permute_reshape_on_device(device, n, c, h, w, pad_h, pad_w, s
torch_input_tensor, pad_h, pad_w, stride_h, stride_w
)
torch_output_tensor = torch.permute(torch_output_tensor, (0, 2, 3, 1))
tt_output_tensor = pad_and_fold_with_permute_and_reshape_on_device(
device, torch_input_tensor, pad_h, pad_w, stride_h, stride_w
# pad on host
n, c, h, w = torch_input_tensor.shape
C = _nearest_y(c, 4)
padded_h = h + pad_h * 2
padded_w = w + pad_w * 2
w_pad32 = padded_w + (32 - padded_w % 32) % 32
pad_w_right = w_pad32 - (w + pad_w)
torch_input_tensor_padded = torch.nn.functional.pad(torch_input_tensor, (pad_w, pad_w_right, pad_h, pad_h))
# on device
tt_input_tensor = ttnn.from_torch(
torch_input_tensor_padded, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
tt_output_tensor = ttl.tensor.fold(
tt_input_tensor,
stride_h,
stride_w,
use_transpose_as_fold=True,
output_shape=(n, padded_h // stride_h, padded_w // stride_w, C * (stride_h * stride_w)),
pad_c=C - c,
pad_h=pad_h,
pad_w=0,
)
tt_output_tensor = ttnn.to_torch(tt_output_tensor)
assert_with_pcc(torch_output_tensor, tt_output_tensor, 1)
Expand Down
93 changes: 92 additions & 1 deletion ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/fold/fold_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,95 @@

#include "ttnn/run_operation.hpp"


#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/reshape/reshape_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"

namespace tt::tt_metal {

std::vector<Tensor> fold_with_transpose_(
const Tensor& input, const std::optional<const Shape>& output_shape, uint8_t stride_h, uint8_t stride_w, uint8_t pad_c, uint8_t pad_h, uint8_t pad_w) {

Device * device;

// Get the device
if (input.storage_type() != StorageType::DEVICE) {
device = AutoFormat::GetDefaultDevice();
TT_ASSERT(device != nullptr, "Requires setting default device if no inputs to op are on device");
} else {
device = input.device();
}

uint32_t n = input.shape()[0], c = input.shape()[1], h = input.shape()[2], w = input.shape()[3];
auto padded_c = c + pad_c; // end padding only
auto padded_h = h + pad_h * 2; // front and end padding
auto padded_w = w + pad_w * 2; // front and end padding
auto padded_h32 = round_up(padded_h, TILE_HEIGHT);
auto padded_w32 = round_up(padded_w, TILE_HEIGHT);

auto L1_mem_config = tt::tt_metal::MemoryConfig{.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type=BufferType::L1};

tt::log_debug("input: {}", input.shape());

// pad input tensor
tt::tt_metal::Array4D padded_shape = {n, padded_c, padded_h32, padded_w32};
auto pad_output = ttnn::pad(input, padded_shape, tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

tt::log_debug("pad_output: {}", pad_output.shape());

// transpose
auto transpose_hw_output = ttnn::transpose(pad_output, 2, 3, L1_mem_config);

tt::log_debug("transpose_hw_output: {}", transpose_hw_output.shape());

// transpose
auto transpose_hc_output = ttnn::transpose(transpose_hw_output, 1, 2, L1_mem_config);

tt::log_debug("transpose_hc_output: {}", transpose_hc_output.shape());

// reshape
n = transpose_hc_output.shape()[0], w = transpose_hc_output.shape()[1], c = transpose_hc_output.shape()[2], h = transpose_hc_output.shape()[3];
auto reshape_hc_output = tt::tt_metal::reshape(transpose_hc_output, n, (w / stride_w), (c * stride_w), h, L1_mem_config);

tt::log_debug("reshape_hc_output: {}", reshape_hc_output.shape());

// transpose
auto transpose_hw_output2 = ttnn::transpose(reshape_hc_output, 2, 3, L1_mem_config);

tt::log_debug("transpose_hw_output2: {}", transpose_hw_output2.shape());

// reshape
n = transpose_hw_output2.shape()[0], w = transpose_hw_output2.shape()[1], h = transpose_hw_output2.shape()[2], c = transpose_hw_output2.shape()[3];
auto reshape_hw_output = tt::tt_metal::reshape(transpose_hw_output2, n, w, (h / stride_h), (c * stride_h), L1_mem_config);

tt::log_debug("reshape_hw_output: {}", reshape_hw_output.shape());

// transpose
auto transpose_hc_output2 = ttnn::transpose(reshape_hw_output, 1, 2, L1_mem_config);

tt::log_debug("transpose_hc_output2: {}", transpose_hc_output2.shape());

std::vector<Tensor> output_tensors;
if (output_shape.has_value()) {
// slice
n = output_shape.value()[0], w = output_shape.value()[1], h = output_shape.value()[2], c = output_shape.value()[3];
tt::tt_metal::Array4D slice_output_tensor_start = {0, 0, 0, 0};
tt::tt_metal::Array4D slice_output_tensor_end = {n - 1, w - 1, h - 1, c - 1};
auto slice_output = ttnn::slice(transpose_hc_output2, slice_output_tensor_start, slice_output_tensor_end, L1_mem_config);

output_tensors.emplace_back(slice_output);

tt::log_debug("slice_output: {}", slice_output.shape());
} else {
output_tensors.emplace_back(transpose_hc_output2);
}

return output_tensors;

}

FoldOpParallelizationStrategy Fold::get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const {
if (is_sharded) {
return FoldOpParallelizationStrategy::SHARDED_MULTI_CORE;
Expand Down Expand Up @@ -85,9 +173,12 @@ operation::ProgramWithCallbacks Fold::create_program(
}
}

Tensor fold(const Tensor &input_tensor, uint8_t stride_h, uint8_t stride_w) {
Tensor fold(const Tensor &input_tensor, uint8_t stride_h, uint8_t stride_w, bool use_transpose_as_fold, const std::optional<const Shape>& output_shape, uint8_t pad_c, uint8_t pad_h, uint8_t pad_w) {
bool is_sharded = input_tensor.is_sharded();

if (use_transpose_as_fold) {
return operation::decorate_as_composite(__func__, fold_with_transpose_)(input_tensor, output_shape, stride_h, stride_w, pad_c, pad_h, pad_w).at(0);
}
return operation::run(Fold{.stride_h = stride_h, .stride_w = stride_w, .is_sharded = is_sharded}, {input_tensor})
.at(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ operation::ProgramWithCallbacks fold_single_core(
operation::ProgramWithCallbacks fold_multi_core(
const Tensor &input, const Tensor &output, uint8_t stride_h, uint8_t stride_w);

Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w);
Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w, bool use_transpose_as_fold = false, const std::optional<const Shape>& output_shape = std::nullopt, uint8_t pad_c = 0, uint8_t pad_h = 0, uint8_t pad_w = 0);
} // namespace tt::tt_metal
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ namespace tt::tt_metal::detail{
)doc");

m_tensor.def("fold", &fold,
py::arg("input").noconvert(), py::arg("stride_h"), py::arg("stride_w"), R"doc(
py::arg("input").noconvert(), py::arg("stride_h"), py::arg("stride_w"), py::arg("use_transpose_as_fold")=false, py::arg("output_shape")=std::nullopt, py::arg("pad_c")=0, py::arg("pad_h")=0, py::arg("pad_w")=0, R"doc(
Fold TT Tensor.
Input tensor must be on TT accelerator device, in ROW_MAJOR.
Expand Down
17 changes: 8 additions & 9 deletions ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ static ttnn::Tensor pad_impl(
TT_FATAL(
padding.size() == original_rank,
"ttnn.pad: padding must be the same length as the input tensor rank");
TT_FATAL(
input_tensor.get_layout() != ttnn::ROW_MAJOR_LAYOUT,
"ttnn.pad: row-major tensors have to use fallback because the kernel currently causes a PCC error");

// Unsqueeze Tensor to 4D if it is not already
ttnn::Tensor input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor);
Expand All @@ -94,12 +91,14 @@ static ttnn::Tensor pad_impl(
front_padding_is_zero,
"ttnn.pad: on device padding does not support front padding");

const int target_height = output_padded_shape[padding.size() - 2];
const int target_width = output_padded_shape[padding.size() - 1];
TT_FATAL(
target_height % ttnn::TILE_SIZE == 0 || target_width % ttnn::TILE_SIZE == 0,
"ttnn.pad: for tiled tensors padding end must be a multiple of the tile size on height and width for a "
"tensor in tile layout");
if (input_tensor.get_layout() == ttnn::TILE_LAYOUT) {
const int target_height = output_padded_shape[padding.size() - 2];
const int target_width = output_padded_shape[padding.size() - 1];
TT_FATAL(
target_height % ttnn::TILE_SIZE == 0 || target_width % ttnn::TILE_SIZE == 0,
"ttnn.pad: for tiled tensors padding end must be a multiple of the tile size on height and width for a "
"tensor in tile layout");
}

// Performing actual padding
ShapeType pad_front_array;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,56 +56,3 @@ void kernel_main() {
cb_push_back(cb_id_in0, num_read_per_barrier);
}
}


// // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
// //
// // SPDX-License-Identifier: Apache-2.0

// #include <stdint.h>
// #include "dataflow_api.h"

// void kernel_main() {

// const uint32_t src_addr = get_arg_val<uint32_t>(0);
// const uint32_t padded_stick_size = get_arg_val<uint32_t>(1);
// const uint32_t unpadded_stick_size = get_arg_val<uint32_t>(2);
// const uint32_t num_dims = get_arg_val<uint32_t>(3);
// const uint32_t start_id = get_arg_val<uint32_t>(4);
// const uint32_t num_sticks = get_arg_val<uint32_t>(5);

// tt_l1_ptr uint32_t * num_unpadded_sticks = (tt_l1_ptr uint32_t*)(get_arg_addr(6));
// volatile tt_l1_ptr uint32_t * num_padded_sticks = num_unpadded_sticks + num_dims;
// volatile tt_l1_ptr uint32_t * id_per_dim = num_padded_sticks + num_dims;

// constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;

// const InterleavedAddrGen<src0_is_dram> s0 = {
// .bank_base_address = src_addr,
// .page_size = padded_stick_size
// };

// constexpr uint32_t cb_id_in0 = 0;

// uint32_t src_stick_id = start_id;

// for(uint32_t i = 0; i < num_sticks; i++) {
// // Copy Input
// cb_reserve_back(cb_id_in0, 1);
// uint32_t src_buffer_l1_addr = get_write_ptr(cb_id_in0);
// uint64_t src_noc_addr = get_noc_addr(src_stick_id, s0);
// noc_async_read(src_noc_addr, src_buffer_l1_addr, unpadded_stick_size);
// noc_async_read_barrier();
// cb_push_back(cb_id_in0, 1);
// src_stick_id++;
// for(uint32_t j = 0; j < num_dims; j++) {
// id_per_dim[j]++;
// if (id_per_dim[j] == num_unpadded_sticks[j]) {
// id_per_dim[j] = 0;
// src_stick_id += num_padded_sticks[j];
// } else {
// break;
// }
// }
// }
// }
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,3 @@ void kernel_main() {
cb_pop_front(cb_id_out0, num_read_per_barrier);
}
}


// // SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
// //
// // SPDX-License-Identifier: Apache-2.0

// #include <stdint.h>
// #include "dataflow_api.h"

// void kernel_main() {


// uint32_t dst_addr = get_arg_val<uint32_t>(0);
// uint32_t stick_size = get_arg_val<uint32_t>(1);
// uint32_t num_sticks = get_arg_val<uint32_t>(2);
// uint32_t start_id = get_arg_val<uint32_t>(3);

// constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0);
// constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1;


// const InterleavedAddrGen<dst0_is_dram> s0 = {
// .bank_base_address = dst_addr,
// .page_size = stick_size
// };


// for (uint32_t i = start_id; i < start_id + num_sticks; i++) {
// cb_wait_front(cb_id_out0, 1);
// uint32_t l1_read_addr = get_read_ptr(cb_id_out0);
// uint64_t dst_noc_addr = get_noc_addr(i, s0);
// noc_async_write(l1_read_addr, dst_noc_addr, stick_size);
// noc_async_write_barrier();
// cb_pop_front(cb_id_out0, 1);
// }
// }
Loading

0 comments on commit 37a564f

Please sign in to comment.