diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 86047a86581..fa57a486650 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -109,16 +109,30 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [1, 64, 132, 10], [1, 32, 8, 8], [2, 640, 32, 32], + # some random shapes + [1, 32, 5, 4], + [3, 32, 4, 4], + [5, 64, 5, 5], + [1, 128, 5, 8], + [1, 32, 5, 4], + [7, 64, 128, 17], + [3, 64, 132, 19], ], ) -@pytest.mark.parametrize("scale_h", [2]) -@pytest.mark.parametrize("scale_w", [2]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("scale_h", [2, 3]) +@pytest.mark.parametrize("scale_w", [2, 3]) @pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK]) def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy): ## input shape is N C H W batch_size, num_channels, height, width = input_shape torch.manual_seed(0) input = torch.rand(input_shape, dtype=torch.bfloat16) + # for i in range(input_shape[0]): + # for j in range(input_shape[1]): + # for k in range(input_shape[2]): + # for l in range(input_shape[3]): + # input[i, j, k, l] = k * width + l + 1 ## golden reference using torch scale_factor = (scale_h, scale_w) @@ -136,15 +150,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate max_grid_size = (device_grid.y, device_grid.x) if shard_strategy == ttnn.ShardStrategy.HEIGHT: ## nsticks per shard should be divisible by in_w - max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1]) + max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1]) nshards = max_nshards while nshards > 0: - if batch_size * height % nshards == 0: + if batch_size * height * width % nshards == 0: break nshards -= 1 ncores = nshards elif shard_strategy == ttnn.ShardStrategy.BLOCK: - max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW + max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C ## find nshards_h along NHW nshards_h = max_nshards_h @@ -353,6 +367,7 @@ def test_bilinear_multi_core( ## compare the results torch_result = torch_result.permute(0, 2, 3, 1) + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999) allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1) logger.info(pcc_msg) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 03530ea7433..91e9a6ff9a2 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -3,12 +3,28 @@ // SPDX-License-Identifier: Apache-2.0 #include + #include "dataflow_api.h" +#define ENABLE_DEBUG_PRINT 0 -void kernel_main() { +#if ENABLE_DEBUG_PRINT == 1 +#include "debug/dprint.h" + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } +} +#endif +void kernel_main() { uint32_t stick_nbytes = get_arg_val(0); - uint32_t in_image_rows_per_core = get_arg_val(1); + uint32_t in_nsticks_per_core = get_arg_val(1); uint32_t scale_h = get_arg_val(2); uint32_t scale_w = get_arg_val(3); uint32_t in_w = get_arg_val(4); @@ -17,46 +33,37 @@ void kernel_main() { constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = get_compile_time_arg_val(1); constexpr uint32_t is_reader = get_compile_time_arg_val(2); + constexpr uint32_t config_cb_id = get_compile_time_arg_val(3); + + uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2; + uint32_t writer_nsticks_per_core = in_nsticks_per_core / 2; + uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core; + uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core; + uint32_t l1_read_addr = get_read_ptr(in_cb_id); + uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes; - uint32_t in_image_row_nbytes = in_w * stick_nbytes; - uint32_t out_image_row_nbytes = out_w * stick_nbytes; - uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; - uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; - uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; - uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; - uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes; - uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes; + uint32_t config_l1_addr = get_read_ptr(config_cb_id); + volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast(config_l1_addr); + uint32_t reader_idx = 0; + if (!is_reader) { + reader_idx = 4 * (scale_h * image_row_begin); + } cb_reserve_back(out_cb_id, out_w); - // assuming shard begins with a new row. TODO: generalize? - for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) { - uint32_t l1_write_addr_image_row_start = l1_write_addr; - for (uint32_t i = 0; i < in_w; ++i) { + for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) { + for (uint32_t sh = 0; sh < scale_h; sh++) { + uint16_t corex = config_data[reader_idx++]; + uint16_t corey = config_data[reader_idx++]; + uint16_t offset = config_data[reader_idx++]; + reader_idx++; + uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes); // replicate stick scale_w times. - for (uint32_t sw = 0; sw < scale_w; ++sw) { - // replicate stick scale_w times. - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_read_addr); - noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes); - } + for (uint32_t sw = 0; sw < scale_w; sw++) { + noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes); l1_write_addr += stick_nbytes; } - l1_read_addr += stick_nbytes; - } - - // Duplicate the whole image row in one shot - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start); - noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes); } - l1_write_addr += out_image_row_nbytes; } cb_push_back(out_cb_id, out_w); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index b2deccc8f2f..507aa6c18ae 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -3,22 +3,61 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include -#include "upsample_op.hpp" -#include "ttnn/operations/math.hpp" +#include "buffers/buffer_constants.hpp" +#include "common/core_coord.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" #include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/common/math.hpp" -#include "tt_metal/tt_stl/reflection.hpp" using namespace tt::constants; namespace ttnn::operations::upsample { using namespace tt; +Tensor create_config_tensor( + Device *device, + ShardSpec &input_shard_spec, + const uint32_t batch_size, + const uint32_t in_h, + const uint32_t in_w, + const uint32_t scale_factor_h, + const uint32_t scale_factor_w, + const uint32_t ncores) { + std::vector config_vector; + uint32_t input_nsticks_per_core = input_shard_spec.shape[0]; + uint32_t ncores_x = device->compute_with_storage_grid_size().x; + uint32_t in_core = 0; + uint32_t w = 0; + uint32_t curr_stick = 0; + auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + for (uint32_t b = 0; b < batch_size; b++) { + for (uint32_t h = 0; h < in_h; h++) { + for (uint32_t w = 0; w < in_w; w++) { + if (curr_stick == input_nsticks_per_core) { + curr_stick = 0; + in_core++; + core_coords = + device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + } + config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0}); + curr_stick++; + } + for (uint32_t j = 0; j < scale_factor_h - 1; j++) + config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end()); + } + } + + uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(config_vector)); + Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); + return config_tensor; +} + operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); Device *device = input.device(); @@ -53,7 +92,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // extra limitation to avoid post upsample step of resharding if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_FATAL(in_nsticks_per_core % in_w == 0, "Restriction: Input sticks per core {} should be divisible by input width {}. TODO to remove this restriction", in_nsticks_per_core, in_w); } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { ncores_x = all_cores.ranges().begin()->end_coord.x + 1; ncores_nhw = all_cores.ranges().begin()->end_coord.y + 1; @@ -68,8 +106,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // TODO: Support non-multiple case TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); - TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core); - TT_FATAL(input_nsticks_per_core % in_w == 0, "Error"); // CBs @@ -105,12 +141,37 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x); log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); + // create config tensor + Tensor config_tensor = create_config_tensor( + device, + shard_spec, + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w, + ncores); + auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); + ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, ShardOrientation::ROW_MAJOR, false); + MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, config_shard_spec}; + auto config_tensor_device = config_tensor.to(device, memory_config); + tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); + + tt::DataFormat config_df = tt::DataFormat::RawUInt16; + Buffer *config_buffer = config_tensor_device.buffer(); + uint32_t config_cb_id = tt::CB::c_in2; + auto config_cb_config = CircularBufferConfig(config_buffer->size(), {{config_cb_id, config_df}}) + .set_page_size(config_cb_id, config_buffer->page_size()) + .set_globally_allocated_address(*config_buffer); + CBHandle config_cb = CreateCircularBuffer(program, all_cores, config_cb_config); + // Kernels std::vector writer_compile_time_args = { in_cb_id, out_cb_id, false, + config_cb_id, }; auto writer_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto writer_kernel = @@ -120,6 +181,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& in_cb_id, out_cb_id, true, + config_cb_id, }; auto reader_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto reader_kernel = @@ -132,11 +194,11 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t writer_nargs = 7; std::vector writer_rt_args(writer_nargs); writer_rt_args[0] = input_stick_nbytes; - writer_rt_args[1] = input_nsticks_per_core / in_w; + writer_rt_args[1] = input_nsticks_per_core; writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; - writer_rt_args[4] = in_w; - writer_rt_args[5] = out_w; + writer_rt_args[4] = input_nsticks_per_core; + writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core writer_rt_args[6] = 0; // set for each core below uint32_t start_input_stick_id = 0; @@ -162,7 +224,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& TT_THROW("Unsupported memory layout"); } - auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb]( + auto override_runtime_args_callback = [writer_kernel, cb_src0, config_cb, out_cb]( const void* operation, Program &program, const std::vector& input_tensors,