Skip to content

Commit

Permalink
Remove restriction of input_nsticks_per_core % w == 0 for height shar…
Browse files Browse the repository at this point in the history
…ded tensor inputs.

Signed-off-by: Nilaykumar Patel <[email protected]>
  • Loading branch information
nkpatel-tt committed Nov 19, 2024
1 parent bdbf3e0 commit 607825e
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 51 deletions.
25 changes: 20 additions & 5 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,30 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
[1, 64, 132, 10],
[1, 32, 8, 8],
[2, 640, 32, 32],
# some random shapes
[1, 32, 5, 4],
[3, 32, 4, 4],
[5, 64, 5, 5],
[1, 128, 5, 8],
[1, 32, 5, 4],
[7, 64, 128, 17],
[3, 64, 132, 19],
],
)
@pytest.mark.parametrize("scale_h", [2])
@pytest.mark.parametrize("scale_w", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize("scale_h", [2, 3])
@pytest.mark.parametrize("scale_w", [2, 3])
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy):
## input shape is N C H W
batch_size, num_channels, height, width = input_shape
torch.manual_seed(0)
input = torch.rand(input_shape, dtype=torch.bfloat16)
# for i in range(input_shape[0]):
# for j in range(input_shape[1]):
# for k in range(input_shape[2]):
# for l in range(input_shape[3]):
# input[i, j, k, l] = k * width + l + 1

## golden reference using torch
scale_factor = (scale_h, scale_w)
Expand All @@ -136,15 +150,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1])
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height % nshards == 0:
if batch_size * height * width % nshards == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
Expand Down Expand Up @@ -353,6 +367,7 @@ def test_bilinear_multi_core(

## compare the results
torch_result = torch_result.permute(0, 2, 3, 1)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999)
allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1)
logger.info(pcc_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"
#define ENABLE_DEBUG_PRINT 0

void kernel_main() {
#if ENABLE_DEBUG_PRINT == 1
#include "debug/dprint.h"

inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) {
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(l1_addr) + start * pagelen;
for (uint32_t page = 0; page < npages; ++page) {
DPRINT << start + page << ": ";
for (uint32_t j = 0; j < pagelen; ++j, ++ptr) {
DPRINT << BF16(*ptr) << " ";
}
DPRINT << ENDL();
}
}
#endif

void kernel_main() {
uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t in_nsticks_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
Expand All @@ -17,46 +33,37 @@ void kernel_main() {
constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = get_compile_time_arg_val(1);
constexpr uint32_t is_reader = get_compile_time_arg_val(2);
constexpr uint32_t config_cb_id = get_compile_time_arg_val(3);

uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2;
uint32_t writer_nsticks_per_core = in_nsticks_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core;
uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id);
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes;

uint32_t in_image_row_nbytes = in_w * stick_nbytes;
uint32_t out_image_row_nbytes = out_w * stick_nbytes;
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2;
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core;
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes;
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes;
uint32_t config_l1_addr = get_read_ptr(config_cb_id);
volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(config_l1_addr);

uint32_t reader_idx = 0;
if (!is_reader) {
reader_idx = 4 * (scale_h * image_row_begin);
}
cb_reserve_back(out_cb_id, out_w);

// assuming shard begins with a new row. TODO: generalize?
for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) {
uint32_t l1_write_addr_image_row_start = l1_write_addr;
for (uint32_t i = 0; i < in_w; ++i) {
for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) {
for (uint32_t sh = 0; sh < scale_h; sh++) {
uint16_t corex = config_data[reader_idx++];
uint16_t corey = config_data[reader_idx++];
uint16_t offset = config_data[reader_idx++];
reader_idx++;
uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes);
// replicate stick scale_w times.
for (uint32_t sw = 0; sw < scale_w; ++sw) {
// replicate stick scale_w times.
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_read_addr);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes);
}
for (uint32_t sw = 0; sw < scale_w; sw++) {
noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;
}
l1_read_addr += stick_nbytes;
}

// Duplicate the whole image row in one shot
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start);
noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes);
}
l1_write_addr += out_image_row_nbytes;
}

cb_push_back(out_cb_id, out_w);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,61 @@
// SPDX-License-Identifier: Apache-2.0

#include <math.h>
#include <vector>

#include "upsample_op.hpp"
#include "ttnn/operations/math.hpp"
#include "buffers/buffer_constants.hpp"
#include "common/core_coord.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"

#include "tt_metal/host_api.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/common/math.hpp"

#include "tt_metal/tt_stl/reflection.hpp"

using namespace tt::constants;

namespace ttnn::operations::upsample {
using namespace tt;

Tensor create_config_tensor(
Device *device,
ShardSpec &input_shard_spec,
const uint32_t batch_size,
const uint32_t in_h,
const uint32_t in_w,
const uint32_t scale_factor_h,
const uint32_t scale_factor_w,
const uint32_t ncores) {
std::vector<uint16_t> config_vector;
uint32_t input_nsticks_per_core = input_shard_spec.shape[0];
uint32_t ncores_x = device->compute_with_storage_grid_size().x;
uint32_t in_core = 0;
uint32_t w = 0;
uint32_t curr_stick = 0;
auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x));
for (uint32_t b = 0; b < batch_size; b++) {
for (uint32_t h = 0; h < in_h; h++) {
for (uint32_t w = 0; w < in_w; w++) {
if (curr_stick == input_nsticks_per_core) {
curr_stick = 0;
in_core++;
core_coords =
device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x));
}
config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0});
curr_stick++;
}
for (uint32_t j = 0; j < scale_factor_h - 1; j++)
config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end());
}
}

uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core;
Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core});
auto config_buffer = owned_buffer::create<uint16_t>(std::move(config_vector));
Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR);
return config_tensor;
}

operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) {
Program program = CreateProgram();
Device *device = input.device();
Expand Down Expand Up @@ -53,7 +92,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&

// extra limitation to avoid post upsample step of resharding
if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
TT_FATAL(in_nsticks_per_core % in_w == 0, "Restriction: Input sticks per core {} should be divisible by input width {}. TODO to remove this restriction", in_nsticks_per_core, in_w);
} else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
ncores_x = all_cores.ranges().begin()->end_coord.x + 1;
ncores_nhw = all_cores.ranges().begin()->end_coord.y + 1;
Expand All @@ -68,8 +106,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&

// TODO: Support non-multiple case
TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core);
TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core);
TT_FATAL(input_nsticks_per_core % in_w == 0, "Error");

// CBs

Expand Down Expand Up @@ -105,12 +141,37 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x);
log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core);

// create config tensor
Tensor config_tensor = create_config_tensor(
device,
shard_spec,
input.legacy_shape()[0],
input.legacy_shape()[1],
in_w,
scale_factor_h,
scale_factor_w,
ncores);
auto shard_shape = std::array<uint32_t, 2>({1, (uint32_t)config_tensor.get_shape()[-1]});
ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, ShardOrientation::ROW_MAJOR, false);
MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, config_shard_spec};
auto config_tensor_device = config_tensor.to(device, memory_config);
tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer());

tt::DataFormat config_df = tt::DataFormat::RawUInt16;
Buffer *config_buffer = config_tensor_device.buffer();
uint32_t config_cb_id = tt::CB::c_in2;
auto config_cb_config = CircularBufferConfig(config_buffer->size(), {{config_cb_id, config_df}})
.set_page_size(config_cb_id, config_buffer->page_size())
.set_globally_allocated_address(*config_buffer);
CBHandle config_cb = CreateCircularBuffer(program, all_cores, config_cb_config);

// Kernels

std::vector<uint32_t> writer_compile_time_args = {
in_cb_id,
out_cb_id,
false,
config_cb_id,
};
auto writer_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp");
auto writer_kernel =
Expand All @@ -120,6 +181,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
in_cb_id,
out_cb_id,
true,
config_cb_id,
};
auto reader_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp");
auto reader_kernel =
Expand All @@ -132,11 +194,11 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
uint32_t writer_nargs = 7;
std::vector<uint32_t> writer_rt_args(writer_nargs);
writer_rt_args[0] = input_stick_nbytes;
writer_rt_args[1] = input_nsticks_per_core / in_w;
writer_rt_args[1] = input_nsticks_per_core;
writer_rt_args[2] = scale_factor_h;
writer_rt_args[3] = scale_factor_w;
writer_rt_args[4] = in_w;
writer_rt_args[5] = out_w;
writer_rt_args[4] = input_nsticks_per_core;
writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core
writer_rt_args[6] = 0; // set for each core below

uint32_t start_input_stick_id = 0;
Expand All @@ -162,7 +224,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
TT_THROW("Unsupported memory layout");
}

auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb](
auto override_runtime_args_callback = [writer_kernel, cb_src0, config_cb, out_cb](
const void* operation,
Program &program,
const std::vector<Tensor>& input_tensors,
Expand Down

0 comments on commit 607825e

Please sign in to comment.