diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 6ee1bb97cd8a..88ffd61332e9 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -406,7 +406,7 @@ def run_max_pool( @pytest.mark.parametrize( "in_place_halo", [ - # False, + False, True, ], ) diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp index 585ee4f7bf25..92146dfb42f7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp @@ -89,18 +89,18 @@ void kernel_main() { constexpr uint32_t padding_config_cb_id = get_compile_time_arg_val(0); // has untilized input shard constexpr uint32_t local_config_cb_id = get_compile_time_arg_val(1); // has untilized input shard constexpr uint32_t remote_config_cb_id = get_compile_time_arg_val(2); // has untilized input shard - constexpr uint32_t src_cb_id = get_compile_time_arg_val(5); // has untilized input shard - constexpr uint32_t in_cb_id = get_compile_time_arg_val(6); // has untilized input shard - constexpr uint32_t out_cb_id = get_compile_time_arg_val(7); // output shard with padding and halo goes here - constexpr uint32_t pad_cb_id = get_compile_time_arg_val(8); // cb for const pad val buffer - constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(9); // pad value to fill pad buffer with - constexpr uint32_t in_nsticks = get_compile_time_arg_val(10); // number of sticks - constexpr uint32_t stick_nbytes = get_compile_time_arg_val(11); // stick size in bytes (post untilize) - constexpr uint32_t is_block_sharded = get_compile_time_arg_val(12); - constexpr uint32_t remote_read = get_compile_time_arg_val(13); - constexpr bool is_col_major = get_compile_time_arg_val(14) == 1; - constexpr uint32_t is_width_sharded = get_compile_time_arg_val(15); - constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(16); + constexpr uint32_t src_cb_id = get_compile_time_arg_val(4); // has untilized input shard + constexpr uint32_t in_cb_id = get_compile_time_arg_val(5); // has untilized input shard + constexpr uint32_t out_cb_id = get_compile_time_arg_val(6); // output shard with padding and halo goes here + constexpr uint32_t pad_cb_id = get_compile_time_arg_val(7); // cb for const pad val buffer + constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(8); // pad value to fill pad buffer with + constexpr uint32_t in_nsticks = get_compile_time_arg_val(9); // number of sticks + constexpr uint32_t stick_nbytes = get_compile_time_arg_val(10); // stick size in bytes (post untilize) + constexpr uint32_t is_block_sharded = get_compile_time_arg_val(11); + constexpr uint32_t remote_read = get_compile_time_arg_val(12); + constexpr bool is_col_major = get_compile_time_arg_val(13) == 1; + constexpr uint32_t is_width_sharded = get_compile_time_arg_val(14); + constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(15); constexpr uint32_t elem_nbytes = sizeof(uint16_t); constexpr uint16_t pad_core_id = 0xFFFF; diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather_in_place.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather_in_place.cpp index c0a71051ae99..f939e7ca1328 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather_in_place.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather_in_place.cpp @@ -155,16 +155,15 @@ void copy_sticks_async( uint64_t dst_addr = base_addr + dst_offset; uint32_t src_addr = in_base_l1_addr + src_offset; - if ((dst_local_idx > src_local_idx + in_out_buffer_start_delta && - dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks) || - (dst_local_idx + nsticks >= src_local_idx + in_out_buffer_start_delta && - dst_local_idx + nsticks < - src_local_idx + in_out_buffer_start_delta + - nsticks)) { // dst and src data overlaps, stick by stick copy is necessary - if (dst_local_idx > src_local_idx + in_out_buffer_start_delta && - dst_local_idx <= src_local_idx + in_out_buffer_start_delta + - nsticks) { // dst data is being moved "in front" of the source data, reverse - // ordering of stick by stick copy is necessary + bool is_forward_copy = dst_local_idx > src_local_idx + in_out_buffer_start_delta && + dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks; + bool is_overlap_copy = (dst_local_idx > src_local_idx + in_out_buffer_start_delta && + dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks) || + (dst_local_idx + nsticks >= src_local_idx + in_out_buffer_start_delta && + dst_local_idx + nsticks < src_local_idx + in_out_buffer_start_delta + nsticks); + if (is_overlap_copy) { // dst and src data overlaps, stick by stick copy is necessary + if (is_forward_copy) { // dst data is being moved "in front" of the source data, reverse + // ordering of stick by stick copy is necessary for (int16_t k = nsticks - 1; k >= 0; k--) { noc_async_write(src_addr + k * stick_nbytes, dst_addr + k * stick_nbytes, stick_nbytes); } @@ -186,27 +185,26 @@ void kernel_main() { constexpr uint32_t padding_config_cb_id = get_compile_time_arg_val(0); // has untilized input shard constexpr uint32_t local_config_cb_id = get_compile_time_arg_val(1); // has untilized input shard constexpr uint32_t remote_config_cb_id = get_compile_time_arg_val(2); // has untilized input shard - constexpr uint32_t remote_ref_counts_cb_id = get_compile_time_arg_val(3); // has untilized input shard - constexpr uint32_t remote_temp_cb_id = get_compile_time_arg_val(4); // has untilized input shard - constexpr uint32_t src_cb_id = get_compile_time_arg_val(5); // has untilized input shard - constexpr uint32_t in_cb_id = get_compile_time_arg_val(6); // has untilized input shard - constexpr uint32_t out_cb_id = get_compile_time_arg_val(7); // output shard with padding and halo goes here - constexpr uint32_t pad_cb_id = get_compile_time_arg_val(8); // cb for const pad val buffer - constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(9); // pad value to fill pad buffer with - constexpr uint32_t in_nsticks = get_compile_time_arg_val(10); // number of sticks - constexpr uint32_t stick_nbytes = get_compile_time_arg_val(11); // stick size in bytes (post untilize) - constexpr uint32_t is_block_sharded = get_compile_time_arg_val(12); - constexpr uint32_t remote_read = get_compile_time_arg_val(13); - constexpr bool is_col_major = get_compile_time_arg_val(14) == 1; - constexpr uint32_t is_width_sharded = get_compile_time_arg_val(15); - constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(16); - constexpr uint32_t noc_00_x = get_compile_time_arg_val(17); - constexpr uint32_t noc_00_y = get_compile_time_arg_val(18); - constexpr uint32_t num_cores_nhw = get_compile_time_arg_val(19); - constexpr uint32_t num_cores_c = get_compile_time_arg_val(20); - constexpr uint32_t num_cores_x = get_compile_time_arg_val(21); - constexpr uint32_t semaphore_id = get_compile_time_arg_val(22); - constexpr uint32_t max_out_nsticks_per_core = get_compile_time_arg_val(23); + constexpr uint32_t remote_temp_cb_id = get_compile_time_arg_val(3); // has untilized input shard + constexpr uint32_t src_cb_id = get_compile_time_arg_val(4); // has untilized input shard + constexpr uint32_t in_cb_id = get_compile_time_arg_val(5); // has untilized input shard + constexpr uint32_t out_cb_id = get_compile_time_arg_val(6); // output shard with padding and halo goes here + constexpr uint32_t pad_cb_id = get_compile_time_arg_val(7); // cb for const pad val buffer + constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(8); // pad value to fill pad buffer with + constexpr uint32_t in_nsticks = get_compile_time_arg_val(9); // number of sticks + constexpr uint32_t stick_nbytes = get_compile_time_arg_val(10); // stick size in bytes (post untilize) + constexpr uint32_t is_block_sharded = get_compile_time_arg_val(11); + constexpr uint32_t remote_read = get_compile_time_arg_val(12); + constexpr bool is_col_major = get_compile_time_arg_val(13) == 1; + constexpr uint32_t is_width_sharded = get_compile_time_arg_val(14); + constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(15); + constexpr uint32_t noc_00_x = get_compile_time_arg_val(16); + constexpr uint32_t noc_00_y = get_compile_time_arg_val(17); + constexpr uint32_t num_cores_nhw = get_compile_time_arg_val(18); + constexpr uint32_t num_cores_c = get_compile_time_arg_val(19); + constexpr uint32_t num_cores_x = get_compile_time_arg_val(20); + constexpr uint32_t semaphore_id = get_compile_time_arg_val(21); + constexpr uint32_t max_out_nsticks_per_core = get_compile_time_arg_val(22); constexpr uint32_t num_cores = num_cores_nhw * num_cores_c; diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_op.cpp index 569f9a63fb6e..098956c91d88 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_op.cpp @@ -93,7 +93,6 @@ operation::ProgramWithCallbacks UntilizeWithHaloV2::create_program( local_config, remote_config, std::nullopt, - std::nullopt, remote_read_, transpose_mcast_, output_tensor, diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp index a9980567e9e9..6a535353876e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp @@ -31,7 +31,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( const Tensor& padding_config, const Tensor& local_config, const Tensor& remote_config, - std::optional> remote_ref_counts, std::optional> remote_temp, const bool remote_read, const bool transpose_mcast, @@ -122,7 +121,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( uint32_t padding_config_cb_id = tt::CBIndex::c_2; uint32_t local_config_cb_id = tt::CBIndex::c_3; uint32_t remote_config_cb_id = tt::CBIndex::c_4; - uint32_t remote_ref_counts_cb_id = remote_ref_counts.has_value() ? tt::CBIndex::c_5 : 0; uint32_t remote_temp_cb_id = remote_temp.has_value() ? tt::CBIndex::c_6 : 0; tt::DataFormat kernel_config_df = tt::DataFormat::RawUInt16; // NOTE: UInt16 is not supported for CB types @@ -183,20 +181,8 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( uint32_t semaphore_id = 0; if (in_place) { TT_ASSERT(!remote_read, "remote_read is not supported for in place operation"); - TT_ASSERT( - (remote_ref_counts.has_value() && remote_temp.has_value()) || - (!remote_ref_counts.has_value() && !remote_temp.has_value()), - "remote_ref_counts and remote_temp should be both present or absent"); - if (remote_ref_counts.has_value()) { - auto remote_ref_counts_buffer = remote_ref_counts.value().get().device_buffer(); - auto remote_ref_counts_cb_config = - CircularBufferConfig( - remote_ref_counts_buffer->size() / num_cores, {{remote_ref_counts_cb_id, kernel_config_df}}) - .set_page_size(remote_ref_counts_cb_id, remote_ref_counts_buffer->page_size()) - .set_globally_allocated_address(*remote_ref_counts_buffer); - CBHandle remote_ref_counts_cb = CreateCircularBuffer(program, all_cores, remote_ref_counts_cb_config); - } + // create the remote temp CB if (remote_temp.has_value()) { auto remote_temp_buffer = remote_temp.value().get().device_buffer(); auto remote_temp_cb_config = @@ -206,6 +192,7 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( CBHandle remote_temp_cb = CreateCircularBuffer(program, all_cores, remote_temp_cb_config); } + // compute core data and create semaphore auto core_id_to_noc_coords = [is_block_sharded, transpose_mcast, device](uint32_t core_id) -> CoreCoord { auto num_cores_x = device->compute_with_storage_grid_size().x; auto core_coord = is_block_sharded ? (transpose_mcast ? CoreCoord(core_id, 0) : CoreCoord(0, core_id)) @@ -215,7 +202,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( noc_00 = core_id_to_noc_coords(0); num_cores_x = device->compute_with_storage_grid_size().x; num_cores_y = device->compute_with_storage_grid_size().y; - semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); } @@ -231,7 +217,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( 0, // padding_config_cb_id 0, // local_config_cb_id 0, // remote_config_cb_id - 0, // remote_ref_counts_cb_id 0, // remote_temp_cb_id src_cb_id, input_to_writer_cb_id, @@ -256,8 +241,7 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( reader_ct_args[0] = 0; reader_ct_args[1] = local_config_cb_id; reader_ct_args[2] = remote_config_cb_id; - reader_ct_args[3] = 0; - reader_ct_args[4] = remote_temp_cb_id; + reader_ct_args[3] = remote_temp_cb_id; KernelHandle reader_kernel_id0 = CreateKernel( program, @@ -273,7 +257,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( reader_ct_args[1] = 0; reader_ct_args[2] = 0; reader_ct_args[3] = 0; - reader_ct_args[4] = 0; KernelHandle reader_kernel_id1 = CreateKernel( program, diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp index 99b517c56fac..76864b769354 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.hpp @@ -18,7 +18,6 @@ tt::tt_metal::operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( const Tensor& padding_config, const Tensor& local_config, const Tensor& remote_config, - std::optional> remote_ref_counts, std::optional> remote_temp, const bool remote_read, const bool transpose_mcast, diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp index 29c982ebe097..3ae0061733cf 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp @@ -119,8 +119,7 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program( const auto& pad_config = std::get<0>(kernel_config); const auto& local_config = std::get<1>(kernel_config); const auto& remote_config = std::get<2>(kernel_config); - const auto& remote_ref_counts = std::get<3>(kernel_config); - const auto& max_ref_size = std::get<4>(kernel_config); + const auto& max_ref_size = std::get<3>(kernel_config); auto pad_config_tensor = sliding_window::construct_on_host_config_tensor(pad_config, this->config_, this->parallel_config_); @@ -142,29 +141,14 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program( sliding_window::move_config_tensor_to_device(remote_config_tensor, parallel_config_, is_block_sharded, device); std::optional remote_temp_device_tensor; - std::optional remote_ref_counts_device_tensor; if (max_ref_size > 0 && this->in_place_) { + // create the remote temp tensor, TODO do we need to vary this type for bfloat8? int remote_temp_size = max_ref_size * stick_size * num_cores; - // TODO do we need to vary this type for bfloat8? auto remote_temp_buffer = owned_buffer::create(std::vector(remote_temp_size)); ttnn::Shape remote_temp_shape = ttnn::Shape({num_cores, max_ref_size, stick_size}); Tensor remote_temp_tensor(OwnedStorage{remote_temp_buffer}, remote_temp_shape, type, Layout::ROW_MAJOR); - uint32_t repeat_factor = num_cores; - std::vector> remote_ref_counts_repeated; - for (uint32_t i = 0; i < repeat_factor; ++i) { - remote_ref_counts_repeated.push_back(remote_ref_counts); - } - ttnn::Shape config_shape( - {(uint32_t)remote_ref_counts_repeated.size(), (uint32_t)remote_ref_counts_repeated[0].size()}); - std::vector config_vector = sliding_window::flatten(remote_ref_counts_repeated); - auto config_buffer = owned_buffer::create(std::move(config_vector)); - auto remote_ref_counts_tensor = - Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); - - remote_ref_counts_device_tensor = sliding_window::move_config_tensor_to_device( - remote_ref_counts_tensor, parallel_config_, is_block_sharded, device); - + // move tensors to device auto shard_shape = std::array({max_ref_size, stick_size}); ShardSpec shard_spec(parallel_config_.grid, shard_shape, ShardOrientation::ROW_MAJOR); MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1, shard_spec}; @@ -183,9 +167,6 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program( pad_config_device_tensor, local_config_device_tensor, remote_config_device_tensor, - remote_ref_counts_device_tensor - ? std::optional>(*remote_ref_counts_device_tensor) - : std::nullopt, remote_temp_device_tensor ? std::optional>(*remote_temp_device_tensor) : std::nullopt, remote_read_, diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp index 44bfaabeb1d2..98d889ec48de 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp @@ -351,7 +351,6 @@ std::tuple< std::vector>, std::vector>, std::vector>, - std::vector, int> generate_halo_kernel_config_tensors( const std::vector& tensor_metadata, @@ -513,9 +512,8 @@ generate_halo_kernel_config_tensors( return flattened_config; }; - auto flatten_remote_config = - [core_id_to_noc_coords, - &device](auto& config) -> std::tuple>, std::vector, int> { + auto flatten_remote_config = [core_id_to_noc_coords, + &device](auto& config) -> std::tuple>, int> { // find max length size_t max_len = 0; for (auto& core_config : config) { @@ -529,9 +527,8 @@ generate_halo_kernel_config_tensors( int num_cores_x = device->compute_with_storage_grid_size().x; int num_cores_y = device->compute_with_storage_grid_size().y; int num_cores = num_cores_x * num_cores_y; - std::vector remote_ref_counts(num_cores, 0); CoreCoord noc_00 = core_id_to_noc_coords(0); - int max_ref_size = 0; + int max_ref_size = 0; // track the max remote ref size for sizing the remote temp tensor int core = 0; for (auto& core_config : config) { std::vector flat_data(max_len, 0); @@ -543,7 +540,6 @@ generate_halo_kernel_config_tensors( flat_data[idx++] = nocy; flat_data[idx++] = len; int ref_ind = nocx - noc_00.x + (nocy - noc_00.y) * num_cores_x; - remote_ref_counts[ref_ind]++; for (size_t i = 0; i < key_data.second.size(); ++i) { auto [src_start, dst_start, length] = key_data.second[i]; flat_data[idx++] = src_start; @@ -561,12 +557,12 @@ generate_halo_kernel_config_tensors( flattened_config.emplace_back(flat_data); } - return std::make_tuple(flattened_config, remote_ref_counts, max_ref_size); + return std::make_tuple(flattened_config, max_ref_size); }; auto flattened_pad_config = flatten_pad_config(pad_config); auto flattened_local_config = flatten_local_config(local_config); - auto [flattened_remote_config, remote_ref_counts, max_ref_size] = flatten_remote_config(remote_config); + auto [flattened_remote_config, max_ref_size] = flatten_remote_config(remote_config); auto align_config = [](auto& config, size_t align_granularity = 1, uint16_t align_value = 0) { size_t max_len = 0; @@ -591,8 +587,7 @@ generate_halo_kernel_config_tensors( align_config(flattened_local_config, 2); align_config(flattened_remote_config, 2); - return std::make_tuple( - flattened_pad_config, flattened_local_config, flattened_remote_config, remote_ref_counts, max_ref_size); + return std::make_tuple(flattened_pad_config, flattened_local_config, flattened_remote_config, max_ref_size); } std::vector> generate_sliding_window_op_config( diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp index 009979586362..3e9e3772d68f 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp @@ -114,7 +114,6 @@ std::tuple< std::vector>, std::vector>, std::vector>, - std::vector, int> generate_halo_kernel_config_tensors( const std::vector& tensor_metadata,