diff --git a/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_assert.cpp b/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_assert.cpp index 0a793f08af09..04df836f63c0 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_assert.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_assert.cpp @@ -154,7 +154,7 @@ static void RunTest(WatcherFixture *fixture, Device *device, riscv_id_t riscv_ty // We should be able to find the expected watcher error in the log as well, // expected error message depends on the risc we're running on. string kernel = "tests/tt_metal/tt_metal/test_kernels/misc/watcher_asserts.cpp"; - int line_num = 57; + int line_num = 56; string expected = fmt::format( "Device {} {} core(x={:2},y={:2}) phys(x={:2},y={:2}): {} tripped an assert on line {}. Current kernel: {}.", diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 55e21198d0af..626ae15cc59b 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -269,27 +269,38 @@ void Device::initialize_build() { this->build_env_.init(this->build_key(), this->arch()); - auto init_helper = [this] (bool is_fw) -> JitBuildStateSet { + CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->id()); + uint32_t dispatch_message_addr = + dispatch_constants::get(dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + + auto init_helper = [this, dispatch_message_addr] (bool is_fw) -> JitBuildStateSet { std::vector> build_states; build_states.resize(arch() == tt::ARCH::GRAYSKULL ? 5 : 7); build_states[build_processor_type_to_index(JitBuildProcessorType::DATA_MOVEMENT).first + 0] = - std::make_shared(this->build_env_, 0, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); build_states[build_processor_type_to_index(JitBuildProcessorType::DATA_MOVEMENT).first + 1] = - std::make_shared(this->build_env_, 1, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 0] = - std::make_shared(this->build_env_, 0, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 1] = - std::make_shared(this->build_env_, 1, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); build_states[build_processor_type_to_index(JitBuildProcessorType::COMPUTE).first + 2] = - std::make_shared(this->build_env_, 2, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 2, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); if (arch() != tt::ARCH::GRAYSKULL) { build_states[build_processor_type_to_index(JitBuildProcessorType::ETHERNET).first + 0] = - std::make_shared(this->build_env_, 0, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 0, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); build_states[build_processor_type_to_index(JitBuildProcessorType::ETHERNET).first + 1] = - std::make_shared(this->build_env_, 1, is_fw); + std::make_shared( + this->build_env_, JitBuiltStateConfig{.processor_id = 1, .is_fw=is_fw, .dispatch_message_addr=dispatch_message_addr}); } return build_states; @@ -1186,11 +1197,12 @@ void Device::update_workers_build_settings(std::vectordispatch_s_enabled()) { // Populate settings for dispatch_s + uint32_t dispatch_buffer_base = dispatch_constants::get(dispatch_core_type).dispatch_buffer_base(); for (uint32_t cq_id = 0; cq_id < num_hw_cqs; cq_id++) { // Initialize dispatch_s settings as invalid values. To be populated if dispatch_s is enabled. settings.cb_log_page_size = dispatch_constants::DISPATCH_S_BUFFER_LOG_PAGE_SIZE; @@ -1591,12 +1604,12 @@ void Device::setup_tunnel_for_remote_devices() { settings.semaphores.push_back(0); // dispatch_s waits on this until dispatch_d increments it if (dispatch_core_type == CoreType::WORKER) { // dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start and sem idx. - settings.cb_start_address = dispatch_constants::DISPATCH_BUFFER_BASE + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages(); + settings.cb_start_address = dispatch_buffer_base + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages(); settings.producer_semaphore_id = 2; // sync with producer (prefetcher) settings.consumer_semaphore_id = 3; // sync with dispatch_d (this is the "consumer" of dispatch_s) } else { // dispatch_d and dispatch_s are on different cores. No shared resources: dispatch_s CB and semaphores start at base. - settings.cb_start_address = dispatch_constants::DISPATCH_BUFFER_BASE; + settings.cb_start_address = dispatch_buffer_base; settings.producer_semaphore_id = 0; // sync with producer (prefetcher) settings.consumer_semaphore_id = 1; // sync with dispatch_d (this is the "consumer" of dispatch_s) } @@ -1749,13 +1762,14 @@ void Device::compile_command_queue_programs() { // Skip allocating dispatch_s for multi-CQ configurations with ethernet dispatch dispatch_s_core = dispatch_core_manager::instance().dispatcher_s_core(device_id, channel, cq_id); dispatch_s_physical_core = get_physical_core_coordinate(dispatch_s_core, dispatch_core_type); + uint32_t dispatch_buffer_base = dispatch_constants::get(dispatch_core_type).dispatch_buffer_base(); if (dispatch_core_type == CoreType::WORKER) { // dispatch_s is on the same Tensix core as dispatch_d. Shared resources. Offset CB start idx. - dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages(); + dispatch_s_buffer_base = dispatch_buffer_base + (1 << dispatch_constants::DISPATCH_BUFFER_LOG_PAGE_SIZE) * dispatch_constants::get(dispatch_core_type).dispatch_buffer_pages(); } else { // dispatch_d and dispatch_s are on different cores. No shared resources: dispatch_s CB starts at base. - dispatch_s_buffer_base = dispatch_constants::DISPATCH_BUFFER_BASE; + dispatch_s_buffer_base = dispatch_buffer_base; } dispatch_s_sem = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_s to sync with prefetch dispatch_s_sync_sem_id = tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_s_core, 0, dispatch_core_type); // used by dispatch_d to signal that dispatch_s can send go signal diff --git a/tt_metal/jit_build/build.cpp b/tt_metal/jit_build/build.cpp index 048cfc3d2cea..e231daac62a0 100644 --- a/tt_metal/jit_build/build.cpp +++ b/tt_metal/jit_build/build.cpp @@ -140,8 +140,8 @@ void JitBuildEnv::init(uint32_t build_key, tt::ARCH arch) { this->lflags_ += "-fno-exceptions -Wl,-z,max-page-size=16 -Wl,-z,common-page-size=16 -nostartfiles "; } -JitBuildState::JitBuildState(const JitBuildEnv& env, int which, bool is_fw) : - env_(env), core_id_(which), is_fw_(is_fw) {} +JitBuildState::JitBuildState(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : + env_(env), core_id_(build_config.processor_id), is_fw_(build_config.is_fw), dispatch_message_addr_(build_config.dispatch_message_addr) {} // Fill in common state derived from the default state set up in the constructors void JitBuildState::finish_init() { @@ -150,6 +150,7 @@ void JitBuildState::finish_init() { } else { this->defines_ += "-DKERNEL_BUILD "; } + this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(this->dispatch_message_addr_) + " "; // Create the objs from the srcs for (string src : srcs_) { @@ -189,8 +190,8 @@ void JitBuildState::finish_init() { this->target_full_path_ = "/" + this->target_name_ + "/" + this->target_name_ + ".hex"; } -JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bool is_fw) : - JitBuildState(env, which, is_fw) { +JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : + JitBuildState(env, build_config) { TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 2, "Invalid data movement processor"); this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_; @@ -201,9 +202,6 @@ JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bo "tt_metal/hw/ckernels/" + env.arch_name_ + "/metal/llk_io "; this->defines_ = env_.defines_; - uint32_t dispatch_message_addr = - dispatch_constants::get(CoreType::WORKER).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " "; uint32_t l1_cache_disable_mask = tt::llrt::OptionsG.get_feature_riscv_mask(tt::llrt::RunTimeDebugFeatureDisableL1DataCache); @@ -266,7 +264,7 @@ JitBuildDataMovement::JitBuildDataMovement(const JitBuildEnv& env, int which, bo finish_init(); } -JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw) : JitBuildState(env, which, is_fw) { +JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : JitBuildState(env, build_config) { TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 3, "Invalid compute processor"); this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_; @@ -282,9 +280,6 @@ JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw) if ((l1_cache_disable_mask & debug_compute_mask) == debug_compute_mask) { this->defines_ += "-DDISABLE_L1_DATA_CACHE "; } - uint32_t dispatch_message_addr = - dispatch_constants::get(CoreType::WORKER).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " "; this->includes_ = env_.includes_ + "-I" + env_.root_ + "tt_metal/hw/ckernels/" + env.arch_name_ + "/inc " + "-I" + env_.root_ + "tt_metal/hw/ckernels/" + env.arch_name_ + "/metal/common " + "-I" + env_.root_ + @@ -355,7 +350,7 @@ JitBuildCompute::JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw) finish_init(); } -JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw) : JitBuildState(env, which, is_fw) { +JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, const JitBuiltStateConfig &build_config) : JitBuildState(env, build_config) { TT_ASSERT(this->core_id_ >= 0 && this->core_id_ < 2, "Invalid ethernet processor"); this->out_path_ = this->is_fw_ ? env_.out_firmware_root_ : env_.out_kernel_root_; @@ -369,9 +364,6 @@ JitBuildEthernet::JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw if ((l1_cache_disable_mask & tt::llrt::DebugHartFlags::RISCV_ER) == tt::llrt::DebugHartFlags::RISCV_ER) { this->defines_ += "-DDISABLE_L1_DATA_CACHE "; } - uint32_t dispatch_message_addr = - dispatch_constants::get(CoreType::ETH).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); - this->defines_ += "-DDISPATCH_MESSAGE_ADDR=" + to_string(dispatch_message_addr) + " "; switch (this->core_id_) { case 0: { diff --git a/tt_metal/jit_build/build.hpp b/tt_metal/jit_build/build.hpp index 7c3bf86a053c..43a3b6f7714a 100644 --- a/tt_metal/jit_build/build.hpp +++ b/tt_metal/jit_build/build.hpp @@ -34,6 +34,12 @@ enum class JitBuildProcessorType { ETHERNET }; +struct JitBuiltStateConfig { + int processor_id = 0; + bool is_fw = false; + uint32_t dispatch_message_addr = 0; +}; + // The build environment // Includes the path to the src/output and global defines, flags, etc // Device specific @@ -83,6 +89,7 @@ class alignas(CACHE_LINE_ALIGNMENT) JitBuildState { int core_id_; int is_fw_; + uint32_t dispatch_message_addr_; bool process_defines_at_compile; string out_path_; @@ -109,7 +116,7 @@ class alignas(CACHE_LINE_ALIGNMENT) JitBuildState { void extract_zone_src_locations(const string& log_file) const; public: - JitBuildState(const JitBuildEnv& env, int which, bool is_fw = false); + JitBuildState(const JitBuildEnv& env, const JitBuiltStateConfig &build_config); virtual ~JitBuildState() = default; void finish_init(); @@ -137,19 +144,19 @@ class JitBuildDataMovement : public JitBuildState { private: public: - JitBuildDataMovement(const JitBuildEnv& env, int which, bool is_fw = false); + JitBuildDataMovement(const JitBuildEnv& env, const JitBuiltStateConfig &build_config); }; class JitBuildCompute : public JitBuildState { private: public: - JitBuildCompute(const JitBuildEnv& env, int which, bool is_fw = false); + JitBuildCompute(const JitBuildEnv& env, const JitBuiltStateConfig &build_config); }; class JitBuildEthernet : public JitBuildState { private: public: - JitBuildEthernet(const JitBuildEnv& env, int which, bool is_fw = false); + JitBuildEthernet(const JitBuildEnv& env, const JitBuiltStateConfig &build_config); }; // Abstract base class for kernel specialization