Skip to content

Commit

Permalink
#13086: Revising moreh_getitem (#13087)
Browse files Browse the repository at this point in the history
* #13086: Revising moreh_getitem

* #13086: resolve comment
  • Loading branch information
thd1007 authored Sep 28, 2024
1 parent 5acea64 commit 27e1896
Show file tree
Hide file tree
Showing 15 changed files with 432 additions and 266 deletions.
397 changes: 281 additions & 116 deletions tests/ttnn/unit_tests/operations/test_moreh_getitem.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void MorehGetItemOperation::validate_inputs(
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to getitem need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to getitem need to be allocated in buffers on device!");
auto dtype = input_tensor.get_dtype();
TT_FATAL(dtype == DataType::INT32 || dtype == DataType::BFLOAT16, "Input tensor must be of type INT32 or BFLOAT16!");
TT_FATAL(
dtype == DataType::INT32 || dtype == DataType::BFLOAT16, "Input tensor must be of type INT32 or BFLOAT16!");

// validate index tensors
uint32_t index_size = index_tensors[0].get_shape()[-1];
Expand Down Expand Up @@ -53,7 +54,8 @@ void MorehGetItemOperation::validate_inputs(
for (auto dim : operation_attributes.index_dims) {
TT_FATAL(
dim_start + i == dim,
"The value of index_dims={} must be consecutive integers.", operation_attributes.index_dims);
"The value of index_dims={} must be consecutive integers.",
operation_attributes.index_dims);
i++;
}
if (!output_tensor.has_value()) {
Expand Down Expand Up @@ -184,7 +186,7 @@ MorehGetItemOperation::tensor_return_value_t MorehGetItemOperation::create_outpu
tensor_args.input.get_dtype(),
tensor_args.input.get_layout(),
tensor_args.input.device(),
operation_attributes.output_memory_config);
operation_attributes.memory_config);
};

std::tuple<MorehGetItemOperation::operation_attributes_t, MorehGetItemOperation::tensor_args_t>
Expand All @@ -193,8 +195,8 @@ MorehGetItemOperation::invoke(
const std::vector<Tensor>& index_tensors,
const std::vector<uint32_t> index_dims,
const std::optional<Tensor>& output,
const std::optional<MemoryConfig> output_memory_config) {
operation_attributes_t operation_attributes = {index_dims, output_memory_config.value_or(input.memory_config())};
const std::optional<MemoryConfig> memory_config) {
operation_attributes_t operation_attributes = {index_dims, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args = {input, index_tensors, output};
return {operation_attributes, tensor_args};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct MorehGetItemOperation {
struct operation_attributes_t {
const std::vector<uint32_t> index_dims;
// const CoreRange core_range;
const MemoryConfig output_memory_config;
const MemoryConfig memory_config;
};

struct tensor_args_t {
Expand Down Expand Up @@ -90,7 +90,7 @@ struct MorehGetItemOperation {
const std::vector<uint32_t> index_dims,
const std::optional<Tensor>& output,
// const CoreRange core_range,
const std::optional<MemoryConfig> output_memory_config);
const std::optional<MemoryConfig> memory_config);
};
} // namespace ttnn::operations::moreh::moreh_getitem

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void kernel_main() {
bool is_first_index = true;
int32_t output_dim = 3;
for (int32_t dim = 3; dim >= 0; dim--) {

uint32_t input_stick_idx_stride = input_stick_idx_strides[dim];
auto output_size = output_size_list[output_dim];

Expand Down Expand Up @@ -167,8 +166,7 @@ void kernel_main() {
noc_async_read(index_noc_addr, index_l1_addr, index_stick_sizes[dim]);
noc_async_read_barrier();

volatile tt_l1_ptr int32_t* index_l1_ptr =
reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);
volatile tt_l1_ptr int32_t* index_l1_ptr = reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);
int32_t noc_idx = index_l1_ptr[index_index];

if (noc_idx < 0) {
Expand All @@ -186,7 +184,7 @@ void kernel_main() {
output_stick_idx /= output_size;
}
if (!(index_start_dim < dim && dim <= index_end_dim)) {
output_dim --;
output_dim--;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ void kernel_main() {

constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1;

const InterleavedAddrGen<dst_is_dram> s0 = {
.bank_base_address = dst_addr,
.page_size = output_stick_size
};
const InterleavedAddrGen<dst_is_dram> s0 = {.bank_base_address = dst_addr, .page_size = output_stick_size};

uint32_t end_id = start_id + num_sticks;
for (uint32_t i = start_id; i < end_id; ++ i) {
for (uint32_t i = start_id; i < end_id; ++i) {
cb_wait_front(cb_id_out, 1);
uint32_t l1_read_addr = get_read_ptr(cb_id_out);
uint64_t dst_noc_addr = get_noc_addr(i, s0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera
auto index_tensors = tensor_args.index_tensors;
auto output = output_tensor;
auto index_dims = operation_attributes.index_dims;
auto output_memory_config = operation_attributes.output_memory_config;
auto memory_config = operation_attributes.memory_config;
// auto core_range = operation_attributes.core_range;
auto device = input.device();
auto grid_coord = device->compute_with_storage_grid_size();
Expand Down Expand Up @@ -94,9 +94,8 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera

auto src_cb_index = CB::c_in0;
auto rounded_input_page_size = round_up_to_mul32(input_unit_size);
auto cb_src0_config =
CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0_config = CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config);

for (uint32_t dim = 0; dim < 5; dim++) {
Expand All @@ -105,17 +104,15 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera

auto src1_cb_index = CB::c_in1 + dim;
auto index_page_size = round_up_to_mul32(index_info[dim].unit_size);
auto cb_index_config =
CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_index_config = CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config);
}

auto out_cb_index = CB::c_out0;
auto rounded_output_page_size = round_up_to_mul32(input_unit_size);
auto cb_out0_config =
CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}})
.set_page_size(out_cb_index, rounded_input_page_size);
auto cb_out0_config = CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}})
.set_page_size(out_cb_index, rounded_input_page_size);
auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config);

// create read/wrtie kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ MorehGetItemOperation::MorehGetItemTilizedFactory::create(
auto index_tensors = tensor_args.index_tensors;
auto output = output_tensor;
auto index_dims = operation_attributes.index_dims;
auto output_memory_config = operation_attributes.output_memory_config;
auto memory_config = operation_attributes.memory_config;
auto TILE_HEIGHT = constants::TILE_HEIGHT;
auto TILE_WIDTH = constants::TILE_WIDTH;
// auto core_range = operation_attributes.core_range;
Expand Down Expand Up @@ -122,9 +122,8 @@ MorehGetItemOperation::MorehGetItemTilizedFactory::create(

auto src_cb_index = CB::c_in0;
auto rounded_input_page_size = round_up_to_mul32(input_unit_size);
auto cb_src0_config =
CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0_config = CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config);

for (uint32_t dim = 0; dim < 5; dim++) {
Expand All @@ -133,23 +132,20 @@ MorehGetItemOperation::MorehGetItemTilizedFactory::create(

auto src1_cb_index = CB::c_in1 + dim;
auto index_page_size = 1024 * 4;
auto cb_index_config =
CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_index_config = CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config);
}

auto out_cb0_index = CB::c_out0;
auto rounded_output_page_size = round_up_to_mul32(output_unit_size);
auto cb_out0_config =
CircularBufferConfig(rounded_output_page_size, {{out_cb0_index, output_cb_data_format}})
.set_page_size(out_cb0_index, rounded_output_page_size);
auto cb_out0_config = CircularBufferConfig(rounded_output_page_size, {{out_cb0_index, output_cb_data_format}})
.set_page_size(out_cb0_index, rounded_output_page_size);
auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config);

auto out_cb1_index = CB::c_out1;
auto cb_out1_config =
CircularBufferConfig(rounded_output_page_size, {{out_cb1_index, output_cb_data_format}})
.set_page_size(out_cb1_index, rounded_output_page_size);
auto cb_out1_config = CircularBufferConfig(rounded_output_page_size, {{out_cb1_index, output_cb_data_format}})
.set_page_size(out_cb1_index, rounded_output_page_size);
auto cb_out1 = CreateCircularBuffer(program, all_cores, cb_out1_config);

// create read/wrtie kernel
Expand Down Expand Up @@ -359,9 +355,8 @@ MorehGetItemOperation::MorehGetItemTilizedFactory::create(

auto src_cb_index = CB::c_in0;
auto rounded_input_page_size = round_up_to_mul32(input_unit_size);
auto cb_src0_config =
CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0_config = CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}})
.set_page_size(src_cb_index, rounded_input_page_size);
auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config);

for (uint32_t dim = 0; dim < 5; dim++) {
Expand All @@ -371,17 +366,15 @@ MorehGetItemOperation::MorehGetItemTilizedFactory::create(
auto src1_cb_index = CB::c_in1 + dim;
// auto index_page_size = round_up_to_mul32(index_info[dim].unit_size);
auto index_page_size = 1024 * 4;
auto cb_index_config =
CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_index_config = CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}})
.set_page_size(src1_cb_index, index_page_size);
auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config);
}

auto out_cb_index = CB::c_out0;
auto rounded_output_page_size = round_up_to_mul32(input_unit_size);
auto cb_out0_config =
CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}})
.set_page_size(out_cb_index, rounded_input_page_size);
auto cb_out0_config = CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}})
.set_page_size(out_cb_index, rounded_input_page_size);
auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config);

// create read/wrtie kernel
Expand Down Expand Up @@ -585,7 +578,7 @@ void MorehGetItemOperation::MorehGetItemTilizedFactory::override_runtime_argumen
runtime_args[2] = index_info[1].address;
runtime_args[3] = index_info[2].address;
runtime_args[4] = index_info[3].address;
runtime_args[4] = index_info[4].address;
runtime_args[5] = index_info[4].address;
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ uint32_t get_noc_offset_in_tile(uint32_t stick_h, uint32_t stick_w, uint32_t til

const uint32_t stick_bytes = FACE_WIDTH * element_size;

if (stick_h < FACE_WIDTH && is_even) noc_offset += stick_h * stick_bytes;
else if (stick_h < FACE_WIDTH && is_odd) noc_offset += (16 + stick_h) * stick_bytes;
else if (stick_h >= FACE_WIDTH && is_even) noc_offset += (16 + stick_h) * stick_bytes;
else if (stick_h >= FACE_WIDTH && is_odd) noc_offset += (32 + stick_h) * stick_bytes;
if (stick_h < FACE_WIDTH && is_even)
noc_offset += stick_h * stick_bytes;
else if (stick_h < FACE_WIDTH && is_odd)
noc_offset += (16 + stick_h) * stick_bytes;
else if (stick_h >= FACE_WIDTH && is_even)
noc_offset += (16 + stick_h) * stick_bytes;
else if (stick_h >= FACE_WIDTH && is_odd)
noc_offset += (32 + stick_h) * stick_bytes;

return noc_offset;
}

struct Idx4d
{
struct Idx4d {
uint32_t n;
uint32_t c;
uint32_t h;
Expand Down Expand Up @@ -59,16 +62,16 @@ Idx4d get_tile_indices(Idx4d stick_index_4d) {
return tile_index_4d;
}

struct Idx5d
{
struct Idx5d {
uint32_t n;
uint32_t c;
uint32_t d;
uint32_t h;
uint32_t w;
};

Idx5d get_stick_indices(uint32_t stick_idx, uint32_t size_c, uint32_t size_d, uint32_t size_h, uint32_t num_stick_width) {
Idx5d get_stick_indices(
uint32_t stick_idx, uint32_t size_c, uint32_t size_d, uint32_t size_h, uint32_t num_stick_width) {
Idx5d stick_index_5d;

stick_index_5d.w = stick_idx % num_stick_width;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,11 @@ void kernel_main() {
.page_size = 1024 * element_size,
};

const InterleavedAddrGen<index0_is_dram> index0 = {
.bank_base_address = index0_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index1_is_dram> index1 = {
.bank_base_address = index1_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index2_is_dram> index2 = {
.bank_base_address = index2_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index3_is_dram> index3 = {
.bank_base_address = index3_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index4_is_dram> index4 = {
.bank_base_address = index4_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index0_is_dram> index0 = {.bank_base_address = index0_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index1_is_dram> index1 = {.bank_base_address = index1_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index2_is_dram> index2 = {.bank_base_address = index2_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index3_is_dram> index3 = {.bank_base_address = index3_addr, .page_size = INDEX_TILE_SIZE};
const InterleavedAddrGen<index4_is_dram> index4 = {.bank_base_address = index4_addr, .page_size = INDEX_TILE_SIZE};

uint32_t index_is_defined[5] = {
index0_is_defined,
Expand Down Expand Up @@ -134,8 +129,8 @@ void kernel_main() {
input_stick_idx_stride_w,
};

#define NOC_MINIMUM_READ_SIZE (32)
#define INDEX_SIZE (4)
#define NOC_MINIMUM_READ_SIZE (32)
#define INDEX_SIZE (4)

uint32_t end_id = start_id + num_sticks;

Expand All @@ -160,7 +155,7 @@ void kernel_main() {
index_index = output_stick_idx % index_size;
is_first_index = false;
}
#ifdef TILIZE_INDEX
#ifdef TILIZE_INDEX
uint32_t index_noc_id = index_index / TILE_HEIGHT;
if (dim == 0) {
index_noc_addr = get_noc_addr(index_noc_id, index0);
Expand All @@ -177,17 +172,19 @@ void kernel_main() {
noc_async_read(index_noc_addr, index_l1_addr, INDEX_TILE_SIZE);
noc_async_read_barrier();

volatile tt_l1_ptr int32_t* index_l1_ptr =
reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);
volatile tt_l1_ptr int32_t* index_l1_ptr = reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);
uint32_t index_dim_offset;
uint32_t index_tile_idx = index_index % TILE_WIDTH;
if (index_tile_idx < FACE_WIDTH) index_dim_offset = index_tile_idx;
else index_dim_offset = index_tile_idx + 256 - 16;
if (index_tile_idx < FACE_WIDTH)
index_dim_offset = index_tile_idx;
else
index_dim_offset = index_tile_idx + 256 - 16;

int32_t index_val = index_l1_ptr[index_dim_offset];
#endif
#ifdef ROW_MAJOR_INDEX
uint32_t noc_offset = ((uint32_t)((index_index * INDEX_SIZE) / NOC_MINIMUM_READ_SIZE)) * NOC_MINIMUM_READ_SIZE;
#endif
#ifdef ROW_MAJOR_INDEX
uint32_t noc_offset =
((uint32_t)((index_index * INDEX_SIZE) / NOC_MINIMUM_READ_SIZE)) * NOC_MINIMUM_READ_SIZE;
if (dim == 0) {
index_noc_addr = get_noc_addr(0, index0, noc_offset);
}
Expand All @@ -203,13 +200,12 @@ void kernel_main() {
noc_async_read(index_noc_addr, index_l1_addr, NOC_MINIMUM_READ_SIZE);
noc_async_read_barrier();

volatile tt_l1_ptr int32_t* index_l1_ptr =
reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);
volatile tt_l1_ptr int32_t* index_l1_ptr = reinterpret_cast<volatile tt_l1_ptr int32_t*>(index_l1_addr);

uint32_t index_dim_offset = (index_index * INDEX_SIZE - noc_offset) / INDEX_SIZE;
int32_t index_val = index_l1_ptr[index_dim_offset];

#endif
#endif

if (index_val < 0) {
index_val += input_size_list[dim];
Expand Down Expand Up @@ -240,16 +236,19 @@ void kernel_main() {
cb_reserve_back(cb_in0, 1);
uint32_t l1_write_addr = get_write_ptr(cb_in0);

Idx5d stick_index_5d = get_stick_indices(input_stick_idx, input_size_c_without_padding, input_size_d_without_padding, input_size_h_without_padding, input_num_stick_width);
Idx5d stick_index_5d = get_stick_indices(
input_stick_idx,
input_size_c_without_padding,
input_size_d_without_padding,
input_size_h_without_padding,
input_num_stick_width);
Idx5d tile_index_5d = get_tile_indices(stick_index_5d);

uint32_t noc_id = tile_index_5d.n * input_noc_id_stride_n +
tile_index_5d.c * input_noc_id_stride_c +
tile_index_5d.d * input_noc_id_stride_d +
tile_index_5d.h * input_noc_id_stride_h +
tile_index_5d.w;
uint32_t noc_id = tile_index_5d.n * input_noc_id_stride_n + tile_index_5d.c * input_noc_id_stride_c +
tile_index_5d.d * input_noc_id_stride_d + tile_index_5d.h * input_noc_id_stride_h +
tile_index_5d.w;

uint32_t noc_offset = get_noc_offset_in_tile(stick_index_5d.h , stick_index_5d.w, tile_index_5d.h, element_size);
uint32_t noc_offset = get_noc_offset_in_tile(stick_index_5d.h, stick_index_5d.w, tile_index_5d.h, element_size);

uint64_t src_noc_addr = get_noc_addr(noc_id, s0, noc_offset);

Expand Down
Loading

0 comments on commit 27e1896

Please sign in to comment.