Skip to content

Commit

Permalink
#0: removed remote refs tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wransom-TT committed Mar 3, 2025
1 parent 6e5fbcb commit a98659d
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 100 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def run_max_pool(
@pytest.mark.parametrize(
"in_place_halo",
[
# False,
False,
True,
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ void kernel_main() {
constexpr uint32_t padding_config_cb_id = get_compile_time_arg_val(0); // has untilized input shard
constexpr uint32_t local_config_cb_id = get_compile_time_arg_val(1); // has untilized input shard
constexpr uint32_t remote_config_cb_id = get_compile_time_arg_val(2); // has untilized input shard
constexpr uint32_t src_cb_id = get_compile_time_arg_val(5); // has untilized input shard
constexpr uint32_t in_cb_id = get_compile_time_arg_val(6); // has untilized input shard
constexpr uint32_t out_cb_id = get_compile_time_arg_val(7); // output shard with padding and halo goes here
constexpr uint32_t pad_cb_id = get_compile_time_arg_val(8); // cb for const pad val buffer
constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(9); // pad value to fill pad buffer with
constexpr uint32_t in_nsticks = get_compile_time_arg_val(10); // number of sticks
constexpr uint32_t stick_nbytes = get_compile_time_arg_val(11); // stick size in bytes (post untilize)
constexpr uint32_t is_block_sharded = get_compile_time_arg_val(12);
constexpr uint32_t remote_read = get_compile_time_arg_val(13);
constexpr bool is_col_major = get_compile_time_arg_val(14) == 1;
constexpr uint32_t is_width_sharded = get_compile_time_arg_val(15);
constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(16);
constexpr uint32_t src_cb_id = get_compile_time_arg_val(4); // has untilized input shard
constexpr uint32_t in_cb_id = get_compile_time_arg_val(5); // has untilized input shard
constexpr uint32_t out_cb_id = get_compile_time_arg_val(6); // output shard with padding and halo goes here
constexpr uint32_t pad_cb_id = get_compile_time_arg_val(7); // cb for const pad val buffer
constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(8); // pad value to fill pad buffer with
constexpr uint32_t in_nsticks = get_compile_time_arg_val(9); // number of sticks
constexpr uint32_t stick_nbytes = get_compile_time_arg_val(10); // stick size in bytes (post untilize)
constexpr uint32_t is_block_sharded = get_compile_time_arg_val(11);
constexpr uint32_t remote_read = get_compile_time_arg_val(12);
constexpr bool is_col_major = get_compile_time_arg_val(13) == 1;
constexpr uint32_t is_width_sharded = get_compile_time_arg_val(14);
constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(15);

constexpr uint32_t elem_nbytes = sizeof(uint16_t);
constexpr uint16_t pad_core_id = 0xFFFF;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ void copy_sticks_async(
uint64_t dst_addr = base_addr + dst_offset;
uint32_t src_addr = in_base_l1_addr + src_offset;

if ((dst_local_idx > src_local_idx + in_out_buffer_start_delta &&
dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks) ||
(dst_local_idx + nsticks >= src_local_idx + in_out_buffer_start_delta &&
dst_local_idx + nsticks <
src_local_idx + in_out_buffer_start_delta +
nsticks)) { // dst and src data overlaps, stick by stick copy is necessary
if (dst_local_idx > src_local_idx + in_out_buffer_start_delta &&
dst_local_idx <= src_local_idx + in_out_buffer_start_delta +
nsticks) { // dst data is being moved "in front" of the source data, reverse
// ordering of stick by stick copy is necessary
bool is_forward_copy = dst_local_idx > src_local_idx + in_out_buffer_start_delta &&
dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks;
bool is_overlap_copy = (dst_local_idx > src_local_idx + in_out_buffer_start_delta &&
dst_local_idx <= src_local_idx + in_out_buffer_start_delta + nsticks) ||
(dst_local_idx + nsticks >= src_local_idx + in_out_buffer_start_delta &&
dst_local_idx + nsticks < src_local_idx + in_out_buffer_start_delta + nsticks);
if (is_overlap_copy) { // dst and src data overlaps, stick by stick copy is necessary
if (is_forward_copy) { // dst data is being moved "in front" of the source data, reverse
// ordering of stick by stick copy is necessary
for (int16_t k = nsticks - 1; k >= 0; k--) {
noc_async_write(src_addr + k * stick_nbytes, dst_addr + k * stick_nbytes, stick_nbytes);
}
Expand All @@ -186,27 +185,26 @@ void kernel_main() {
constexpr uint32_t padding_config_cb_id = get_compile_time_arg_val(0); // has untilized input shard
constexpr uint32_t local_config_cb_id = get_compile_time_arg_val(1); // has untilized input shard
constexpr uint32_t remote_config_cb_id = get_compile_time_arg_val(2); // has untilized input shard
constexpr uint32_t remote_ref_counts_cb_id = get_compile_time_arg_val(3); // has untilized input shard
constexpr uint32_t remote_temp_cb_id = get_compile_time_arg_val(4); // has untilized input shard
constexpr uint32_t src_cb_id = get_compile_time_arg_val(5); // has untilized input shard
constexpr uint32_t in_cb_id = get_compile_time_arg_val(6); // has untilized input shard
constexpr uint32_t out_cb_id = get_compile_time_arg_val(7); // output shard with padding and halo goes here
constexpr uint32_t pad_cb_id = get_compile_time_arg_val(8); // cb for const pad val buffer
constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(9); // pad value to fill pad buffer with
constexpr uint32_t in_nsticks = get_compile_time_arg_val(10); // number of sticks
constexpr uint32_t stick_nbytes = get_compile_time_arg_val(11); // stick size in bytes (post untilize)
constexpr uint32_t is_block_sharded = get_compile_time_arg_val(12);
constexpr uint32_t remote_read = get_compile_time_arg_val(13);
constexpr bool is_col_major = get_compile_time_arg_val(14) == 1;
constexpr uint32_t is_width_sharded = get_compile_time_arg_val(15);
constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(16);
constexpr uint32_t noc_00_x = get_compile_time_arg_val(17);
constexpr uint32_t noc_00_y = get_compile_time_arg_val(18);
constexpr uint32_t num_cores_nhw = get_compile_time_arg_val(19);
constexpr uint32_t num_cores_c = get_compile_time_arg_val(20);
constexpr uint32_t num_cores_x = get_compile_time_arg_val(21);
constexpr uint32_t semaphore_id = get_compile_time_arg_val(22);
constexpr uint32_t max_out_nsticks_per_core = get_compile_time_arg_val(23);
constexpr uint32_t remote_temp_cb_id = get_compile_time_arg_val(3); // has untilized input shard
constexpr uint32_t src_cb_id = get_compile_time_arg_val(4); // has untilized input shard
constexpr uint32_t in_cb_id = get_compile_time_arg_val(5); // has untilized input shard
constexpr uint32_t out_cb_id = get_compile_time_arg_val(6); // output shard with padding and halo goes here
constexpr uint32_t pad_cb_id = get_compile_time_arg_val(7); // cb for const pad val buffer
constexpr uint32_t pad_val_u32 = get_compile_time_arg_val(8); // pad value to fill pad buffer with
constexpr uint32_t in_nsticks = get_compile_time_arg_val(9); // number of sticks
constexpr uint32_t stick_nbytes = get_compile_time_arg_val(10); // stick size in bytes (post untilize)
constexpr uint32_t is_block_sharded = get_compile_time_arg_val(11);
constexpr uint32_t remote_read = get_compile_time_arg_val(12);
constexpr bool is_col_major = get_compile_time_arg_val(13) == 1;
constexpr uint32_t is_width_sharded = get_compile_time_arg_val(14);
constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(15);
constexpr uint32_t noc_00_x = get_compile_time_arg_val(16);
constexpr uint32_t noc_00_y = get_compile_time_arg_val(17);
constexpr uint32_t num_cores_nhw = get_compile_time_arg_val(18);
constexpr uint32_t num_cores_c = get_compile_time_arg_val(19);
constexpr uint32_t num_cores_x = get_compile_time_arg_val(20);
constexpr uint32_t semaphore_id = get_compile_time_arg_val(21);
constexpr uint32_t max_out_nsticks_per_core = get_compile_time_arg_val(22);

constexpr uint32_t num_cores = num_cores_nhw * num_cores_c;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ operation::ProgramWithCallbacks UntilizeWithHaloV2::create_program(
local_config,
remote_config,
std::nullopt,
std::nullopt,
remote_read_,
transpose_mcast_,
output_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
const Tensor& padding_config,
const Tensor& local_config,
const Tensor& remote_config,
std::optional<std::reference_wrapper<const Tensor>> remote_ref_counts,
std::optional<std::reference_wrapper<const Tensor>> remote_temp,
const bool remote_read,
const bool transpose_mcast,
Expand Down Expand Up @@ -122,7 +121,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
uint32_t padding_config_cb_id = tt::CBIndex::c_2;
uint32_t local_config_cb_id = tt::CBIndex::c_3;
uint32_t remote_config_cb_id = tt::CBIndex::c_4;
uint32_t remote_ref_counts_cb_id = remote_ref_counts.has_value() ? tt::CBIndex::c_5 : 0;
uint32_t remote_temp_cb_id = remote_temp.has_value() ? tt::CBIndex::c_6 : 0;

tt::DataFormat kernel_config_df = tt::DataFormat::RawUInt16; // NOTE: UInt16 is not supported for CB types
Expand Down Expand Up @@ -183,20 +181,8 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
uint32_t semaphore_id = 0;
if (in_place) {
TT_ASSERT(!remote_read, "remote_read is not supported for in place operation");
TT_ASSERT(
(remote_ref_counts.has_value() && remote_temp.has_value()) ||
(!remote_ref_counts.has_value() && !remote_temp.has_value()),
"remote_ref_counts and remote_temp should be both present or absent");
if (remote_ref_counts.has_value()) {
auto remote_ref_counts_buffer = remote_ref_counts.value().get().device_buffer();
auto remote_ref_counts_cb_config =
CircularBufferConfig(
remote_ref_counts_buffer->size() / num_cores, {{remote_ref_counts_cb_id, kernel_config_df}})
.set_page_size(remote_ref_counts_cb_id, remote_ref_counts_buffer->page_size())
.set_globally_allocated_address(*remote_ref_counts_buffer);
CBHandle remote_ref_counts_cb = CreateCircularBuffer(program, all_cores, remote_ref_counts_cb_config);
}

// create the remote temp CB
if (remote_temp.has_value()) {
auto remote_temp_buffer = remote_temp.value().get().device_buffer();
auto remote_temp_cb_config =
Expand All @@ -206,6 +192,7 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
CBHandle remote_temp_cb = CreateCircularBuffer(program, all_cores, remote_temp_cb_config);
}

// compute core data and create semaphore
auto core_id_to_noc_coords = [is_block_sharded, transpose_mcast, device](uint32_t core_id) -> CoreCoord {
auto num_cores_x = device->compute_with_storage_grid_size().x;
auto core_coord = is_block_sharded ? (transpose_mcast ? CoreCoord(core_id, 0) : CoreCoord(0, core_id))
Expand All @@ -215,7 +202,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
noc_00 = core_id_to_noc_coords(0);
num_cores_x = device->compute_with_storage_grid_size().x;
num_cores_y = device->compute_with_storage_grid_size().y;

semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0);
}

Expand All @@ -231,7 +217,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
0, // padding_config_cb_id
0, // local_config_cb_id
0, // remote_config_cb_id
0, // remote_ref_counts_cb_id
0, // remote_temp_cb_id
src_cb_id,
input_to_writer_cb_id,
Expand All @@ -256,8 +241,7 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
reader_ct_args[0] = 0;
reader_ct_args[1] = local_config_cb_id;
reader_ct_args[2] = remote_config_cb_id;
reader_ct_args[3] = 0;
reader_ct_args[4] = remote_temp_cb_id;
reader_ct_args[3] = remote_temp_cb_id;

KernelHandle reader_kernel_id0 = CreateKernel(
program,
Expand All @@ -273,7 +257,6 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
reader_ct_args[1] = 0;
reader_ct_args[2] = 0;
reader_ct_args[3] = 0;
reader_ct_args[4] = 0;

KernelHandle reader_kernel_id1 = CreateKernel(
program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ tt::tt_metal::operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2(
const Tensor& padding_config,
const Tensor& local_config,
const Tensor& remote_config,
std::optional<std::reference_wrapper<const Tensor>> remote_ref_counts,
std::optional<std::reference_wrapper<const Tensor>> remote_temp,
const bool remote_read,
const bool transpose_mcast,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program(
const auto& pad_config = std::get<0>(kernel_config);
const auto& local_config = std::get<1>(kernel_config);
const auto& remote_config = std::get<2>(kernel_config);
const auto& remote_ref_counts = std::get<3>(kernel_config);
const auto& max_ref_size = std::get<4>(kernel_config);
const auto& max_ref_size = std::get<3>(kernel_config);

auto pad_config_tensor =
sliding_window::construct_on_host_config_tensor(pad_config, this->config_, this->parallel_config_);
Expand All @@ -142,29 +141,14 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program(
sliding_window::move_config_tensor_to_device(remote_config_tensor, parallel_config_, is_block_sharded, device);

std::optional<Tensor> remote_temp_device_tensor;
std::optional<Tensor> remote_ref_counts_device_tensor;
if (max_ref_size > 0 && this->in_place_) {
// create the remote temp tensor, TODO do we need to vary this type for bfloat8?
int remote_temp_size = max_ref_size * stick_size * num_cores;
// TODO do we need to vary this type for bfloat8?
auto remote_temp_buffer = owned_buffer::create<bfloat16>(std::vector<bfloat16>(remote_temp_size));
ttnn::Shape remote_temp_shape = ttnn::Shape({num_cores, max_ref_size, stick_size});
Tensor remote_temp_tensor(OwnedStorage{remote_temp_buffer}, remote_temp_shape, type, Layout::ROW_MAJOR);

uint32_t repeat_factor = num_cores;
std::vector<std::vector<uint16_t>> remote_ref_counts_repeated;
for (uint32_t i = 0; i < repeat_factor; ++i) {
remote_ref_counts_repeated.push_back(remote_ref_counts);
}
ttnn::Shape config_shape(
{(uint32_t)remote_ref_counts_repeated.size(), (uint32_t)remote_ref_counts_repeated[0].size()});
std::vector<uint16_t> config_vector = sliding_window::flatten(remote_ref_counts_repeated);
auto config_buffer = owned_buffer::create<uint16_t>(std::move(config_vector));
auto remote_ref_counts_tensor =
Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR);

remote_ref_counts_device_tensor = sliding_window::move_config_tensor_to_device(
remote_ref_counts_tensor, parallel_config_, is_block_sharded, device);

// move tensors to device
auto shard_shape = std::array<uint32_t, 2>({max_ref_size, stick_size});
ShardSpec shard_spec(parallel_config_.grid, shard_shape, ShardOrientation::ROW_MAJOR);
MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1, shard_spec};
Expand All @@ -183,9 +167,6 @@ operation::ProgramWithCallbacks HaloDeviceOperation::create_program(
pad_config_device_tensor,
local_config_device_tensor,
remote_config_device_tensor,
remote_ref_counts_device_tensor
? std::optional<std::reference_wrapper<const Tensor>>(*remote_ref_counts_device_tensor)
: std::nullopt,
remote_temp_device_tensor ? std::optional<std::reference_wrapper<const Tensor>>(*remote_temp_device_tensor)
: std::nullopt,
remote_read_,
Expand Down
17 changes: 6 additions & 11 deletions ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ std::tuple<
std::vector<std::vector<uint16_t>>,
std::vector<std::vector<uint16_t>>,
std::vector<std::vector<uint16_t>>,
std::vector<uint16_t>,
int>
generate_halo_kernel_config_tensors(
const std::vector<PixelMetadata>& tensor_metadata,
Expand Down Expand Up @@ -513,9 +512,8 @@ generate_halo_kernel_config_tensors(
return flattened_config;
};

auto flatten_remote_config =
[core_id_to_noc_coords,
&device](auto& config) -> std::tuple<std::vector<std::vector<uint16_t>>, std::vector<uint16_t>, int> {
auto flatten_remote_config = [core_id_to_noc_coords,
&device](auto& config) -> std::tuple<std::vector<std::vector<uint16_t>>, int> {
// find max length
size_t max_len = 0;
for (auto& core_config : config) {
Expand All @@ -529,9 +527,8 @@ generate_halo_kernel_config_tensors(
int num_cores_x = device->compute_with_storage_grid_size().x;
int num_cores_y = device->compute_with_storage_grid_size().y;
int num_cores = num_cores_x * num_cores_y;
std::vector<uint16_t> remote_ref_counts(num_cores, 0);
CoreCoord noc_00 = core_id_to_noc_coords(0);
int max_ref_size = 0;
int max_ref_size = 0; // track the max remote ref size for sizing the remote temp tensor
int core = 0;
for (auto& core_config : config) {
std::vector<uint16_t> flat_data(max_len, 0);
Expand All @@ -543,7 +540,6 @@ generate_halo_kernel_config_tensors(
flat_data[idx++] = nocy;
flat_data[idx++] = len;
int ref_ind = nocx - noc_00.x + (nocy - noc_00.y) * num_cores_x;
remote_ref_counts[ref_ind]++;
for (size_t i = 0; i < key_data.second.size(); ++i) {
auto [src_start, dst_start, length] = key_data.second[i];
flat_data[idx++] = src_start;
Expand All @@ -561,12 +557,12 @@ generate_halo_kernel_config_tensors(
flattened_config.emplace_back(flat_data);
}

return std::make_tuple(flattened_config, remote_ref_counts, max_ref_size);
return std::make_tuple(flattened_config, max_ref_size);
};

auto flattened_pad_config = flatten_pad_config(pad_config);
auto flattened_local_config = flatten_local_config(local_config);
auto [flattened_remote_config, remote_ref_counts, max_ref_size] = flatten_remote_config(remote_config);
auto [flattened_remote_config, max_ref_size] = flatten_remote_config(remote_config);

auto align_config = [](auto& config, size_t align_granularity = 1, uint16_t align_value = 0) {
size_t max_len = 0;
Expand All @@ -591,8 +587,7 @@ generate_halo_kernel_config_tensors(
align_config(flattened_local_config, 2);
align_config(flattened_remote_config, 2);

return std::make_tuple(
flattened_pad_config, flattened_local_config, flattened_remote_config, remote_ref_counts, max_ref_size);
return std::make_tuple(flattened_pad_config, flattened_local_config, flattened_remote_config, max_ref_size);
}

std::vector<std::vector<uint16_t>> generate_sliding_window_op_config(
Expand Down
Loading

0 comments on commit a98659d

Please sign in to comment.