Skip to content

Commit

Permalink
Tightens CB indices for Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinleeTT committed Feb 27, 2025
1 parent c560617 commit 25ce98a
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void MAIN {
const uint32_t ndst = get_arg_val<uint32_t>(3);
const uint32_t start_ht = get_arg_val<uint32_t>(4);
const uint32_t mask_padded_data = get_arg_val<uint32_t>(5);
binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_24);
binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_6);

constexpr uint32_t onetile = 1;
// reserve one tile for zeros on cb_in2
Expand All @@ -86,14 +86,14 @@ void MAIN {
constexpr auto cb_fused_scale = tt::CBIndex::c_3;
constexpr auto cb_fused_attn = tt::CBIndex::c_4;
constexpr auto cb_mask_padded = tt::CBIndex::c_5;
constexpr auto cb_exps = tt::CBIndex::c_24;
constexpr auto cb_scale_mask = tt::CBIndex::c_27;
constexpr auto cb_recipsumexps = tt::CBIndex::c_25;
constexpr auto cb_exps = tt::CBIndex::c_6;
constexpr auto cb_scale_mask = tt::CBIndex::c_9;
constexpr auto cb_recipsumexps = tt::CBIndex::c_7;
constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_out0 = tt::CBIndex::c_16;
constexpr auto cb_out0 = tt::CBIndex::c_11;
#ifdef NUMERIC_STABLE
constexpr auto cb_max = tt::CBIndex::c_26;
constexpr auto cb_x = tt::CBIndex::c_28;
constexpr auto cb_max = tt::CBIndex::c_8;
constexpr auto cb_x = tt::CBIndex::c_10;
#else
constexpr auto cb_x = cb_exps;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ void MAIN {
constexpr uint32_t subblock_w = get_compile_time_arg_val(2);
constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(3);

binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_24);
binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_6);

constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_bcast_scaler = tt::CBIndex::c_1;
constexpr auto cb_fused_scale = tt::CBIndex::c_2;
constexpr auto cb_fused_attn = tt::CBIndex::c_3;
constexpr auto cb_exps = tt::CBIndex::c_24;
constexpr auto cb_recipsumexps = tt::CBIndex::c_25;
constexpr auto cb_scale_mask = tt::CBIndex::c_26;
constexpr auto cb_out0 = tt::CBIndex::c_16;
constexpr auto cb_exps = tt::CBIndex::c_6;
constexpr auto cb_recipsumexps = tt::CBIndex::c_7;
constexpr auto cb_scale_mask = tt::CBIndex::c_8;
constexpr auto cb_out0 = tt::CBIndex::c_11;
#ifdef NUMERIC_STABLE
constexpr auto cb_max = tt::CBIndex::c_27;
constexpr auto cb_x = tt::CBIndex::c_28;
constexpr auto cb_max = tt::CBIndex::c_9;
constexpr auto cb_x = tt::CBIndex::c_10;
#else
constexpr auto cb_x = cb_exps;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void kernel_main() {
const uint32_t Wt = get_arg_val<uint32_t>(5);

constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr uint32_t cb_id_in0 = 0, cb_id_in1 = 1;
constexpr uint32_t cb_id_in0 = tt::CBIndex::c_0, cb_id_in1 = tt::CBIndex::c_1;

// ublocks size defined in tiles
constexpr uint32_t onetile = 1;
Expand Down Expand Up @@ -58,7 +58,7 @@ void kernel_main() {

// TODO(AP): cleanup, probably with named args/param pack/reflection.
{
constexpr uint32_t cb_in_2 = 2;
constexpr uint32_t cb_in_2 = tt::CBIndex::c_2;
const uint32_t reduce_scaler = get_arg_val<uint32_t>(10);
generate_reduce_scaler(cb_in_2, reduce_scaler);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void kernel_main() {

constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1;

constexpr uint32_t cb_id_out0 = 16;
constexpr uint32_t cb_id_out0 = tt::CBIndex::c_11;
constexpr uint32_t onetile = 1;
const uint32_t tile_bytes = get_tile_size(cb_id_out0);
const DataFormat data_format = get_dataformat(cb_id_out0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,26 +210,26 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core(
auto c_in0_config = CircularBufferConfig(in0_t * in0_tile_size, {{tt::CBIndex::c_0, in0_cb_data_format}})
.set_page_size(tt::CBIndex::c_0, in0_tile_size);
auto cb_in0_id = CreateCircularBuffer(program, all_device_cores, c_in0_config);
auto c_out0_config = CircularBufferConfig(out0_t * out0_tile_size, {{tt::CBIndex::c_16, out0_cb_data_format}})
.set_page_size(tt::CBIndex::c_16, out0_tile_size);
auto c_out0_config = CircularBufferConfig(out0_t * out0_tile_size, {{tt::CBIndex::c_11, out0_cb_data_format}})
.set_page_size(tt::CBIndex::c_11, out0_tile_size);
auto cb_out0_id = CreateCircularBuffer(program, all_device_cores, c_out0_config);
auto c_intermed1_config = CircularBufferConfig(im1_t * im_tile_size, {{tt::CBIndex::c_25, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_25, im_tile_size);
auto c_intermed1_config = CircularBufferConfig(im1_t * im_tile_size, {{tt::CBIndex::c_7, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_7, im_tile_size);
auto cb_intermed1_id = CreateCircularBuffer(program, all_device_cores, c_intermed1_config);
auto c_in2_config = CircularBufferConfig(in2_t * scalar_tile_size, {{tt::CBIndex::c_2, scalar_cb_data_format}})
.set_page_size(tt::CBIndex::c_2, scalar_tile_size);
auto cb_in2_id = CreateCircularBuffer(program, all_device_cores, c_in2_config);
auto c_intermed0_config = CircularBufferConfig(im0_t * im_tile_size, {{tt::CBIndex::c_24, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_24, im_tile_size);
auto c_intermed0_config = CircularBufferConfig(im0_t * im_tile_size, {{tt::CBIndex::c_6, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_6, im_tile_size);
auto cb_intermed0_id = CreateCircularBuffer(program, all_device_cores, c_intermed0_config);
std::optional<CBHandle> cb_intermed3_id;
std::optional<CBHandle> cb_in3_id;
std::optional<CBHandle> cb_in4_id;
std::optional<CBHandle> cb_in5_id;
if (mask.has_value()) {
CircularBufferConfig c_intermed3_config =
CircularBufferConfig(im3_t * im_tile_size, {{tt::CBIndex::c_27, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_27, im_tile_size);
CircularBufferConfig(im3_t * im_tile_size, {{tt::CBIndex::c_9, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_9, im_tile_size);
cb_intermed3_id = CreateCircularBuffer(program, all_device_cores, c_intermed3_config);
CircularBufferConfig c_in3_config =
CircularBufferConfig(in3_t * scalar_tile_size, {{tt::CBIndex::c_3, scalar_cb_data_format}})
Expand All @@ -248,12 +248,12 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core(
std::optional<CBHandle> cb_intermed4_id;
if (numeric_stable) {
// cb_max
auto c_intermed2_config = CircularBufferConfig(im2_t * im_tile_size, {{tt::CBIndex::c_26, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_26, im_tile_size);
auto c_intermed2_config = CircularBufferConfig(im2_t * im_tile_size, {{tt::CBIndex::c_8, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_8, im_tile_size);
cb_intermed2_id = CreateCircularBuffer(program, all_device_cores, c_intermed2_config);
// cb_x
auto c_x_config = CircularBufferConfig(im4_t * im_tile_size, {{tt::CBIndex::c_28, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_28, im_tile_size);
auto c_x_config = CircularBufferConfig(im4_t * im_tile_size, {{tt::CBIndex::c_10, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_10, im_tile_size);
cb_intermed4_id = CreateCircularBuffer(program, all_device_cores, c_x_config);
}

Expand Down Expand Up @@ -771,8 +771,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
std::optional<CBHandle> cb_in3_id;
if (mask.has_value()) {
// im2
auto c_intermed2_config = CircularBufferConfig(im2_CB_size, {{tt::CBIndex::c_26, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_26, im_tile_size);
auto c_intermed2_config = CircularBufferConfig(im2_CB_size, {{tt::CBIndex::c_8, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_8, im_tile_size);
cb_intermed2_id = CreateCircularBuffer(program, all_device_cores, c_intermed2_config);
// in2 scale
auto c_in2_config = CircularBufferConfig(in2_CB_size, {{tt::CBIndex::c_2, scale_cb_data_format}})
Expand All @@ -792,26 +792,26 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
}
}
// out
auto c_out0_config = CircularBufferConfig(out_CB_size, {{tt::CBIndex::c_16, out0_cb_data_format}})
.set_page_size(tt::CBIndex::c_16, out0_tile_size)
auto c_out0_config = CircularBufferConfig(out_CB_size, {{tt::CBIndex::c_11, out0_cb_data_format}})
.set_page_size(tt::CBIndex::c_11, out0_tile_size)
.set_globally_allocated_address(*out0_buffer);
auto cb_out0_id = CreateCircularBuffer(program, all_device_cores, c_out0_config);
// im0 for exp(x)
auto c_intermed0_config = CircularBufferConfig(im0_CB_size, {{tt::CBIndex::c_24, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_24, im_tile_size);
auto c_intermed0_config = CircularBufferConfig(im0_CB_size, {{tt::CBIndex::c_6, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_6, im_tile_size);
auto cb_intermed0_id = CreateCircularBuffer(program, all_device_cores, c_intermed0_config);
// im1 for 1/sum(exp(x))
auto c_intermed1_config = CircularBufferConfig(im1_CB_size, {{tt::CBIndex::c_25, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_25, im_tile_size);
auto c_intermed1_config = CircularBufferConfig(im1_CB_size, {{tt::CBIndex::c_7, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_7, im_tile_size);
auto cb_intermed1_id = CreateCircularBuffer(program, all_device_cores, c_intermed1_config);
if (numeric_stable) {
// cb_max
auto c_intermed3_config = CircularBufferConfig(max_CB_size, {{tt::CBIndex::c_27, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_27, im_tile_size);
auto c_intermed3_config = CircularBufferConfig(max_CB_size, {{tt::CBIndex::c_9, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_9, im_tile_size);
auto cb_intermed3_id = CreateCircularBuffer(program, all_device_cores, c_intermed3_config);
// cb_x
auto c_intermed4_config = CircularBufferConfig(x_CB_size, {{tt::CBIndex::c_28, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_28, im_tile_size);
auto c_intermed4_config = CircularBufferConfig(x_CB_size, {{tt::CBIndex::c_10, im_cb_data_format}})
.set_page_size(tt::CBIndex::c_10, im_tile_size);
auto cb_intermed4_id = CreateCircularBuffer(program, all_device_cores, c_intermed4_config);
}

Expand Down

0 comments on commit 25ce98a

Please sign in to comment.