Skip to content

Commit

Permalink
#0: Minor cleanup/fix for generating reduce/bcast scalar tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jul 29, 2024
1 parent 616d148 commit 483b120
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,43 @@
#include "dataflow_api.h"

// W-bcast scalar
// Tile is assumed to have 16-bit elements
// Scalar is assumed to be a 16-bit value double packed into a u32
FORCE_INLINE void generate_bcast_col_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint16_t scalar_val = scalar>>16;
const uint16_t scalar_val = scalar >> 16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(get_write_ptr(cb_id));
for (int k = 0; k < 4; k+=2) {
for (int k = 0; k < 4; k += 2) {
uint32_t idx = k << 8;
for (int j = 0; j < 256; j+=16) {
for (int j = 0; j < 256; j += 16) {
ptr[idx + j] = scalar_val;
}
}
cb_push_back(cb_id, 1);
}

// H-bcast scalar
// Tile is assumed to have 16-bit elements
// Scalar is assumed to be a 16-bit value double packed into a u32
FORCE_INLINE void generate_bcast_row_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint32_t scalar_val = scalar>>16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id));
for (int k = 0; k < 2; ++k) {
uint32_t idx = k << 7;
for (int j = 0; j < 8; ++j) {
ptr[idx + j] = scalar_val;
ptr[idx + j] = scalar;
}
}
cb_push_back(cb_id, 1);
}

// HW-bcast scalar
// Tile is assumed to have 16-bit elements
// Scalar is assumed to be a 16-bit value double packed into a u32
FORCE_INLINE void generate_bcast_unary_scalar(const uint32_t cb_id, const uint32_t scalar) {
const uint32_t scalar_val = scalar>>16;
const uint32_t scalar_val = scalar >> 16;
cb_reserve_back(cb_id, 1);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id));
ptr[0] = scalar>>16;
ptr[0] = scalar >> 16;
cb_push_back(cb_id, 1);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

#include "dataflow_api.h"

#include "debug/dprint.h"

// Tile is assumed to have 16-bit elements
// Scaler is assumed to be a 16-bit value double packed into a u32
FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t scaler) {
cb_reserve_back(cb_id, 1);

Expand All @@ -17,8 +17,10 @@ FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t sc
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(write_addr);

// Fill tile with zeros
// TODO: src addr does not need to be rewritten. Update/add api for this
noc_async_read_one_packet_set_state(zeros_noc_addr, MEM_ZEROS_SIZE);
for (uint32_t i = 0; i < num_zeros_reads; ++i) {
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE);
noc_async_read_one_packet_with_state(zeros_noc_addr, write_addr);
write_addr += MEM_ZEROS_SIZE;
}
noc_async_read_barrier();
Expand Down

0 comments on commit 483b120

Please sign in to comment.