Skip to content

Commit

Permalink
#17758: Update Running stats Writer kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Feb 11, 2025
1 parent c03ba32 commit 075f224
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ void kernel_main() {

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr auto cb_id_src = get_compile_time_arg_val(6);
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr auto cb_id_dst = get_compile_time_arg_val(7);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand All @@ -39,7 +39,7 @@ void kernel_main() {
.bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format};

// old running mean
constexpr auto cb_id_old_running_mean = tt::CBIndex::c_3;
constexpr auto cb_id_old_running_mean = get_compile_time_arg_val(8);
constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(2) == 1;
const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean);
const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean);
Expand All @@ -50,7 +50,7 @@ void kernel_main() {
.data_format = old_running_mean_data_format};

// old running var
constexpr auto cb_id_old_running_var = tt::CBIndex::c_4;
constexpr auto cb_id_old_running_var = get_compile_time_arg_val(9);
constexpr bool old_running_var_is_dram = get_compile_time_arg_val(3) == 1;
const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var);
const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var);
Expand All @@ -62,8 +62,8 @@ void kernel_main() {

constexpr bool old_running_mean_has_value = get_compile_time_arg_val(4) == 1;
constexpr bool old_running_var_has_value = get_compile_time_arg_val(5) == 1;
constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27;
constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28;
constexpr auto cb_id_updated_running_mean = get_compile_time_arg_val(10);
constexpr auto cb_id_updated_running_var = get_compile_time_arg_val(11);

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ RunningStatistics::RunningStatisticsProgramFactory::create(
e_is_dram,
static_cast<uint32_t>(running_mean_has_value),
static_cast<uint32_t>(running_var_has_value),
batch_var_tensor_cb,
output_tensor_cb,
old_running_mean_tensor_cb,
old_running_var_tensor_cb,
updated_m_cb,
updated_v_cb,
},
std::move(writer_defines)));

Expand Down

0 comments on commit 075f224

Please sign in to comment.