Skip to content

Commit

Permalink
#11830: Pass correct dispatch msg addr to define via build state config
Browse files Browse the repository at this point in the history
  • Loading branch information
abhullar-tt committed Oct 3, 2024
1 parent 65cc65c commit fd4a4d4
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ static void RunTest(WatcherFixture *fixture, Device *device, riscv_id_t riscv_ty
// We should be able to find the expected watcher error in the log as well,
// expected error message depends on the risc we're running on.
string kernel = "tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp";
int line_num = 57;
int line_num = 56;

string expected = fmt::format(
"Device {} {} core(x={:2},y={:2}) phys(x={:2},y={:2}): {} tripped an assert on line {}. Current kernel: {}.",
Expand Down
42 changes: 28 additions & 14 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,27 +269,38 @@ void Device::initialize_build() {

this->build_env_.init(this->build_key(), this->arch());

auto init_helper = [this] (bool is_fw) -> JitBuildStateSet {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->id());
uint32_t dispatch_message_addr =
dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);

auto init_helper = [this, dispatch_message_addr] (bool is_fw) -> JitBuildStateSet {
std::vector<std::shared_ptr<JitBuildState>> build_states;

build_states.resize(arch() == tt::ARCH::GRAYSKULL ? 5 : 7);

build_states[build_processor_type_to_index(JitBuildProcessorType::DATA_MOVEMENT).first + 0] =
std::make_shared<JitBuildDataMovement>(this->build_env_, 0, is_fw);
std::make_shared<JitBuildDataMovement>(
this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
build_states[build_processor_type_to_index(JitBuildProcessorType::DATA_MOVEMENT).first + 1] =
std::make_shared<JitBuildDataMovement>(this->build_env_, 1, is_fw);
std::make_shared<JitBuildDataMovement>(
this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 0] =
std::make_shared<JitBuildCompute>(this->build_env_, 0, is_fw);
std::make_shared<JitBuildCompute>(
this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 1] =
std::make_shared<JitBuildCompute>(this->build_env_, 1, is_fw);
std::make_shared<JitBuildCompute>(
this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 2] =
std::make_shared<JitBuildCompute>(this->build_env_, 2, is_fw);
std::make_shared<JitBuildCompute>(
this->build_env_, JitBuiltStateConfig{.processor_id = 2, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});

if (arch() != tt::ARCH::GRAYSKULL) {
build_states[build_processor_type_to_index(JitBuildProcessorType::ETHERNET).first + 0] =
std::make_shared<JitBuildEthernet>(this->build_env_, 0, is_fw);
std::make_shared<JitBuildEthernet>(
this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
build_states[build_processor_type_to_index(JitBuildProcessorType::ETHERNET).first + 1] =
std::make_shared<JitBuildEthernet>(this->build_env_, 1, is_fw);
std::make_shared<JitBuildEthernet>(
this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr});
}

return build_states;
Expand Down Expand Up @@ -1186,11 +1197,12 @@ void Device::update_workers_build_settings(std::vector<std::vector<std::tuple<tt
uint32_t scratch_db_size = dispatch_constants::get(dispatch_core_type).scratch_db_size();
const uint32_t l1_size = dispatch_core_type == CoreType::WORKER ? MEM_L1_SIZE : MEM_ETH_SIZE;
uint32_t dispatch_s_buffer_base;
uint32_t dispatch_buffer_base = dispatch_constants::get(dispatch_core_type).dispatch_buffer_base();
if (dispatch_core_type == CoreType::WORKER) {
dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
dispatch_s_buffer_base = dispatch_buffer_base + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
}
else {
dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE;
dispatch_s_buffer_base = dispatch_buffer_base;
}
TT_ASSERT(scratch_db_base + scratch_db_size <= l1_size);

Expand Down Expand Up @@ -1584,19 +1596,20 @@ void Device::setup_tunnel_for_remote_devices() {
}
if (this->dispatch_s_enabled()) {
// Populate settings for dispatch_s
uint32_t dispatch_buffer_base = dispatch_constants::get(dispatch_core_type).dispatch_buffer_base();
for (uint32_t cq_id = 0; cq_id < num_hw_cqs; cq_id++) {
// Initialize dispatch_s settings as invalid values. To be populated if dispatch_s is enabled.
settings.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE;
settings.semaphores.push_back(0); // used by dispatch_s to sync with prefetch_d
settings.semaphores.push_back(0); // dispatch_s waits on this until dispatch_d increments it
if (dispatch_core_type == CoreType::WORKER) {
// dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start and sem idx.
settings.cb_start_address = dispatch_constants::DISPATCH_BUFFER_BASE + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
settings.cb_start_address = dispatch_buffer_base + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
settings.producer_semaphore_id = 2; // sync with producer (prefetcher)
settings.consumer_semaphore_id = 3; // sync with dispatch_d (this is the "consumer" of dispatch_s)
} else {
// dispatch_d and dispatch_s are on different cores. No shared resources: dispatch_s CB and semaphores start at base.
settings.cb_start_address = dispatch_constants::DISPATCH_BUFFER_BASE;
settings.cb_start_address = dispatch_buffer_base;
settings.producer_semaphore_id = 0; // sync with producer (prefetcher)
settings.consumer_semaphore_id = 1; // sync with dispatch_d (this is the "consumer" of dispatch_s)
}
Expand Down Expand Up @@ -1749,13 +1762,14 @@ void Device::compile_command_queue_programs() {
// Skip allocating dispatch_s for multi-CQ configurations with ethernet dispatch
dispatch_s_core = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id);
dispatch_s_physical_core = get_physical_core_coordinate(dispatch_s_core, dispatch_core_type);
uint32_t dispatch_buffer_base = dispatch_constants::get(dispatch_core_type).dispatch_buffer_base();
if (dispatch_core_type == CoreType::WORKER) {
// dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx.
dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
dispatch_s_buffer_base = dispatch_buffer_base + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages();
}
else {
// dispatch_d and dispatch_s are on different cores. No shared resources: dispatch_s CB starts at base.
dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE;
dispatch_s_buffer_base = dispatch_buffer_base;
}
dispatch_s_sem = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_s to sync with prefetch
dispatch_s_sync_sem_id = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_d to signal that dispatch_s can send go signal
Expand Down
22 changes: 7 additions & 15 deletions tt_metal/jit_build/build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ void JitBuildEnv::init(uint32_t build_key, tt::ARCH arch) {
this->lflags_ += "-fno-exceptions -Wl,-z,max-page-size=16 -Wl,-z,common-page-size=16 -nostartfiles ";
}

JitBuildState::JitBuildState(const JitBuildEnv& env, int which, bool is_fw) :
env_(env), core_id_(which), is_fw_(is_fw) {}
JitBuildState::JitBuildState(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) :
env_(env), core_id_(build_config.processor_id), is_fw_(build_config.is_fw), dispatch_message_addr_(build_config.dispatch_message_addr) {}

// Fill in common state derived from the default state set up in the constructors
void JitBuildState::finish_init() {
Expand All @@ -150,6 +150,7 @@ void JitBuildState::finish_init() {
} else {
this->defines_ += "-DKERNEL_BUILD ";
}
this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(this->dispatch_message_addr_) + " ";

// Create the objs from the srcs
for (string src : srcs_) {
Expand Down Expand Up @@ -189,8 +190,8 @@ void JitBuildState::finish_init() {
this->target_full_path_ = "/" + this->target_name_ + "/" + this->target_name_ + ".hex";
}

JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bool is_fw) :
JitBuildState(env, which, is_fw) {
JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) :
JitBuildState(env, build_config) {
TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 2, "Invalid data movement processor");

this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_;
Expand All @@ -201,9 +202,6 @@ JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bo
"tt_metal/hw/ckernels/" + env.arch_name_ + "/metal/llk_io ";

this->defines_ = env_.defines_;
uint32_t dispatch_message_addr =
dispatch_constants::get(CoreType::WORKER).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " ";

uint32_t l1_cache_disable_mask =
tt::llrt::OptionsG.get_feature_riscv_mask(tt::llrt::RunTimeDebugFeatureDisableL1DataCache);
Expand Down Expand Up @@ -266,7 +264,7 @@ JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bo
finish_init();
}

JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw) : JitBuildState(env, which, is_fw) {
JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : JitBuildState(env, build_config) {
TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 3, "Invalid compute processor");

this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_;
Expand All @@ -282,9 +280,6 @@ JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw)
if ((l1_cache_disable_mask & debug_compute_mask) == debug_compute_mask) {
this->defines_ += "-DDISABLE_L1_DATA_CACHE ";
}
uint32_t dispatch_message_addr =
dispatch_constants::get(CoreType::WORKER).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " ";

this->includes_ = env_.includes_ + "-I" + env_.root_ + "tt_metal/hw/ckernels/" + env.arch_name_ + "/inc " + "-I" +
env_.root_ + "tt_metal/hw/ckernels/" + env.arch_name_ + "/metal/common " + "-I" + env_.root_ +
Expand Down Expand Up @@ -355,7 +350,7 @@ JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw)
finish_init();
}

JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw) : JitBuildState(env, which, is_fw) {
JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : JitBuildState(env, build_config) {
TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 2, "Invalid ethernet processor");
this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_;

Expand All @@ -369,9 +364,6 @@ JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw
if ((l1_cache_disable_mask & tt::llrt::DebugHartFlags::RISCV_ER) == tt::llrt::DebugHartFlags::RISCV_ER) {
this->defines_ += "-DDISABLE_L1_DATA_CACHE ";
}
uint32_t dispatch_message_addr =
dispatch_constants::get(CoreType::ETH).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " ";

switch (this->core_id_) {
case 0: {
Expand Down
15 changes: 11 additions & 4 deletions tt_metal/jit_build/build.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ enum class JitBuildProcessorType {
ETHERNET
};

struct JitBuiltStateConfig {
int processor_id = 0;
bool is_fw = false;
uint32_t dispatch_message_addr = 0;
};

// The build environment
// Includes the path to the src/output and global defines, flags, etc
// Device specific
Expand Down Expand Up @@ -83,6 +89,7 @@ class alignas(CACHE_LINE_ALIGNMENT) JitBuildState {

int core_id_;
int is_fw_;
uint32_t dispatch_message_addr_;
bool process_defines_at_compile;

string out_path_;
Expand All @@ -109,7 +116,7 @@ class alignas(CACHE_LINE_ALIGNMENT) JitBuildState {
void extract_zone_src_locations(const string& log_file) const;

public:
JitBuildState(const JitBuildEnv& env, int which, bool is_fw = false);
JitBuildState(const JitBuildEnv& env, const JitBuiltStateConfig &build_config);
virtual ~JitBuildState() = default;
void finish_init();

Expand Down Expand Up @@ -137,19 +144,19 @@ class JitBuildDataMovement : public JitBuildState {
private:

public:
JitBuildDataMovement(const JitBuildEnv& env, int which, bool is_fw = false);
JitBuildDataMovement(const JitBuildEnv& env, const JitBuiltStateConfig &build_config);
};

class JitBuildCompute : public JitBuildState {
private:
public:
JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw = false);
JitBuildCompute(const JitBuildEnv& env, const JitBuiltStateConfig &build_config);
};

class JitBuildEthernet : public JitBuildState {
private:
public:
JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw = false);
JitBuildEthernet(const JitBuildEnv& env, const JitBuiltStateConfig &build_config);
};

// Abstract base class for kernel specialization
Expand Down

0 comments on commit fd4a4d4

Please sign in to comment.