Skip to content

Commit

Permalink
#13931: Move index_size to compile time
Browse files Browse the repository at this point in the history
  • Loading branch information
thanhnguyen-moreh committed Oct 26, 2024
1 parent 95b01ba commit 127be69
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c
Device* device = input.device();

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;

auto [num_cores, all_cores, core_group_1, core_group_2, num_rows_per_core_group_1, num_rows_per_core_group_2] =
Expand Down Expand Up @@ -107,7 +108,8 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c
(std::uint32_t)index_is_dram,
(std::uint32_t)src_cb_index,
(std::uint32_t)index_cb_index,
(std::uint32_t)(dim == n - 1)};
(std::uint32_t)(dim == n - 1),
(std::uint32_t)index.volume()};

auto reader_kernel_id = CreateKernel(
program,
Expand All @@ -124,10 +126,11 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c
WriterDataMovementConfig(writer_compile_time_args));

uint32_t unit_offset = 0;
for (uint32_t i = 0; i < num_cores; i++) {
const CoreCoord core(i / num_cores_y, i % num_cores_y);

uint32_t num_rows_per_core = 0;
uint32_t num_cores_group_1 = core_group_1.num_cores();
auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y);
for (uint32_t i = 0; i < cores.size(); i++) {
const auto& core = cores[i];
uint32_t num_rows_per_core = i < num_cores_group_1 ? num_rows_per_core_group_1 : num_rows_per_core_group_2;
if (core_group_1.core_coord_in_core_ranges(core)) {
num_rows_per_core = num_rows_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Expand All @@ -146,8 +149,6 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c
index_unit_size,
unit_offset,
num_rows_per_core,
input_shape[-1],
index.volume(),
num_rows_to_fill_per_index,
input_shape[dim]});
SetRuntimeArgs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ void kernel_main() {
uint32_t index_page_size = get_arg_val<uint32_t>(4);
uint32_t start_row_id = get_arg_val<uint32_t>(5);
uint32_t num_rows_per_core = get_arg_val<uint32_t>(6);
uint32_t row_size = get_arg_val<uint32_t>(7);
uint32_t index_size = get_arg_val<uint32_t>(8);
uint32_t num_rows_to_fill_per_index = get_arg_val<uint32_t>(9);
uint32_t dim = get_arg_val<uint32_t>(10);
uint32_t num_rows_to_fill_per_index = get_arg_val<uint32_t>(7);
uint32_t dim = get_arg_val<uint32_t>(8);

constexpr bool input_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool index_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t src_cb_id = get_compile_time_arg_val(2);
constexpr uint32_t index_cb_id = get_compile_time_arg_val(3);
constexpr bool is_last_dim = get_compile_time_arg_val(4) == 1;
constexpr uint32_t index_size = get_compile_time_arg_val(5);

constexpr uint32_t onetile = 1;

Expand All @@ -55,7 +54,6 @@ void kernel_main() {
noc_async_read(index_noc_addr, index_cb_reader, index_page_size);
noc_async_read_barrier();
uint32_t *index_ptr = reinterpret_cast<uint32_t *>(index_cb_reader);

if (is_last_dim) {
for (uint32_t row_id = start_row_id; row_id < start_row_id + num_rows_per_core; row_id++) {
cb_reserve_back(src_cb_id, onetile);
Expand All @@ -66,7 +64,7 @@ void kernel_main() {

uint32_t *input_ptr = reinterpret_cast<uint32_t *>(src_cb_reader);

for (uint32_t i = 0; i < index_page_size / 4; i++) {
for (uint32_t i = 0; i < index_size; i++) {
uint32_t current_index = index_ptr[i];
input_ptr[current_index] = fill_value;
}
Expand All @@ -81,22 +79,22 @@ void kernel_main() {
noc_async_read(input_noc_addr, src_cb_reader, input_page_size);
noc_async_read_barrier();

if (is_in_indices(index_ptr, index_page_size / 4, row_id / num_rows_to_fill_per_index % dim)) {
if (is_in_indices(index_ptr, index_size, row_id / num_rows_to_fill_per_index % dim)) {
#ifdef OUTPUT_DTYPE_BFLOAT16
auto ptr = reinterpret_cast<uint16_t *>(write_addr);
for (uint32_t i = 0; i < input_page_size / 4; ++i) {
for (uint32_t i = 0; i < index_size; ++i) {
ptr[i] = val.u >> 16;
}
#endif
#ifdef OUTPUT_DTYPE_INT32
auto ptr = reinterpret_cast<uint32_t *>(write_addr);
for (uint32_t i = 0; i < input_page_size / 4; ++i) {
for (uint32_t i = 0; i < index_size; ++i) {
ptr[i] = fill_value;
}
#endif
#ifdef OUTPUT_DTYPE_FLOAT32
auto ptr = reinterpret_cast<float *>(write_addr);
for (uint32_t i = 0; i < input_page_size / 4; ++i) {
for (uint32_t i = 0; i < index_size; ++i) {
ptr[i] = val.f;
}
#endif
Expand Down

0 comments on commit 127be69

Please sign in to comment.