-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#0: Add programming examples for TT-Metalium multi-device native APIs
- Loading branch information
Showing
12 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
tt_metal/programming_examples/distributed/1_distributed_program_dispatch/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
set(DISTRIBUTED_PROGRAM_DISPATCH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/distributed_program_dispatch.cpp) | ||
add_executable(distributed_program_dispatch ${DISTRIBUTED_PROGRAM_DISPATCH_SRC}) | ||
|
||
target_link_libraries( | ||
distributed_program_dispatch | ||
PUBLIC | ||
tt_metal | ||
pthread | ||
) | ||
|
||
target_include_directories(distributed_program_dispatch PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) | ||
|
||
set_target_properties( | ||
distributed_program_dispatch | ||
PROPERTIES | ||
RUNTIME_OUTPUT_DIRECTORY | ||
${PROJECT_BINARY_DIR}/programming_examples/distributed | ||
) |
47 changes: 47 additions & 0 deletions
47
...ming_examples/distributed/1_distributed_program_dispatch/distributed_program_dispatch.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <tt-metalium/distributed.hpp> | ||
|
||
// Stand-alone example demonstrating usage of native multi-device TT-Metalium APIs | ||
// for issuing a program dispatch across a mesh of devices. | ||
int main(int argc, char** argv) { | ||
using namespace tt::tt_metal::distributed; | ||
|
||
auto mesh_device = MeshDevice::create(MeshDeviceConfig{.mesh_shape{2, 4}}); | ||
auto& cq = mesh_device->mesh_command_queue(); | ||
|
||
// In a typical single-device fashion, instantiate a program with | ||
// an example compute kernel targeting a 2x2 core range. | ||
auto example_program = CreateProgram(); | ||
auto target_tensix_cores = CoreRange{ | ||
CoreCoord{0, 0} /* start_coord */, CoreCoord{1, 1} /* end_coord */ | ||
}; | ||
|
||
auto compute_kernel_id = CreateKernel( | ||
example_program, | ||
"tt_metal/programming_examples/distributed/1_distributed_program_dispatch/kernels/void_kernel.cpp", | ||
target_tensix_cores, | ||
ComputeConfig{.compile_args = {}}); | ||
|
||
// Configure the runtime arguments for the kernel. | ||
auto runtime_args = std::vector<uint32_t>{}; | ||
SetRuntimeArgs(example_program, compute_kernel_id, target_tensix_cores, runtime_args); | ||
|
||
// Instantiate a MeshWorkload and attach the example program. We'll broadcast | ||
// this program by enqueueing it across all devices in our 2x4 mesh. | ||
auto mesh_workload = CreateMeshWorkload(); | ||
auto target_devices = LogicalDeviceRange{ | ||
DeviceCoord{0, 0} /* start_coord */, | ||
DeviceCoord{mesh_device->num_cols(), mesh_device->num_rows()} /* end_coord */ | ||
}; | ||
|
||
AddProgramToMeshWorkload(mesh_workload, example_program, target_devices); | ||
EnqueueMeshWorkload(cq, mesh_workload, false /* blocking */); | ||
|
||
// Synchronize the mesh command queue to ensure the workload has completed. | ||
Finish(cq); | ||
|
||
return 0; | ||
} |
17 changes: 17 additions & 0 deletions
17
...l/programming_examples/distributed/1_distributed_program_dispatch/kernels/void_kernel.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "debug/dprint.h" // required in all kernels using DPRINT | ||
#include "compute_kernel_api.h" | ||
|
||
namespace NAMESPACE { | ||
|
||
void MAIN { | ||
// Nothing to compute. Print respond message. | ||
// Make sure to export TT_METAL_DPRINT_CORES=0,0 before runtime. | ||
|
||
DPRINT_MATH(DPRINT << "Hello, World! I'm running a void compute kernel." << ENDL()); | ||
} | ||
|
||
} // namespace NAMESPACE |
18 changes: 18 additions & 0 deletions
18
tt_metal/programming_examples/distributed/2_distributed_buffer_rw/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
set(DISTRIBUTED_BUFFER_RW_SRC ${CMAKE_CURRENT_SOURCE_DIR}/distributed_buffer_rw.cpp) | ||
add_executable(distributed_buffer_rw ${DISTRIBUTED_BUFFER_RW_SRC}) | ||
|
||
target_link_libraries( | ||
distributed_buffer_rw | ||
PUBLIC | ||
tt_metal | ||
pthread | ||
) | ||
|
||
target_include_directories(distributed_buffer_rw PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) | ||
|
||
set_target_properties( | ||
distributed_buffer_rw | ||
PROPERTIES | ||
RUNTIME_OUTPUT_DIRECTORY | ||
${PROJECT_BINARY_DIR}/programming_examples/distributed | ||
) |
58 changes: 58 additions & 0 deletions
58
tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <tt-metalium/distributed.hpp> | ||
#include <tt-metalium/mesh_buffer.hpp> | ||
|
||
// Stand-alone example demonstrating usage of native multi-device TT-Metalium APIs | ||
// for issuing Read and Write commands to a distributed memory buffer spanning | ||
// multiple devices in a mesh. | ||
// | ||
// The example demonstrates how to: | ||
// 1. Perform a lock-step allocation of a distributed L1 MeshBuffer | ||
// containing data scattered across multiple devices in a mesh | ||
// 2. Enqueue a Write command to the MeshBuffer with random data | ||
// 3. Enqueue a Read command to the MeshBuffer and read back the data to a local buffer | ||
// 4. Verify that the data read back matches the original data | ||
int main(int argc, char** argv) { | ||
using namespace tt::tt_metal::distributed; | ||
using tt::tt_metal::distributed::ShardedBufferConfig; | ||
|
||
auto mesh_device = MeshDevice::create(MeshDeviceConfig{.mesh_shape{2, 4}}); | ||
auto& cq = mesh_device->mesh_command_queue(); | ||
|
||
// Define the shape of the shard and the distributed buffer. | ||
// We will create a distributed buffer with 8 shards of {32, 32} and distribute it across the devices in the mesh. | ||
auto shard_shape = Shape2D{32, 32}; | ||
auto distributed_buffer_shape = Shape2D{32 * mesh_device->num_rows(), 32 * mesh_device->num_cols()}; | ||
uint32_t tile_size_bytes = detail::TileSize(tt::DataFormat::UInt32); | ||
uint32_t distributed_buffer_size_bytes = 64 * 128 * tile_size_bytes; | ||
|
||
auto local_buffer_config = DeviceLocalBufferConfig{ | ||
.page_size = tile_size_bytes, | ||
.buffer_type = BufferType::L1, | ||
.buffer_layout = TensorMemoryLayout::INTERLEAVED, | ||
.bottom_up = false}; | ||
auto distributed_buffer_config = ShardedBufferConfig{ | ||
.global_size = distributed_buffer_size_bytes, | ||
.global_buffer_shape = distributed_buffer_shape, | ||
.shard_shape = shard_shape}; | ||
|
||
// Allocate a distributed buffer in L1 memory, spanning devices in the mesh. | ||
auto mesh_buffer = MeshBuffer::create(distributed_buffer_config, local_buffer_config, mesh_device.get()); | ||
|
||
// Enqueue a write to the distributed buffer (L1 banks across devices) with random data. | ||
std::vector<uint32_t> src_data = create_random_vector_of_bfloat16( | ||
distributed_buffer_size_bytes, 1, std::chrono::system_clock::now().time_since_epoch().count()); | ||
EnqueueWriteMeshBuffer(cq, mesh_buffer, src_data); | ||
|
||
// Enqueue a read from the distributed buffer (L1 banks across devices) to a local buffer. | ||
std::vector<uint32_t> read_back_data{}; | ||
EnqueueReadMeshBuffer(cq, read_back_data, mesh_buffer, true /* blocking */); | ||
|
||
// Data read back across all devices in the mesh should match the original data. | ||
assert(src_data == read_back_data); | ||
|
||
return 0; | ||
} |
18 changes: 18 additions & 0 deletions
18
tt_metal/programming_examples/distributed/3_distributed_eltwise_add/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
set(DISTRIBUTED_ELTWISE_ADD_SRC ${CMAKE_CURRENT_SOURCE_DIR}/distributed_eltwise_add.cpp) | ||
add_executable(distributed_eltwise_add ${DISTRIBUTED_ELTWISE_ADD_SRC}) | ||
|
||
target_link_libraries( | ||
distributed_eltwise_add | ||
PUBLIC | ||
tt_metal | ||
pthread | ||
) | ||
|
||
target_include_directories(distributed_eltwise_add PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) | ||
|
||
set_target_properties( | ||
distributed_eltwise_add | ||
PROPERTIES | ||
RUNTIME_OUTPUT_DIRECTORY | ||
${PROJECT_BINARY_DIR}/programming_examples/distributed | ||
) |
172 changes: 172 additions & 0 deletions
172
...al/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <functional> | ||
|
||
#include <tt-metalium/distributed.hpp> | ||
#include <tt-metalium/bfloat16.hpp> | ||
|
||
using namespace tt; | ||
using namespace tt::tt_metal; | ||
using namespace tt::tt_metal::distributed; | ||
|
||
Program CreateEltwiseAddProgram( | ||
const std::shared_ptr<MeshBuffer>& a, | ||
const std::shared_ptr<MeshBuffer>& b, | ||
const std::shared_ptr<MeshBuffer>& c, | ||
size_t tile_size_bytes, | ||
uint32_t num_tiles) { | ||
auto program = CreateProgram(); | ||
auto target_tensix_core = CoreRange(CoreCoord{0, 0}); | ||
|
||
// Add circular buffers for data movement | ||
constexpr uint32_t src0_cb_index = tt::CBIndex::c_0; | ||
constexpr uint32_t num_input_tiles = 1; | ||
CircularBufferConfig cb_src0_config = | ||
CircularBufferConfig(num_input_tiles * tile_size_bytes, {{src0_cb_index, tt::DataFormat::Float16_b}}) | ||
.set_page_size(src0_cb_index, tile_size_bytes); | ||
CBHandle cb_src0 = tt_metal::CreateCircularBuffer(program, target_tensix_core, cb_src0_config); | ||
|
||
constexpr uint32_t src1_cb_index = tt::CBIndex::c_1; | ||
CircularBufferConfig cb_src1_config = | ||
CircularBufferConfig(num_input_tiles * tile_size_bytes, {{src1_cb_index, tt::DataFormat::Float16_b}}) | ||
.set_page_size(src1_cb_index, tile_size_bytes); | ||
CBHandle cb_src1 = tt_metal::CreateCircularBuffer(program, target_tensix_core, cb_src1_config); | ||
|
||
constexpr uint32_t output_cb_index = tt::CBIndex::c_16; | ||
constexpr uint32_t num_output_tiles = 1; | ||
CircularBufferConfig cb_output_config = | ||
CircularBufferConfig(num_output_tiles * tile_size_bytes, {{output_cb_index, tt::DataFormat::Float16_b}}) | ||
.set_page_size(output_cb_index, tile_size_bytes); | ||
CBHandle cb_output = tt_metal::CreateCircularBuffer(program, target_tensix_core, cb_output_config); | ||
|
||
// Add data movement kernels | ||
KernelHandle reader = CreateKernel( | ||
program, | ||
"tt_metal/programming_examples/contributed/vecadd/kernels/interleaved_tile_read.cpp", | ||
target_tensix_core, | ||
DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); | ||
|
||
KernelHandle writer = CreateKernel( | ||
program, | ||
"tt_metal/programming_examples/contributed/vecadd/kernels/tile_write.cpp", | ||
target_tensix_core, | ||
DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); | ||
|
||
// Create the eltwise binary kernel | ||
auto compute = CreateKernel( | ||
program, | ||
"tt_metal/programming_examples/contributed/vecadd/kernels/add.cpp", | ||
target_tensix_core, | ||
ComputeConfig{ | ||
.math_fidelity = MathFidelity::HiFi4, | ||
.fp32_dest_acc_en = false, | ||
.math_approx_mode = false, | ||
.compile_args = {}, | ||
.defines = {{"ELTWISE_OP", "add_tiles"}, {"ELTWISE_OP_TYPE", "EltwiseBinaryType::ELWADD"}}}); | ||
|
||
// Set runtime arguments for each device | ||
SetRuntimeArgs(program, reader, target_tensix_core, {a->address(), b->address(), num_tiles}); | ||
SetRuntimeArgs(program, writer, target_tensix_core, {c->address(), num_tiles}); | ||
SetRuntimeArgs(program, compute, target_tensix_core, {num_tiles}); | ||
|
||
return program; | ||
} | ||
|
||
// The example demonstrates distributed element-wise addition across a 2x4 mesh of devices: | ||
// | ||
// 1. Allocating distributed buffers that automatically shard data across the mesh | ||
// 2. Initializing and distributing input data to each device's local memory | ||
// 3. Executing a distributed compute kernel that performs element-wise addition | ||
// in parallel across all devices | ||
// 4. Gathering and validating the results from the distributed computation | ||
// | ||
// The example showcases TT-Metalium's ability to abstract away the complexity | ||
// of distributed memory management and compute. | ||
int main(int argc, char** argv) { | ||
auto mesh_device = MeshDevice::create(MeshDeviceConfig{.mesh_shape{2, 4}}); | ||
|
||
// Define the global buffer shape and shard shape for distributed buffers | ||
auto shard_shape = Shape2D{32, 32}; | ||
auto distributed_buffer_shape = | ||
Shape2D{shard_shape.height() * mesh_device->num_rows(), shard_shape.width() * mesh_device->num_cols()}; | ||
auto num_tiles = 1; | ||
auto tile_size_bytes = detail::TileSize(tt::DataFormat::Float16_b); | ||
auto distributed_buffer_size_bytes = mesh_device->num_rows() * mesh_device->num_cols() * tile_size_bytes; | ||
|
||
// Configure device-local buffer settings | ||
auto local_buffer_config = DeviceLocalBufferConfig{ | ||
.page_size = tile_size_bytes, | ||
.buffer_type = BufferType::DRAM, | ||
.buffer_layout = TensorMemoryLayout::INTERLEAVED, | ||
.bottom_up = false}; | ||
auto distributed_buffer_config = tt::tt_metal::distributed::ShardedBufferConfig{ | ||
.global_size = distributed_buffer_size_bytes, | ||
.global_buffer_shape = distributed_buffer_shape, | ||
.shard_shape = shard_shape, | ||
.shard_orientation = ShardOrientation::ROW_MAJOR}; | ||
|
||
// Create distributed buffers for inputs and output | ||
auto a = MeshBuffer::create(distributed_buffer_config, local_buffer_config, mesh_device.get()); | ||
auto b = MeshBuffer::create(distributed_buffer_config, local_buffer_config, mesh_device.get()); | ||
auto c = MeshBuffer::create(distributed_buffer_config, local_buffer_config, mesh_device.get()); | ||
|
||
// Create and initialize source data | ||
constexpr float val_to_add = 0.5f; | ||
std::vector<uint32_t> a_data = | ||
create_random_vector_of_bfloat16(distributed_buffer_size_bytes, 1 /* rand_max_float */, 0 /* seed */); | ||
std::vector<uint32_t> b_data = create_constant_vector_of_bfloat16(distributed_buffer_size_bytes, val_to_add); | ||
|
||
// Write data to distributed buffers | ||
auto& cq = mesh_device->mesh_command_queue(); | ||
EnqueueWriteMeshBuffer(cq, a, a_data, false /* blocking */); | ||
EnqueueWriteMeshBuffer(cq, b, b_data, false /* blocking */); | ||
|
||
// Create program for distributed computation | ||
auto program = CreateEltwiseAddProgram(a, b, c, tile_size_bytes, num_tiles); | ||
|
||
// Create mesh workload and broadcast the program across all devices | ||
auto mesh_workload = CreateMeshWorkload(); | ||
auto device_range = LogicalDeviceRange{ | ||
DeviceCoord{0, 0} /* start_coord */, | ||
DeviceCoord{mesh_device->num_cols(), mesh_device->num_rows()} /* end_coord */ | ||
}; | ||
|
||
AddProgramToMeshWorkload(mesh_workload, program, device_range); | ||
EnqueueMeshWorkload(cq, mesh_workload, false /* blocking */); | ||
|
||
// Read back results | ||
std::vector<uint32_t> result_data(a_data.size(), 0); | ||
EnqueueReadMeshBuffer(cq, result_data, c, true /* blocking */); | ||
|
||
// Verify results | ||
auto transform_to_golden = [val_to_add](const bfloat16& a) { return bfloat16(a.to_float() + val_to_add); }; | ||
std::vector<uint32_t> golden_data = | ||
pack_bfloat16_vec_into_uint32_vec(unpack_uint32_vec_into_bfloat16_vec(a_data, transform_to_golden)); | ||
|
||
// Print partial results so we can see the output is correct (plus or minus some error due to BFP16 precision) | ||
std::cout << "Partial results: (note we are running under BFP16. It's going to be less accurate)\n"; | ||
bfloat16* a_bf16 = reinterpret_cast<bfloat16*>(a_data.data()); | ||
bfloat16* b_bf16 = reinterpret_cast<bfloat16*>(b_data.data()); | ||
bfloat16* c_bf16 = reinterpret_cast<bfloat16*>(result_data.data()); | ||
bfloat16* golden_bf16 = reinterpret_cast<bfloat16*>(golden_data.data()); | ||
|
||
size_t num_failures = 0; | ||
auto total_values = result_data.size() * 2; | ||
for (int i = 0; i < total_values; i++) { | ||
if (!is_close(c_bf16[i].to_float(), golden_bf16[i].to_float())) { | ||
num_failures++; | ||
} | ||
} | ||
|
||
std::cout << "Total values: " << total_values << "\n"; | ||
std::cout << "Distributed elementwise add verification: " << (total_values - num_failures) << " / " << total_values | ||
<< " passed\n"; | ||
if (num_failures > 0) { | ||
std::cout << "Distributed elementwise add verification failed with " << num_failures << " failures\n"; | ||
throw std::runtime_error("Distributed elementwise add verification failed"); | ||
} | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
add_subdirectory(1_distributed_program_dispatch) | ||
add_subdirectory(2_distributed_buffer_rw) | ||
add_subdirectory(3_distributed_eltwise_add) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Distributed Programming Examples | ||
|
||
This directory contains examples of distributed programming model using the TT-Metalium API. | ||
|
||
They are intended to be simple demonstrations for distributed program dispatch, distributed memory management, and end-to-end distributed program execution. | ||
|
||
Users familiar with the single-device TT-Metal programming model will find the distributed programming model to be a natural extension. |