From dea1940e7c9b4c3a64377760fb845107e64e8494 Mon Sep 17 00:00:00 2001 From: asaigal Date: Wed, 9 Oct 2024 15:22:04 +0000 Subject: [PATCH] #0: Add MeshProgram class - Includes APIs to set MeshProgram configuration across entire MeshDevice or per device in the Mesh. APIs are analogous to Program config APIs - Basic getter APIs to return individual programs and state across MeshProgram - Relies on distribute_impl_ and distribute_to_mesh_device_ functions - distribute_impl_ serves as the MeshProgram entry point on the Controller or Executor to process attributes of this data structure on host(s) - This function is currently implemented as a simple loop, but it can be swapped out for a set of RPC calls on the Controller and asynchronous calls on the executor - distribute_to_mesh_device_impl_ serves as the interface between the MeshProgram on host and the MeshDevice. Curently used in EnqueueMeshProgram and implemented using a simple loop. Can be used to interface with TT-Fabric, once the infra is available - Design aspects to consider as we go along: - Does a MeshProgram span Controllers, or is it limited to a Controller connected to a single cluster. APIs and heirachy may be easier with MeshProgram <--> Controller <--> MeshDevice mapping - For programs we want to broadcast, we don't need the host to perform repeated work across the entire device mesh (it currently does). Likely makes sense to have a bcast setting in the MeshProgram class, to ensure that host sets configuration data once with fast dispatch/fabric performing the bcast - Potential Hierarchy: - Controller populates MeshProgram with kernels, sems, CBs and RTAs (individual programs, single program bcast or multi-program bcast). Population can be done with a reimplementation of distributed_impl_ on the controller - MeshProgram is sent to Executors through a virtual CQ (RPC call + cq_id). This is through a specialized distribute_to_mesh_device_impl_ on the Controller. - Executors get a MeshProgram, which can be scattered or broadcasted through the specified CQ using Fast Dispatch and Fabric. Assembling FD/Fabric commands can be host intensive, and it would thus make sense to distribute these steps across Exectuors. Executors get the MeshProgram to the MeshDevice through a specialized distribute_to_mesh_device_impl_ - For generic entry points to mutate Mesh data structures and send them to the Mesh device, we need generic distribute* functions that can accept any data type and perform generic processing (assemble FD commands, make RPC calls, send programs or data to MeshDevice, etc.) --- .../unit_tests_fast_dispatch/CMakeLists.txt | 1 + .../command_queue/test_mesh_program.cpp | 201 ++++++++++++++++++ tt_metal/CMakeLists.txt | 2 +- tt_metal/distributed/mesh_program.cpp | 144 +++++++++++++ tt_metal/distributed/mesh_program.hpp | 134 ++++++++++++ 5 files changed, 481 insertions(+), 1 deletion(-) create mode 100644 tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_mesh_program.cpp create mode 100644 tt_metal/distributed/mesh_program.cpp create mode 100644 tt_metal/distributed/mesh_program.hpp diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt index 23707d70cd5..9872e115e8c 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/CMakeLists.txt @@ -2,6 +2,7 @@ set(UNIT_TESTS_FD_SRC ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_CommandQueue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_EnqueueProgram.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_mesh_program.cpp ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_EnqueueTrace.cpp ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_EnqueueWriteBuffer_and_EnqueueReadBuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/command_queue/test_events.cpp diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_mesh_program.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_mesh_program.cpp new file mode 100644 index 00000000000..b187944caa6 --- /dev/null +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_mesh_program.cpp @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "command_queue_fixture.hpp" +#include "command_queue_test_utils.hpp" +#include "gtest/gtest.h" +#include "impl/buffers/buffer.hpp" +#include "impl/device/device.hpp" +#include "tt_metal/common/bfloat16.hpp" +#include "tt_metal/common/scoped_timer.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/distributed/mesh_program.hpp" + +using namespace tt::tt_metal; + +struct CBConfig { + uint32_t cb_id; + uint32_t num_pages; + uint32_t page_size; + tt::DataFormat data_format; +}; + +struct DummyProgramConfig { + CoreRangeSet cr_set; + CBConfig cb_config; + uint32_t num_cbs; + uint32_t num_sems; +}; + +struct DummyProgramMultiCBConfig { + CoreRangeSet cr_set; + std::vector cb_config_vector; + uint32_t num_sems; +}; + +void initialize_dummy_kernels(MeshProgram& mesh_program, const CoreRangeSet& cr_set) { + auto dummy_reader_kernel = CreateKernel( + mesh_program, "tt_metal/kernels/dataflow/blank.cpp", cr_set, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); + + auto dummy_writer_kernel = CreateKernel( + mesh_program, "tt_metal/kernels/dataflow/blank.cpp", cr_set, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + auto dummy_compute_kernel = CreateKernel(mesh_program, "tt_metal/kernels/compute/blank.cpp", cr_set, ComputeConfig{}); +} + +std::vector initialize_dummy_circular_buffers(MeshProgram& mesh_program, const CoreRangeSet& cr_set, const std::vector& cb_configs) +{ + std::vector cb_handles; + for (uint32_t i = 0; i < cb_configs.size(); i++) { + const CBConfig& cb_config = cb_configs[i]; + const uint32_t cb_id = cb_config.cb_id; + const uint32_t cb_num_pages = cb_config.num_pages; + const uint32_t page_size = cb_config.page_size; + const uint32_t cb_size = cb_num_pages * page_size; + const tt::DataFormat data_format = cb_config.data_format; + const CircularBufferConfig circular_buffer_config = CircularBufferConfig(cb_size, {{cb_id, data_format}}).set_page_size(cb_id, page_size); + const CBHandle cb_handle = CreateCircularBuffer(mesh_program, cr_set, circular_buffer_config); + cb_handles.push_back(cb_handle); + } + return cb_handles; +} + +void initialize_dummy_semaphores(MeshProgram& mesh_program, const std::variant& core_ranges, const vector& init_values) +{ + for (uint32_t i = 0; i < init_values.size(); i++) + { + CreateSemaphore(mesh_program, core_ranges, init_values[i]); + } +} + +bool cb_config_successful(std::shared_ptr mesh_device, MeshProgram &mesh_program, const DummyProgramMultiCBConfig & program_config){ + bool pass = true; + + // Need to use old APIs to read since we cannot allocate a buffer in the reserved space we're trying + // to read from + vector cb_config_vector; + uint32_t cb_config_buffer_size = NUM_CIRCULAR_BUFFERS * UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * sizeof(uint32_t); + + for (const CoreRange& core_range : program_config.cr_set.ranges()) { + for (const CoreCoord& core_coord : core_range) { + auto sem_base_addrs_across_mesh = mesh_program.get_sem_base_addr(mesh_device, core_coord, CoreType::WORKER); + uint32_t dev_idx = 0; + for (auto device : mesh_device->get_devices()) { + tt::tt_metal::detail::ReadFromDeviceL1(device, core_coord, + sem_base_addrs_across_mesh.at(dev_idx), + cb_config_buffer_size, cb_config_vector); + + uint32_t cb_addr = device->get_base_allocator_addr(HalMemType::L1); + for (uint32_t i = 0; i < program_config.cb_config_vector.size(); i++) { + const uint32_t index = program_config.cb_config_vector[i].cb_id * sizeof(uint32_t); + const uint32_t cb_num_pages = program_config.cb_config_vector[i].num_pages; + const uint32_t cb_size = cb_num_pages * program_config.cb_config_vector[i].page_size; + const bool addr_match = cb_config_vector.at(index) == ((cb_addr) >> 4); + const bool size_match = cb_config_vector.at(index + 1) == (cb_size >> 4); + const bool num_pages_match = cb_config_vector.at(index + 2) == cb_num_pages; + pass &= (addr_match and size_match and num_pages_match); + + cb_addr += cb_size; + } + dev_idx++; + } + } + } + return pass; +} + +bool test_dummy_EnqueueProgram_with_cbs(std::shared_ptr mesh_device, uint8_t cq_id, DummyProgramMultiCBConfig& program_config) { + MeshProgram mesh_program(mesh_device->get_devices().size()); + + initialize_dummy_circular_buffers(mesh_program, program_config.cr_set, program_config.cb_config_vector); + initialize_dummy_kernels(mesh_program, program_config.cr_set); + const bool is_blocking_op = false; + EnqueueMeshProgram(cq_id, mesh_program, mesh_device, is_blocking_op); + Finish(mesh_device, cq_id); + // return true; + return cb_config_successful(mesh_device, mesh_program, program_config); +} + +bool test_dummy_EnqueueProgram_with_sems(std::shared_ptr mesh_device, uint8_t cq_id, MeshProgram& mesh_program, const DummyProgramConfig& program_config, const vector>& expected_semaphore_vals) { + TT_ASSERT(program_config.cr_set.size() == expected_semaphore_vals.size()); + + bool are_all_semaphore_values_correct = true; + + const bool is_blocking_op = false; + EnqueueMeshProgram(cq_id, mesh_program, mesh_device, is_blocking_op); + Finish(mesh_device, cq_id); + + uint32_t expected_semaphore_vals_idx = 0; + for (const CoreRange& core_range : program_config.cr_set.ranges()) + { + const vector& expected_semaphore_vals_for_core = expected_semaphore_vals[expected_semaphore_vals_idx]; + TT_ASSERT(expected_semaphore_vals_for_core.size() == program_config.num_sems); + expected_semaphore_vals_idx++; + for (const CoreCoord& core_coord : core_range) + { + auto sem_base_addrs_across_mesh = mesh_program.get_sem_base_addr(mesh_device, core_coord, CoreType::WORKER); + uint32_t dev_idx = 0; + for (auto device : mesh_device->get_devices()) { + vector semaphore_vals; + uint32_t expected_semaphore_vals_for_core_idx = 0; + const uint32_t semaphore_buffer_size = program_config.num_sems * hal.get_alignment(HalMemType::L1); + tt::tt_metal::detail::ReadFromDeviceL1(device, core_coord, sem_base_addrs_across_mesh.at(dev_idx), semaphore_buffer_size, semaphore_vals); + for (uint32_t i = 0; i < semaphore_vals.size(); i += (hal.get_alignment(HalMemType::L1) / sizeof(uint32_t))) + { + const bool is_semaphore_value_correct = semaphore_vals[i] == expected_semaphore_vals_for_core[expected_semaphore_vals_for_core_idx]; + expected_semaphore_vals_for_core_idx++; + if (!is_semaphore_value_correct) + { + are_all_semaphore_values_correct = false; + } + } + dev_idx++; + } + } + } + + return are_all_semaphore_values_correct; +} + +bool test_dummy_EnqueueProgram_with_sems(std::shared_ptr mesh_device, uint8_t cq_id, const DummyProgramConfig& program_config) { + MeshProgram mesh_program(mesh_device->get_devices().size()); + + vector expected_semaphore_values; + + for (uint32_t initial_sem_value = 0; initial_sem_value < program_config.num_sems; initial_sem_value++) { + expected_semaphore_values.push_back(initial_sem_value); + } + + initialize_dummy_semaphores(mesh_program, program_config.cr_set, expected_semaphore_values); + return test_dummy_EnqueueProgram_with_sems(mesh_device, cq_id, mesh_program, program_config, {expected_semaphore_values}); +} + +TEST(MeshProgram, TestMeshProgramCB) { + std::shared_ptr mesh_device = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4}, MeshType::RowMajor)); + CoreRange cr({0, 0}, {0, 0}); + CoreRangeSet cr_set({cr}); + + CBConfig cb_config = {.cb_id=0, .num_pages = 4, .page_size = 2048, .data_format = tt::DataFormat::Float16_b}; + + DummyProgramMultiCBConfig config = {.cr_set = cr_set, .cb_config_vector = {cb_config} }; + EXPECT_EQ(true, test_dummy_EnqueueProgram_with_cbs(mesh_device, 0, config)); + mesh_device->close_devices(); +} + +TEST(MeshProgram, TestMeshProgramSem) { + std::shared_ptr mesh_device = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4}, MeshType::RowMajor)); + CoreCoord worker_grid_size = mesh_device->compute_with_storage_grid_size(); + + CoreRange cr({0, 0}, {worker_grid_size.x - 1, worker_grid_size.y - 1}); + CoreRangeSet cr_set({cr}); + + DummyProgramConfig config = {.cr_set = cr_set, .num_sems = NUM_SEMAPHORES}; + + EXPECT_TRUE(test_dummy_EnqueueProgram_with_sems(mesh_device, 0, config)); + mesh_device->close_devices(); +} diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index cfecfa2cdb7..0c3b81d684c 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -13,6 +13,7 @@ set(TT_METAL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_tracking.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device_view.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_program.cpp $ $ $ @@ -53,4 +54,3 @@ set_target_properties(tt_metal PROPERTIES if(BUILD_PROGRAMMING_EXAMPLES) add_subdirectory(programming_examples) endif(BUILD_PROGRAMMING_EXAMPLES) - diff --git a/tt_metal/distributed/mesh_program.cpp b/tt_metal/distributed/mesh_program.cpp new file mode 100644 index 00000000000..7fde6e673f9 --- /dev/null +++ b/tt_metal/distributed/mesh_program.cpp @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mesh_program.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/detail/tt_metal.hpp" + +namespace tt::tt_metal { + +MeshProgram::MeshProgram(std::size_t num_devices) { + this->programs.reserve(num_devices); + for (int i = 0; i < num_devices; ++i) { + this->programs.push_back(std::make_shared()); + } +} + +Program& MeshProgram::at(std::size_t device_index) { + TT_ASSERT(device_index < this->program.size()); + return *this->programs.at(device_index); +} + +std::vector MeshProgram::get_sem_base_addr(std::shared_ptr mesh_device, CoreCoord logical_core, CoreType core_type) const { + return this->distributed_impl_( + std::variant, std::function>( + std::function( + [logical_core, core_type](Program& program, Device* device) -> uint32_t { + return program.get_sem_base_addr(device, logical_core, core_type); + } + ) + ), + mesh_device + ); +} + +uint32_t CreateSemaphore( + MeshProgram& mesh_program, + const std::variant &core_spec, + uint32_t initial_value, + CoreType core_type) { + return mesh_program.distributed_impl_( + std::function( + [&core_spec, initial_value, core_type] (Program& program) -> uint32_t { + return CreateSemaphore(program, core_spec, initial_value, core_type); + } + ) + ); +} + +uint32_t CreateSemaphore( + MeshProgram& mesh_program, + const std::variant &core_spec, + uint32_t initial_value, + CoreType core_type, + chip_id_t device_id) { + return CreateSemaphore(mesh_program.at(device_id), core_spec, initial_value, core_type); +} + +CBHandle CreateCircularBuffer( + MeshProgram& mesh_program, + const std::variant &core_spec, + const CircularBufferConfig &config) { + return mesh_program.distributed_impl_( + std::function( + [&core_spec, &config] (Program& program) -> CBHandle { + return CreateCircularBuffer(program, core_spec, config); + } + ) + ); +} + +CBHandle CreateCircularBuffer( + MeshProgram& mesh_program, + const std::variant &core_spec, + const CircularBufferConfig &config, + chip_id_t device_id) { + return CreateCircularBuffer(mesh_program.at(device_id), core_spec, config); +} + +void SetRuntimeArgs( + MeshProgram& mesh_program, + KernelHandle kernel, + const std::variant &core_spec, + const std::vector &runtime_args) { + mesh_program.distributed_impl_( + std::function( + [kernel, &core_spec, &runtime_args] (Program& program) -> void { + return SetRuntimeArgs(program, kernel, core_spec, runtime_args); + } + ) + ); +} + +void SetRuntimeArgs( + MeshProgram& mesh_program, + KernelHandle kernel, + const std::variant &core_spec, + const std::vector &runtime_args, + chip_id_t device_id) { + SetRuntimeArgs(mesh_program.at(device_id), kernel, core_spec, runtime_args); +} + +KernelHandle CreateKernel( + MeshProgram& mesh_program, + const std::string &file_name, + const std::variant &core_spec, + const std::variant &config) { + return mesh_program.distributed_impl_( + std::function( + [&file_name, &core_spec, &config] (Program& program) -> KernelHandle { + return CreateKernel(program, file_name, core_spec, config); + } + ) + ); +} + +KernelHandle CreateKernel( + MeshProgram& mesh_program, + const std::string &file_name, + const std::variant &core_spec, + const std::variant &config, + chip_id_t device_id) { + return CreateKernel(mesh_program.at(device_id), file_name, core_spec, config); +} + +void EnqueueMeshProgram( + uint8_t cq_id, MeshProgram& mesh_program, std::shared_ptr mesh_device, bool blocking) { + mesh_program.distribute_to_mesh_device_impl_( + std::function( + [cq_id, blocking] (Program& program, Device* device) -> void { + EnqueueProgram(device->command_queue(cq_id), program, blocking); + } + ), + mesh_device + ); +} + +void Finish(std::shared_ptr mesh_device, uint8_t cq_id) { + for (auto device : mesh_device->get_devices()) { + Finish(device->command_queue(cq_id)); + } +} + +} diff --git a/tt_metal/distributed/mesh_program.hpp b/tt_metal/distributed/mesh_program.hpp new file mode 100644 index 00000000000..423e0e50006 --- /dev/null +++ b/tt_metal/distributed/mesh_program.hpp @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "mesh_device.hpp" +#include "tt_metal/impl/program/program.hpp" + +namespace tt::tt_metal { +// Mesh buffer, Mesh Op, Compilation +class MeshProgram { + public: + MeshProgram(std::size_t num_devices); + ~MeshProgram() = default; + Program& at(std::size_t device_index); + std::vector get_sem_base_addr(std::shared_ptr mesh_device, CoreCoord logical_core, CoreType core_type) const; + + template + T distributed_impl_(const std::function& callable) { + if constexpr (std::is_same::value) { + for (std::size_t program_idx = 0; program_idx < this->programs.size(); program_idx++) { + callable(*this->programs.at(program_idx)); + } + } else { + for (std::size_t program_idx = 0; program_idx < this->programs.size() - 1; program_idx++) { + callable(*this->programs.at(program_idx)); + } + return callable(*this->programs.at(this->programs.size() -1)); + } + } + + template + std::vector distributed_impl_(const std::variant, std::function>& callable, std::shared_ptr mesh_device = nullptr) const { + std::vector rval = {}; + std::vector devices = {}; + if (mesh_device != nullptr) { + devices = mesh_device->get_devices(); + TT_ASSERT(devices.size() == this->programs.size(), + "MeshProgram created for {} devices cannot be mapped to a MeshDevice with {} devices", + this->programs.size(), devices.size()); + TT_ASSERT(std::holds_alternative>(callable)); + auto f = std::get>(callable); + for (std::size_t program_idx = 0; program_idx < devices.size(); program_idx++) { + rval.push_back(f(*this->programs.at(program_idx), devices.at(program_idx))); + } + } else { + TT_ASSERT(std::holds_alternative>(callable)); + auto f = std::get>(callable); + for (std::size_t program_idx = 0; program_idx < this->programs.size() - 1; program_idx++) { + rval.push_back(f(*this->programs.at(program_idx))); + } + } + return rval; + } + + template + T distribute_to_mesh_device_impl_(const std::function& callable, std::shared_ptr& mesh_device) { + auto devices = mesh_device->get_devices(); + TT_ASSERT(devices.size() == this->programs.size(), + "MeshProgram created for {} devices cannot be mapped to a MeshDevice with {} devices", + this->programs.size(), devices.size()); + if constexpr (std::is_same::value) { + for (std::size_t program_idx = 0; program_idx < devices.size(); program_idx++) { + callable(*this->programs.at(program_idx), devices.at(program_idx)); + } + } else { + for (std::size_t program_idx = 0; program_idx < devices.size() - 1; program_idx++) { + callable(*this->programs.at(program_idx), devices.at(program_idx)); + } + return callable(*this->programs.at(devices.size() -1), devices.at(devices.size() -1)); + } + } + private: + std::vector> programs = {}; + +}; + +uint32_t CreateSemaphore( + MeshProgram& mesh_program, + const std::variant &core_spec, + uint32_t initial_value, + CoreType core_type = CoreType::WORKER); + +uint32_t CreateSemaphore( + MeshProgram& mesh_program, + const std::variant &core_spec, + uint32_t initial_value, + CoreType core_type, + chip_id_t device_id); + +CBHandle CreateCircularBuffer( + MeshProgram& mesh_program, + const std::variant &core_spec, + const CircularBufferConfig &config); + +CBHandle CreateCircularBuffer( + MeshProgram& mesh_program, + const std::variant &core_spec, + const CircularBufferConfig &config, + chip_id_t device_id); + +void SetRuntimeArgs( + MeshProgram& mesh_program, + KernelHandle kernel, + const std::variant &core_spec, + const std::vector &runtime_args); + +void SetRuntimeArgs( + MeshProgram& mesh_program, + KernelHandle kernel, + const std::variant &core_spec, + const std::vector &runtime_args, + chip_id_t device_id); + +KernelHandle CreateKernel( + MeshProgram& mesh_program, + const std::string &file_name, + const std::variant &core_spec, + const std::variant &config); + +KernelHandle CreateKernel( + MeshProgram& mesh_program, + const std::string &file_name, + const std::variant &core_spec, + const std::variant &config, + chip_id_t device_id); + +void EnqueueMeshProgram( + uint8_t cq_id, MeshProgram& mesh_program, std::shared_ptr mesh_device, bool blocking); + +void Finish(std::shared_ptr mesh_device, uint8_t cq_id); + +}