Skip to content

Commit

Permalink
#18332: Move input stats to reader file (#18335)
Browse files Browse the repository at this point in the history
Continuation of another PR. Will be merged once CI passes
Used for testing
  • Loading branch information
VirdhatchaniKN committed Feb 26, 2025
1 parent 8b3ef53 commit 2afcad4
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ void set_or_update_runtime_arguments(
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);
constexpr size_t num_reader_args = 9;
constexpr size_t num_writer_args = 12;
constexpr size_t num_reader_args = 15;
constexpr size_t num_writer_args = 8;
constexpr size_t num_kernel_args = 3;
for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) {
const auto& core = cores[i];
Expand All @@ -78,6 +78,9 @@ void set_or_update_runtime_arguments(
? std::bit_cast<uint32_t>(scalar)
: pack_two_bfloat16_into_uint32({scalar, scalar});

const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0;

std::array reader_runtime_args = {
packed_scalar_eps,
input_tensor.buffer()->address(),
Expand All @@ -87,17 +90,18 @@ void set_or_update_runtime_arguments(
aHt * aWt * aC * (aN > 1),
aHt * aWt * (aC > 1),
cN,
cC};
cC,
bHt * bWt * bC * (bN > 1),
bHt * bWt * (bC > 1),
batch_var_tensor.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
batch_mean_tensor.buffer()->address() // batch mean
};
handle_args(program, reader_kernel_id, core, reader_runtime_args);

const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0;
std::array writer_runtime_args = {
batch_mean_tensor.buffer()->address(), // batch mean
batch_var_tensor.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
c.buffer()->address(), // output
c.buffer()->address(), // output
start_tile_id,
num_tiles_per_core,
cHtWt,
Expand Down Expand Up @@ -232,7 +236,21 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp",
all_device_cores,
tt_metal::ReaderDataMovementConfig({a_is_dram, input_tensor_cb, eps_cb}, std::move(reader_defines)));
tt_metal::ReaderDataMovementConfig(
{a_is_dram,
input_tensor_cb,
eps_cb,
d_is_dram,
batch_var_tensor_cb,
e_is_dram,
weight_tensor_cb,
static_cast<uint32_t>(weight_has_value),
bias_tensor_cb,
f_is_dram,
static_cast<uint32_t>(bias_has_value),
b_is_dram,
batch_mean_tensor_cb},
std::move(reader_defines)));

// WRITER KERNEL
auto writer_defines = dataflow_defines;
Expand All @@ -242,18 +260,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
all_device_cores,
tt_metal::WriterDataMovementConfig(
{
b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
batch_mean_tensor_cb,
output_tensor_cb,
batch_var_tensor_cb,
weight_tensor_cb,
bias_tensor_cb,
},
std::move(writer_defines)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ void kernel_main() {
uint32_t c_stride = get_arg_val<uint32_t>(6);
uint32_t N = get_arg_val<uint32_t>(7);
uint32_t C = get_arg_val<uint32_t>(8);
uint32_t n_stride_stat = get_arg_val<uint32_t>(9);
uint32_t c_stride_stat = get_arg_val<uint32_t>(10);
uint32_t batch_var_addr = get_arg_val<uint32_t>(11); // batch_var
uint32_t weight_addr = get_arg_val<uint32_t>(12); // weight
uint32_t bias_addr = get_arg_val<uint32_t>(13); // bias
uint32_t batch_mean_addr = get_arg_val<uint32_t>(14); // batch_mean

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

Expand Down Expand Up @@ -54,12 +60,92 @@ void kernel_main() {
// Input tile offset
uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t;

// Inputs stats offset
uint32_t tile_offset_stat = start_n * n_stride_stat + start_c * c_stride_stat;
uint32_t next_batch_shift_stat = n_stride_stat - c_stride_stat * C;

uint32_t next_channel_shift = c_stride - HtWt;
uint32_t next_batch_shift = n_stride - c_stride * C;

// batch_mean
constexpr auto cb_id_batch_mean = get_compile_time_arg_val(12);
constexpr bool batch_mean_is_dram = get_compile_time_arg_val(11) == 1;
const uint32_t batch_mean_tile_bytes = get_tile_size(cb_id_batch_mean);
const DataFormat batch_mean_data_format = get_dataformat(cb_id_batch_mean);

const InterleavedAddrGenFast<batch_mean_is_dram> batch_mean = {
.bank_base_address = batch_mean_addr,
.page_size = batch_mean_tile_bytes,
.data_format = batch_mean_data_format};

// batch_var
constexpr auto cb_id_batch_var = get_compile_time_arg_val(4);
constexpr bool batch_var_is_dram = get_compile_time_arg_val(3) == 1;
const uint32_t batch_var_tile_bytes = get_tile_size(cb_id_batch_var);
const DataFormat batch_var_data_format = get_dataformat(cb_id_batch_var);

const InterleavedAddrGenFast<batch_var_is_dram> batch_var = {
.bank_base_address = batch_var_addr, .page_size = batch_var_tile_bytes, .data_format = batch_var_data_format};

// weight
constexpr auto cb_id_weight = get_compile_time_arg_val(6);
constexpr bool weight_is_dram = get_compile_time_arg_val(5) == 1;
const uint32_t weight_tile_bytes = get_tile_size(cb_id_weight);
const DataFormat weight_data_format = get_dataformat(cb_id_weight);

const InterleavedAddrGenFast<weight_is_dram> weight = {
.bank_base_address = weight_addr, .page_size = weight_tile_bytes, .data_format = weight_data_format};

constexpr bool weight_has_value = get_compile_time_arg_val(7) == 1;

// bias
constexpr auto cb_id_bias = get_compile_time_arg_val(8);
constexpr bool bias_is_dram = get_compile_time_arg_val(9) == 1;
const uint32_t bias_tile_bytes = get_tile_size(cb_id_bias);
const DataFormat bias_data_format = get_dataformat(cb_id_bias);

const InterleavedAddrGenFast<bias_is_dram> bias = {
.bank_base_address = bias_addr, .page_size = bias_tile_bytes, .data_format = bias_data_format};

constexpr bool bias_has_value = get_compile_time_arg_val(10) == 1;

uint32_t num_tiles_read = 0;
for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) {
for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) {
// read a tile from batch_mean
cb_reserve_back(cb_id_batch_mean, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_batch_mean);
noc_async_read_tile(tile_offset_stat, batch_mean, l1_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_mean);
cb_push_back(cb_id_batch_mean, onetile);

// read a tile from batch variance
cb_reserve_back(cb_id_batch_var, onetile);
uint32_t l1_batch_var_write_addr = get_write_ptr(cb_id_batch_var);
noc_async_read_tile(tile_offset_stat, batch_var, l1_batch_var_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_var);
cb_push_back(cb_id_batch_var, onetile);

if constexpr (weight_has_value) { // read a tile from weight tensor
cb_reserve_back(cb_id_weight, onetile);
uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight);
noc_async_read_tile(tile_offset_stat, weight, l1_weight_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_weight);
cb_push_back(cb_id_weight, onetile);
}

if constexpr (bias_has_value) { // read a tile from bias tensor
cb_reserve_back(cb_id_bias, onetile);
uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias);
noc_async_read_tile(tile_offset_stat, bias, l1_bias_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_bias);
cb_push_back(cb_id_bias, onetile);
}

for (uint32_t t = start_t; t < HtWt && num_tiles_read < num_tiles; ++t, ++num_tiles_read, ++tile_offset) {
cb_reserve_back(cb_id_src, onetile);
uint32_t l1_write_addr_src = get_write_ptr(cb_id_src);
Expand All @@ -68,7 +154,9 @@ void kernel_main() {
cb_push_back(cb_id_src, onetile);
}
tile_offset += next_channel_shift;
tile_offset_stat += c_stride_stat;
}
tile_offset += next_batch_shift;
tile_offset_stat += next_batch_shift_stat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,26 @@
#include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp"

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0); // batch_mean
uint32_t batch_var_addr = get_arg_val<uint32_t>(1); // batch_var
uint32_t weight_addr = get_arg_val<uint32_t>(2); // weight
uint32_t bias_addr = get_arg_val<uint32_t>(3); // bias
uint32_t dst_addr = get_arg_val<uint32_t>(4); // output
uint32_t start_tile_id = get_arg_val<uint32_t>(5);
uint32_t num_tiles = get_arg_val<uint32_t>(6);
uint32_t HtWt = get_arg_val<uint32_t>(7);
uint32_t n_stride = get_arg_val<uint32_t>(8);
uint32_t c_stride = get_arg_val<uint32_t>(9);
uint32_t N = get_arg_val<uint32_t>(10);
uint32_t C = get_arg_val<uint32_t>(11);
uint32_t dst_addr = get_arg_val<uint32_t>(0); // output
uint32_t start_tile_id = get_arg_val<uint32_t>(1);
uint32_t num_tiles = get_arg_val<uint32_t>(2);
uint32_t HtWt = get_arg_val<uint32_t>(3);
uint32_t n_stride = get_arg_val<uint32_t>(4);
uint32_t c_stride = get_arg_val<uint32_t>(5);
uint32_t N = get_arg_val<uint32_t>(6);
uint32_t C = get_arg_val<uint32_t>(7);

constexpr uint32_t onetile = 1;

// batch_mean
constexpr auto cb_id_src = get_compile_time_arg_val(7);
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

// output
constexpr auto cb_id_dst = get_compile_time_arg_val(8);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
constexpr auto cb_id_dst = get_compile_time_arg_val(1);
constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);

const InterleavedAddrGenFast<dst_is_dram> dst = {
.bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format};

// batch_var
constexpr auto cb_id_batch_var = get_compile_time_arg_val(9);
constexpr bool batch_var_is_dram = get_compile_time_arg_val(2) == 1;
const uint32_t batch_var_tile_bytes = get_tile_size(cb_id_batch_var);
const DataFormat batch_var_data_format = get_dataformat(cb_id_batch_var);

const InterleavedAddrGenFast<batch_var_is_dram> batch_var = {
.bank_base_address = batch_var_addr, .page_size = batch_var_tile_bytes, .data_format = batch_var_data_format};

// weight
constexpr auto cb_id_weight = get_compile_time_arg_val(10);
constexpr bool weight_is_dram = get_compile_time_arg_val(3) == 1;
const uint32_t weight_tile_bytes = get_tile_size(cb_id_weight);
const DataFormat weight_data_format = get_dataformat(cb_id_weight);

const InterleavedAddrGenFast<weight_is_dram> weight = {
.bank_base_address = weight_addr, .page_size = weight_tile_bytes, .data_format = weight_data_format};

// bias
constexpr auto cb_id_bias = get_compile_time_arg_val(11);
constexpr bool bias_is_dram = get_compile_time_arg_val(4) == 1;
const uint32_t bias_tile_bytes = get_tile_size(cb_id_bias);
const DataFormat bias_data_format = get_dataformat(cb_id_bias);

const InterleavedAddrGenFast<bias_is_dram> bias = {
.bank_base_address = bias_addr, .page_size = bias_tile_bytes, .data_format = bias_data_format};

constexpr bool weight_has_value = get_compile_time_arg_val(5) == 1;
constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1;

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
uint32_t start_remaining = start_tile_id % tiles_per_batch;
Expand All @@ -84,40 +41,6 @@ void kernel_main() {
uint32_t num_tiles_written = 0;
for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) {
for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) {
// read a tile from src
cb_reserve_back(cb_id_src, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_src);
noc_async_read_tile(tile_offset, src, l1_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_src);
cb_push_back(cb_id_src, onetile);

// read a tile from batch variance
cb_reserve_back(cb_id_batch_var, onetile);
uint32_t l1_batch_var_write_addr = get_write_ptr(cb_id_batch_var);
noc_async_read_tile(tile_offset, batch_var, l1_batch_var_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_var);
cb_push_back(cb_id_batch_var, onetile);

if constexpr (weight_has_value) { // read a tile from weight tensor
cb_reserve_back(cb_id_weight, onetile);
uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight);
noc_async_read_tile(tile_offset, weight, l1_weight_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_weight);
cb_push_back(cb_id_weight, onetile);
}

if constexpr (bias_has_value) { // read a tile from bias tensor
cb_reserve_back(cb_id_bias, onetile);
uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias);
noc_async_read_tile(tile_offset, bias, l1_bias_write_addr);
noc_async_read_barrier();
FILL_TILE_WITH_FIRST_ELEMENT(cb_id_bias);
cb_push_back(cb_id_bias, onetile);
}

for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) {
// write a tile to dst
cb_wait_front(cb_id_dst, onetile);
Expand Down

0 comments on commit 2afcad4

Please sign in to comment.