diff --git a/tests/tt_metal/distributed/test_mesh_events.cpp b/tests/tt_metal/distributed/test_mesh_events.cpp index 85d5cae74d7..e08e2caead8 100644 --- a/tests/tt_metal/distributed/test_mesh_events.cpp +++ b/tests/tt_metal/distributed/test_mesh_events.cpp @@ -39,12 +39,11 @@ TEST_F(MeshEventsTestSuite, ReplicatedAsyncIO) { std::iota(src_vec.begin(), src_vec.end(), i); std::vector> readback_vecs = {}; - std::shared_ptr event = std::make_shared(); // Writes on CQ 0 EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), buf, src_vec); // Device to Device Synchronization - EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event); - EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event); + auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0)); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), write_event); // Reads on CQ 1 for (const auto& coord : MeshCoordinateRange(mesh_device_->shape())) { @@ -86,17 +85,16 @@ TEST_F(MeshEventsTestT3000, ShardedAsyncIO) { std::vector src_vec = std::vector(global_buffer_shape.height() * global_buffer_shape.width(), 0); std::iota(src_vec.begin(), src_vec.end(), i); - std::shared_ptr event = std::make_shared(); // Writes on CQ 0 EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), mesh_buffer, src_vec); if (i % 2) { // Test Host <-> Device synchronization - EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), event); - EventSynchronize(event); + auto write_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0)); + EventSynchronize(write_event); } else { // Test Device <-> Device synchronization - EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event); - EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event); + auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0)); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), write_event); } // Reads on CQ 1 std::vector dst_vec = {}; @@ -127,9 +125,6 @@ TEST_F(MeshEventsTestSuite, AsyncWorkloadAndIO) { std::vector src0_vec = create_constant_vector_of_bfloat16(src0_bufs[0]->size(), iter + 2); std::vector src1_vec = create_constant_vector_of_bfloat16(src1_bufs[0]->size(), iter + 3); - std::shared_ptr write_event = std::make_shared(); - std::shared_ptr op_event = std::make_shared(); - // Issue writes on MeshCQ 1 for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) { for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) { @@ -141,22 +136,22 @@ TEST_F(MeshEventsTestSuite, AsyncWorkloadAndIO) { } if (iter % 2) { // Test Host <-> Device Synchronization - EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), write_event); + auto write_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1)); EventSynchronize(write_event); } else { // Test Device <-> Device Synchronization - EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), write_event); + auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(1)); EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), write_event); } // Issue workloads on MeshCQ 0 EnqueueMeshWorkload(mesh_device_->mesh_command_queue(0), mesh_workload, false); if (iter % 2) { // Test Device <-> Device Synchronization - EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), op_event); + auto op_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0)); EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), op_event); } else { // Test Host <-> Device Synchronization - EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), op_event); + auto op_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0)); EventSynchronize(op_event); } @@ -210,12 +205,10 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) { MeshCoordinateRange devices_1(MeshCoordinate{1, 0}, MeshCoordinate{1, mesh_device_->num_cols() - 1}); std::vector> readback_vecs = {}; - std::shared_ptr event_0 = std::make_shared(); - std::shared_ptr event_1 = std::make_shared(); mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_0, false); - EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), event_0, {}, devices_0); - EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), event_0); + auto event0 = EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), {}, devices_0); + EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), event0); for (const auto& coord : devices_0) { readback_vecs.push_back({}); @@ -224,8 +217,8 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) { } mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_1, false); - EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), event_1, {}, devices_1); - EventSynchronize(event_1); + auto event1 = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), {}, devices_1); + EventSynchronize(event1); for (const auto& coord : devices_1) { readback_vecs.push_back({}); diff --git a/tt_metal/api/tt-metalium/distributed.hpp b/tt_metal/api/tt-metalium/distributed.hpp index dae23824eee..6f3ffa37eb1 100644 --- a/tt_metal/api/tt-metalium/distributed.hpp +++ b/tt_metal/api/tt-metalium/distributed.hpp @@ -80,21 +80,19 @@ void EnqueueReadMeshBuffer( mesh_cq.enqueue_read_mesh_buffer(dst.data(), mesh_buffer, blocking); } -void EnqueueRecordEvent( +MeshEvent EnqueueRecordEvent( MeshCommandQueue& mesh_cq, - const std::shared_ptr& event, tt::stl::Span sub_device_ids = {}, const std::optional& device_range = std::nullopt); -void EnqueueRecordEventToHost( +MeshEvent EnqueueRecordEventToHost( MeshCommandQueue& mesh_cq, - const std::shared_ptr& event, tt::stl::Span sub_device_ids = {}, const std::optional& device_range = std::nullopt); -void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr& event); +void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const MeshEvent& event); -void EventSynchronize(const std::shared_ptr& event); +void EventSynchronize(const MeshEvent& event); MeshTraceId BeginTraceCapture(MeshDevice* device, uint8_t cq_id); diff --git a/tt_metal/api/tt-metalium/mesh_command_queue.hpp b/tt_metal/api/tt-metalium/mesh_command_queue.hpp index cd6feb366a9..1cd7025e793 100644 --- a/tt_metal/api/tt-metalium/mesh_command_queue.hpp +++ b/tt_metal/api/tt-metalium/mesh_command_queue.hpp @@ -50,8 +50,7 @@ class MeshCommandQueue { // Helper functions for read and write entire Sharded-MeshBuffers void write_sharded_buffer(const MeshBuffer& buffer, const void* src); void read_sharded_buffer(MeshBuffer& buffer, void* dst); - void enqueue_record_event_helper( - const std::shared_ptr& event, + MeshEvent enqueue_record_event_helper( tt::stl::Span sub_device_ids, bool notify_host, const std::optional& device_range = std::nullopt); @@ -156,17 +155,15 @@ class MeshCommandQueue { const std::shared_ptr& mesh_buffer, bool blocking); - void enqueue_record_event( - const std::shared_ptr& event, + MeshEvent enqueue_record_event( tt::stl::Span sub_device_ids = {}, const std::optional& device_range = std::nullopt); - void enqueue_record_event_to_host( - const std::shared_ptr& event, + MeshEvent enqueue_record_event_to_host( tt::stl::Span sub_device_ids = {}, const std::optional& device_range = std::nullopt); - void enqueue_wait_for_event(const std::shared_ptr& sync_event); + void enqueue_wait_for_event(const MeshEvent& sync_event); void drain_events_from_completion_queue(); - void verify_reported_events_after_draining(const std::shared_ptr& event); + void verify_reported_events_after_draining(const MeshEvent& event); void finish(tt::stl::Span sub_device_ids = {}); void reset_worker_state( bool reset_launch_msg_state, diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index a8f5e961616..025a38089da 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -124,6 +124,10 @@ class MeshCoordinateRange { // Returns the intersection of the range with the given range. std::optional intersection(const MeshCoordinateRange& range) const; + // Needed for reflect / fmt + static constexpr auto attribute_names = std::forward_as_tuple("start", "end"); + auto attribute_values() const { return std::forward_as_tuple(start_, end_); } + class Iterator { public: Iterator& operator++(); diff --git a/tt_metal/api/tt-metalium/mesh_event.hpp b/tt_metal/api/tt-metalium/mesh_event.hpp index 72beaeaef94..ed508062a89 100644 --- a/tt_metal/api/tt-metalium/mesh_event.hpp +++ b/tt_metal/api/tt-metalium/mesh_event.hpp @@ -4,16 +4,28 @@ #pragma once +#include #include "mesh_device.hpp" namespace tt::tt_metal::distributed { class MeshEvent { public: - MeshDevice* device = nullptr; - MeshCoordinateRange device_range = MeshCoordinateRange(MeshCoordinate(0, 0), MeshCoordinate(0, 0)); - uint32_t cq_id = 0; - uint32_t event_id = 0; + MeshEvent(uint32_t id, MeshDevice* device, uint32_t mesh_cq_id, const MeshCoordinateRange& device_range); + + // Returns references to the event data. + uint32_t id() const; + MeshDevice* device() const; + uint32_t mesh_cq_id() const; + const MeshCoordinateRange& device_range() const; + + friend std::ostream& operator<<(std::ostream& os, const MeshEvent& event); + +private: + uint32_t id_ = 0; + MeshDevice* device_ = nullptr; + uint32_t mesh_cq_id_ = 0; + MeshCoordinateRange device_range_; }; } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/CMakeLists.txt b/tt_metal/distributed/CMakeLists.txt index 3879a1648eb..4c44e2951b7 100644 --- a/tt_metal/distributed/CMakeLists.txt +++ b/tt_metal/distributed/CMakeLists.txt @@ -1,14 +1,15 @@ set(DISTRIBUTED_SRC ${CMAKE_CURRENT_SOURCE_DIR}/coordinate_translation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mesh_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mesh_command_queue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mesh_event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mesh_trace.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mesh_workload.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mesh_workload_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/mesh_command_queue.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/mesh_buffer.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/mesh_trace.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp ) add_library(distributed OBJECT ${DISTRIBUTED_SRC}) diff --git a/tt_metal/distributed/distributed.cpp b/tt_metal/distributed/distributed.cpp index 9bc3a2f9185..7fc4f87d032 100644 --- a/tt_metal/distributed/distributed.cpp +++ b/tt_metal/distributed/distributed.cpp @@ -19,28 +19,24 @@ void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, mesh_cq.enqueue_mesh_workload(mesh_workload, blocking); } -void EnqueueRecordEvent( +MeshEvent EnqueueRecordEvent( MeshCommandQueue& mesh_cq, - const std::shared_ptr& event, tt::stl::Span sub_device_ids, const std::optional& device_range) { - mesh_cq.enqueue_record_event(event, sub_device_ids, device_range); + return mesh_cq.enqueue_record_event(sub_device_ids, device_range); } -void EnqueueRecordEventToHost( +MeshEvent EnqueueRecordEventToHost( MeshCommandQueue& mesh_cq, - const std::shared_ptr& event, tt::stl::Span sub_device_ids, const std::optional& device_range) { - mesh_cq.enqueue_record_event_to_host(event, sub_device_ids, device_range); + return mesh_cq.enqueue_record_event_to_host(sub_device_ids, device_range); } -void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr& event) { - mesh_cq.enqueue_wait_for_event(event); -} +void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const MeshEvent& event) { mesh_cq.enqueue_wait_for_event(event); } -void EventSynchronize(const std::shared_ptr& event) { - auto& mesh_cq = event->device->mesh_command_queue(event->cq_id); +void EventSynchronize(const MeshEvent& event) { + auto& mesh_cq = event.device()->mesh_command_queue(event.mesh_cq_id()); mesh_cq.drain_events_from_completion_queue(); mesh_cq.verify_reported_events_after_draining(event); } diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index a46cb035266..7635a4bf4ec 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -188,8 +188,7 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b } void MeshCommandQueue::finish(tt::stl::Span sub_device_ids) { - std::shared_ptr event = std::make_shared(); - this->enqueue_record_event_to_host(event, sub_device_ids); + auto event = this->enqueue_record_event_to_host(sub_device_ids); this->drain_events_from_completion_queue(); this->verify_reported_events_after_draining(event); } @@ -478,22 +477,22 @@ void MeshCommandQueue::enqueue_read_shards( } } -void MeshCommandQueue::enqueue_record_event_helper( - const std::shared_ptr& event, +MeshEvent MeshCommandQueue::enqueue_record_event_helper( tt::stl::Span sub_device_ids, bool notify_host, const std::optional& device_range) { auto& sysmem_manager = this->reference_sysmem_manager(); - event->cq_id = id_; - event->event_id = sysmem_manager.get_next_event(id_); - event->device = mesh_device_; - event->device_range = device_range.value_or(MeshCoordinateRange(mesh_device_->shape())); + auto event = MeshEvent( + sysmem_manager.get_next_event(id_), + mesh_device_, + id_, + device_range.value_or(MeshCoordinateRange(mesh_device_->shape()))); sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); - for (const auto& coord : event->device_range) { + for (const auto& coord : event.device_range()) { event_dispatch::issue_record_event_commands( mesh_device_, - event->event_id, + event.id(), id_, mesh_device_->num_hw_cqs(), mesh_device_->get_device(coord)->sysmem_manager(), @@ -501,28 +500,27 @@ void MeshCommandQueue::enqueue_record_event_helper( expected_num_workers_completed_, notify_host); } + + return event; } -void MeshCommandQueue::enqueue_record_event( - const std::shared_ptr& event, - tt::stl::Span sub_device_ids, - const std::optional& device_range) { - this->enqueue_record_event_helper(event, sub_device_ids, false, device_range); +MeshEvent MeshCommandQueue::enqueue_record_event( + tt::stl::Span sub_device_ids, const std::optional& device_range) { + return this->enqueue_record_event_helper(sub_device_ids, /*notify_host=*/false, device_range); } -void MeshCommandQueue::enqueue_record_event_to_host( - const std::shared_ptr& event, - tt::stl::Span sub_device_ids, - const std::optional& device_range) { - this->enqueue_record_event_helper(event, sub_device_ids, true, device_range); +MeshEvent MeshCommandQueue::enqueue_record_event_to_host( + tt::stl::Span sub_device_ids, const std::optional& device_range) { + auto event = this->enqueue_record_event_helper(sub_device_ids, /*notify_host=*/true, device_range); event_descriptors_.push(std::make_shared(MeshReadEventDescriptor{ - .single_device_descriptor = ReadEventDescriptor(event->event_id), .device_range = event->device_range})); + .single_device_descriptor = ReadEventDescriptor(event.id()), .device_range = event.device_range()})); + return event; } -void MeshCommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event) { - for (const auto& coord : sync_event->device_range) { +void MeshCommandQueue::enqueue_wait_for_event(const MeshEvent& sync_event) { + for (const auto& coord : sync_event.device_range()) { event_dispatch::issue_wait_for_event_commands( - id_, sync_event->cq_id, mesh_device_->get_device(coord)->sysmem_manager(), sync_event->event_id); + id_, sync_event.mesh_cq_id(), mesh_device_->get_device(coord)->sysmem_manager(), sync_event.id()); } } @@ -546,13 +544,14 @@ void MeshCommandQueue::drain_events_from_completion_queue() { } } -void MeshCommandQueue::verify_reported_events_after_draining(const std::shared_ptr& event) { - auto& device_range = event->device_range; +void MeshCommandQueue::verify_reported_events_after_draining(const MeshEvent& event) { + auto& device_range = event.device_range(); for (const auto& coord : device_range) { TT_FATAL( - mesh_device_->get_device(coord)->sysmem_manager().get_last_completed_event(event->cq_id) >= event->event_id, + mesh_device_->get_device(coord)->sysmem_manager().get_last_completed_event(event.mesh_cq_id()) >= + event.id(), "Expected to see event id {} in completion queue", - event->event_id); + event.id()); } } diff --git a/tt_metal/distributed/mesh_event.cpp b/tt_metal/distributed/mesh_event.cpp new file mode 100644 index 00000000000..a2711672f2b --- /dev/null +++ b/tt_metal/distributed/mesh_event.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +namespace tt::tt_metal::distributed { + +MeshEvent::MeshEvent(uint32_t id, MeshDevice* device, uint32_t mesh_cq_id, const MeshCoordinateRange& device_range) : + id_(id), device_(device), mesh_cq_id_(mesh_cq_id), device_range_(device_range) {} + +uint32_t MeshEvent::id() const { return id_; } +MeshDevice* MeshEvent::device() const { return device_; } +uint32_t MeshEvent::mesh_cq_id() const { return mesh_cq_id_; } +const MeshCoordinateRange& MeshEvent::device_range() const { return device_range_; } + +std::ostream& operator<<(std::ostream& os, const MeshEvent& event) { + os << "MeshEvent(id=" << event.id() << ", device_id=" << event.device()->id() + << ", mesh_cq_id=" << event.mesh_cq_id() << ", device_range=" << event.device_range() << ")"; + return os; +} + +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/impl/dispatch/hardware_command_queue.cpp b/tt_metal/impl/dispatch/hardware_command_queue.cpp index b26c5d08263..e309e87d7f9 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.cpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.cpp @@ -382,7 +382,7 @@ void HWCommandQueue::enqueue_record_event( event->cq_id = this->id_; event->event_id = this->manager.get_next_event(this->id_); event->device = this->device_; - event->ready = true; // what does this mean??? + event->ready = true; sub_device_ids = buffer_dispatch::select_sub_device_ids(this->device_, sub_device_ids); event_dispatch::issue_record_event_commands( diff --git a/tt_metal/programming_examples/distributed/4_distributed_trace_and_events/distributed_trace_and_events.cpp b/tt_metal/programming_examples/distributed/4_distributed_trace_and_events/distributed_trace_and_events.cpp index 452165c348c..b8eb9472f10 100644 --- a/tt_metal/programming_examples/distributed/4_distributed_trace_and_events/distributed_trace_and_events.cpp +++ b/tt_metal/programming_examples/distributed/4_distributed_trace_and_events/distributed_trace_and_events.cpp @@ -243,20 +243,17 @@ int main(int argc, char** argv) { // =========== Step 7: Write inputs on MeshCQ1 =========== // IO is done through MeshCQ1 and Workload dispatch is done through MeshCQ0. Use MeshEvents to synchronize the // independent MeshCQs. - std::shared_ptr write_event = std::make_shared(); - std::shared_ptr trace_event = std::make_shared(); - EnqueueWriteMeshBuffer(data_movement_cq, add_src0_buf, add_src0_vec); EnqueueWriteMeshBuffer(data_movement_cq, add_src1_buf, add_src1_vec); EnqueueWriteMeshBuffer(data_movement_cq, mul_sub_src0_buf, mul_sub_src0_vec); EnqueueWriteMeshBuffer(data_movement_cq, mul_sub_src1_buf, mul_sub_src1_vec); // Synchronize - EnqueueRecordEvent(data_movement_cq, write_event); + MeshEvent write_event = EnqueueRecordEvent(data_movement_cq); EnqueueWaitForEvent(workload_cq, write_event); // =========== Step 8: Run MeshTrace on MeshCQ0 =========== ReplayTrace(mesh_device.get(), workload_cq_id, trace_id, false); // Synchronize - EnqueueRecordEvent(workload_cq, trace_event); + MeshEvent trace_event = EnqueueRecordEvent(workload_cq); EnqueueWaitForEvent(data_movement_cq, trace_event); // =========== Step 9: Read Outputs on MeshCQ1 =========== std::vector add_dst_vec = {}; diff --git a/ttnn/cpp/pybind11/events.cpp b/ttnn/cpp/pybind11/events.cpp index abc64a7cf2f..354f0660fc0 100644 --- a/ttnn/cpp/pybind11/events.cpp +++ b/ttnn/cpp/pybind11/events.cpp @@ -17,6 +17,11 @@ namespace ttnn::events { void py_module_types(py::module& module) { py::class_>(module, "event"); py::class_(module, "multi_device_event"); + py::class_(module, "MeshEvent").def("__repr__", [](const MeshEvent& self) { + std::ostringstream str; + str << self; + return str.str(); + }); } void py_module(py::module& module) { @@ -99,6 +104,25 @@ void py_module(py::module& module) { event (event): The event used to record completion of preceeding commands. sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to sub-devices set by set_sub_device_stall_group. )doc"); + + module.def( + "record_mesh_event", + py::overload_cast< + MeshDevice*, + QueueId, + const std::vector&, + const std::optional&>(&record_mesh_event), + py::arg("mesh_device"), + py::arg("cq_id"), + py::arg("sub_device_ids") = std::vector(), + py::arg("device_range") = std::nullopt); + + module.def( + "wait_for_mesh_event", + py::overload_cast(&wait_for_mesh_event), + py::arg("mesh_device"), + py::arg("cq_id"), + py::arg("mesh_event")); } } // namespace ttnn::events diff --git a/ttnn/cpp/pybind11/operations/trace.hpp b/ttnn/cpp/pybind11/operations/trace.hpp index ffb72241f4d..2a9f35ad87a 100644 --- a/ttnn/cpp/pybind11/operations/trace.hpp +++ b/ttnn/cpp/pybind11/operations/trace.hpp @@ -97,32 +97,32 @@ void py_module(py::module& module) { module.def( "begin_mesh_trace_capture", - [](MeshDevice* device, MeshCommandQueueId mesh_cq_id) { - return ttnn::operations::trace::begin_mesh_trace_capture(device, mesh_cq_id); + [](MeshDevice* device, QueueId cq_id) { + return ttnn::operations::trace::begin_mesh_trace_capture(device, cq_id); }, py::arg("mesh_device"), py::kw_only(), - py::arg("cq_id") = ttnn::DefaultMeshCommandQueueId); + py::arg("cq_id") = ttnn::DefaultQueueId); module.def( "end_mesh_trace_capture", - [](MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id) { - return ttnn::operations::trace::end_mesh_trace_capture(device, trace_id, mesh_cq_id); + [](MeshDevice* device, MeshTraceId trace_id, QueueId cq_id) { + return ttnn::operations::trace::end_mesh_trace_capture(device, trace_id, cq_id); }, py::arg("mesh_device"), py::arg("trace_id"), py::kw_only(), - py::arg("cq_id") = ttnn::DefaultMeshCommandQueueId); + py::arg("cq_id") = ttnn::DefaultQueueId); module.def( "execute_mesh_trace", - [](MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id, bool blocking) { - return ttnn::operations::trace::execute_mesh_trace(device, trace_id, mesh_cq_id, blocking); + [](MeshDevice* device, MeshTraceId trace_id, QueueId cq_id, bool blocking) { + return ttnn::operations::trace::execute_mesh_trace(device, trace_id, cq_id, blocking); }, py::arg("mesh_device"), py::arg("trace_id"), py::kw_only(), - py::arg("cq_id") = ttnn::DefaultMeshCommandQueueId, + py::arg("cq_id") = ttnn::DefaultQueueId, py::arg("blocking") = true); module.def( diff --git a/ttnn/cpp/pybind11/types.cpp b/ttnn/cpp/pybind11/types.cpp index e4a479eb181..330bb31e5d6 100644 --- a/ttnn/cpp/pybind11/types.cpp +++ b/ttnn/cpp/pybind11/types.cpp @@ -30,15 +30,6 @@ void py_module_types(py::module& module) { "__repr__", [](const ttnn::QueueId& self) { return "QueueId(" + std::to_string(static_cast(*self)) + ")"; }) .def(py::self == py::self); - py::class_(module, "MeshCommandQueueId") - .def(py::init()) - .def("__int__", [](const ttnn::MeshCommandQueueId& self) { return static_cast(*self); }) - .def( - "__repr__", - [](const ttnn::MeshCommandQueueId& self) { - return "MeshCommandQueueId(" + std::to_string(static_cast(*self)) + ")"; - }) - .def(py::self == py::self); export_enum(module, "BcastOpMath"); export_enum(module, "BcastOpDim"); diff --git a/ttnn/cpp/ttnn/common/queue_id.hpp b/ttnn/cpp/ttnn/common/queue_id.hpp index 7bd316d6b0e..b7e98d9858b 100644 --- a/ttnn/cpp/ttnn/common/queue_id.hpp +++ b/ttnn/cpp/ttnn/common/queue_id.hpp @@ -19,9 +19,6 @@ namespace ttnn { using QueueId = tt::stl::StrongType; constexpr QueueId DefaultQueueId = QueueId(0); -using MeshCommandQueueId = tt::stl::StrongType; -constexpr MeshCommandQueueId DefaultMeshCommandQueueId = MeshCommandQueueId(0); - } // namespace ttnn // Exporting to tt::tt_metal namespace because ttnn @@ -29,6 +26,5 @@ constexpr MeshCommandQueueId DefaultMeshCommandQueueId = MeshCommandQueueId(0); namespace tt::tt_metal { using QueueId = ttnn::QueueId; -using MeshCommandQueueId = ttnn::MeshCommandQueueId; } // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index c6bafa80a3f..9ad24cf4aee 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -27,8 +27,9 @@ namespace py = pybind11; void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); - py::class_(module, "MeshShape", "Struct representing the shape of a mesh device."); - py::class_(module, "MeshCoordinate", "Struct representing the coordinate of a mesh device."); + py::class_(module, "MeshShape", "Shape of a mesh device."); + py::class_(module, "MeshCoordinate", "Coordinate within a mesh device."); + py::class_(module, "MeshCoordinateRange", "Range of coordinates within a mesh device."); } void py_module(py::module& module) { @@ -88,6 +89,29 @@ void py_module(py::module& module) { [](const MeshCoordinate& mc) { return py::make_iterator(mc.coords().begin(), mc.coords().end()); }, py::keep_alive<0, 1>()); + static_cast>(module.attr("MeshCoordinateRange")) + .def( + py::init( + [](const MeshCoordinate& start, const MeshCoordinate& end) { return MeshCoordinateRange(start, end); }), + "Constructor with specified start and end coordinates.", + py::arg("start"), + py::arg("end")) + .def( + py::init([](const MeshShape& shape) { return MeshCoordinateRange(shape); }), + "Constructor that spans the entire mesh.", + py::arg("shape")) + .def( + "__repr__", + [](const MeshCoordinateRange& mcr) { + std::ostringstream str; + str << mcr; + return str.str(); + }) + .def( + "__iter__", + [](const MeshCoordinateRange& mcr) { return py::make_iterator(mcr.begin(), mcr.end()); }, + py::keep_alive<0, 1>()); + auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( @@ -392,7 +416,6 @@ void py_module(py::module& module) { py::arg("tensors"), py::kw_only()); module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); - module.attr("DefaultMeshCommandQueueId") = ttnn::DefaultMeshCommandQueueId; } } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/types.hpp b/ttnn/cpp/ttnn/distributed/types.hpp index c97df2a667d..4388e5d8c67 100644 --- a/ttnn/cpp/ttnn/distributed/types.hpp +++ b/ttnn/cpp/ttnn/distributed/types.hpp @@ -14,6 +14,7 @@ namespace ttnn::distributed { using MeshShape = tt::tt_metal::distributed::MeshShape; using MeshCoordinate = tt::tt_metal::distributed::MeshCoordinate; +using MeshCoordinateRange = tt::tt_metal::distributed::MeshCoordinateRange; using DeviceIds = tt::tt_metal::distributed::DeviceIds; using MeshDevice = tt::tt_metal::distributed::MeshDevice; using SystemMesh = tt::tt_metal::distributed::SystemMesh; @@ -28,6 +29,7 @@ namespace ttnn { // These types are exported to the ttnn namespace for convenience. using ttnn::distributed::DeviceIds; using ttnn::distributed::MeshCoordinate; +using ttnn::distributed::MeshCoordinateRange; using ttnn::distributed::MeshDevice; using ttnn::distributed::MeshDeviceConfig; using ttnn::distributed::MeshDeviceView; diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index 54d13fead11..525c353941b 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -6,24 +6,18 @@ #include #include +#include "tt-metalium/distributed.hpp" +#include "ttnn/common/queue_id.hpp" #include "ttnn/distributed/types.hpp" #include -using namespace tt::tt_metal; - namespace ttnn::events { -MultiDeviceEvent::MultiDeviceEvent(MeshDevice* mesh_device) { - TT_ASSERT( - mesh_device != nullptr, "Must provide a valid mesh_device when initializing an event on multiple devices."); - auto devices = mesh_device->get_devices(); - this->events = std::vector>(devices.size()); - for (int event_idx = 0; event_idx < devices.size(); event_idx++) { - this->events[event_idx] = std::make_shared(); - this->events[event_idx]->device = devices[event_idx]; - } -} +using ::tt::tt_metal::EnqueueRecordEvent; +using ::tt::tt_metal::EnqueueWaitForEvent; +using ::tt::tt_metal::distributed::EnqueueRecordEventToHost; +using ::tt::tt_metal::distributed::EnqueueWaitForEvent; std::shared_ptr create_event(IDevice* device) { std::shared_ptr event = std::make_shared(); @@ -43,7 +37,15 @@ void wait_for_event(QueueId cq_id, const std::shared_ptr& event) { device->push_work([device, event, cq_id] { EnqueueWaitForEvent(device->command_queue(*cq_id), event); }); } -MultiDeviceEvent create_event(MeshDevice* mesh_device) { return MultiDeviceEvent(mesh_device); } +MultiDeviceEvent create_event(MeshDevice* mesh_device) { + MultiDeviceEvent multi_device_event; + + multi_device_event.events.reserve(mesh_device->get_devices().size()); + for (auto* device : mesh_device->get_devices()) { + multi_device_event.events.push_back(create_event(device)); + } + return multi_device_event; +} void record_event( QueueId cq_id, const MultiDeviceEvent& multi_device_event, const std::vector& sub_device_ids) { @@ -58,4 +60,16 @@ void wait_for_event(QueueId cq_id, const MultiDeviceEvent& multi_device_event) { } } +MeshEvent record_mesh_event( + MeshDevice* mesh_device, + QueueId cq_id, + const std::vector& sub_device_ids, + const std::optional& device_range) { + return EnqueueRecordEventToHost(mesh_device->mesh_command_queue(*cq_id), sub_device_ids, device_range); +} + +void wait_for_mesh_event(MeshDevice* mesh_device, QueueId cq_id, const MeshEvent& event) { + EnqueueWaitForEvent(mesh_device->mesh_command_queue(*cq_id), event); +} + } // namespace ttnn::events diff --git a/ttnn/cpp/ttnn/events.hpp b/ttnn/cpp/ttnn/events.hpp index b07435706b8..cb20d24e78a 100644 --- a/ttnn/cpp/ttnn/events.hpp +++ b/ttnn/cpp/ttnn/events.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include "tt-metalium/mesh_event.hpp" #include "ttnn/common/queue_id.hpp" #include "ttnn/distributed/types.hpp" @@ -12,12 +13,12 @@ #include "tt-metalium/event.hpp" #include "tt-metalium/sub_device_types.hpp" -namespace ttnn::events { +namespace ttnn { + +using MeshEvent = tt::tt_metal::distributed::MeshEvent; + +namespace events { -struct MultiDeviceEvent { - MultiDeviceEvent(MeshDevice* mesh_device); - std::vector> events; -}; // Single Device APIs std::shared_ptr create_event(IDevice* device); void record_event( @@ -25,10 +26,22 @@ void record_event( const std::shared_ptr& event, const std::vector& sub_device_ids = {}); void wait_for_event(QueueId cq_id, const std::shared_ptr& event); + // Multi Device APIs +struct MultiDeviceEvent { + std::vector> events; +}; MultiDeviceEvent create_event(MeshDevice* mesh_device); void record_event( QueueId cq_id, const MultiDeviceEvent& event, const std::vector& sub_device_ids = {}); void wait_for_event(QueueId cq_id, const MultiDeviceEvent& event); -} // namespace ttnn::events +MeshEvent record_mesh_event( + MeshDevice* mesh_device, + QueueId cq_id, + const std::vector& sub_device_ids = {}, + const std::optional& device_range = std::nullopt); +void wait_for_mesh_event(MeshDevice* mesh_device, QueueId cq_id, const MeshEvent& event); + +} // namespace events +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/trace.cpp b/ttnn/cpp/ttnn/operations/trace.cpp index 88d1f52c9da..00f7c13253c 100644 --- a/ttnn/cpp/ttnn/operations/trace.cpp +++ b/ttnn/cpp/ttnn/operations/trace.cpp @@ -33,19 +33,19 @@ void release_trace(IDevice* device, uint32_t trace_id) { device->release_trace(trace_id); } -MeshTraceId begin_mesh_trace_capture(MeshDevice* device, MeshCommandQueueId mesh_cq_id) { +MeshTraceId begin_mesh_trace_capture(MeshDevice* device, QueueId cq_id) { ZoneScoped; MeshTraceId trace_id = tt::tt_metal::distributed::MeshTrace::next_id(); - device->begin_mesh_trace(*mesh_cq_id, trace_id); + device->begin_mesh_trace(*cq_id, trace_id); return trace_id; } -void end_mesh_trace_capture(MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id) { +void end_mesh_trace_capture(MeshDevice* device, MeshTraceId trace_id, QueueId cq_id) { ZoneScoped; - device->end_mesh_trace(*mesh_cq_id, trace_id); + device->end_mesh_trace(*cq_id, trace_id); } -void execute_mesh_trace(MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id, bool blocking) { +void execute_mesh_trace(MeshDevice* device, MeshTraceId trace_id, QueueId cq_id, bool blocking) { ZoneScoped; - device->replay_mesh_trace(*mesh_cq_id, trace_id, blocking); + device->replay_mesh_trace(*cq_id, trace_id, blocking); } void release_mesh_trace(MeshDevice* device, MeshTraceId trace_id) { ZoneScoped; diff --git a/ttnn/cpp/ttnn/operations/trace.hpp b/ttnn/cpp/ttnn/operations/trace.hpp index bed9c1b300c..7958141fa2a 100644 --- a/ttnn/cpp/ttnn/operations/trace.hpp +++ b/ttnn/cpp/ttnn/operations/trace.hpp @@ -23,9 +23,9 @@ void execute_trace(IDevice* device, uint32_t trace_id, QueueId cq_id, bool block void release_trace(IDevice* device, uint32_t trace_id); // Trace APIs - Multi Device -MeshTraceId begin_mesh_trace_capture(MeshDevice* device, MeshCommandQueueId mesh_cq_id); -void end_mesh_trace_capture(MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id); -void execute_mesh_trace(MeshDevice* device, MeshTraceId trace_id, MeshCommandQueueId mesh_cq_id, bool blocking); +MeshTraceId begin_mesh_trace_capture(MeshDevice* device, QueueId cq_id); +void end_mesh_trace_capture(MeshDevice* device, MeshTraceId trace_id, QueueId cq_id); +void execute_mesh_trace(MeshDevice* device, MeshTraceId trace_id, QueueId cq_id, bool blocking); void release_mesh_trace(MeshDevice* device, MeshTraceId trace_id); } // namespace trace diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index ffa60f1d85d..0e9a074211d 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -99,10 +99,18 @@ def manage_config(name, value): get_device_tensors, aggregate_as_tensor, get_t3k_physical_device_ids_ring, - DefaultMeshCommandQueueId, ) -from ttnn._ttnn.events import create_event, record_event, wait_for_event +from ttnn._ttnn.events import ( + MeshEvent, + create_event, + record_event, + wait_for_event, + record_mesh_event, + wait_for_mesh_event, + record_mesh_event, + wait_for_mesh_event, +) from ttnn._ttnn.operations.trace import ( MeshTraceId, @@ -170,6 +178,8 @@ def manage_config(name, value): GrayskullComputeKernelConfig, MeshShape, MeshCoordinate, + MeshCoordinateRange, + QueueId, UnaryWithParam, UnaryOpType, BinaryOpType, diff --git a/ttnn/ttnn/types.py b/ttnn/ttnn/types.py index d8cd7380a52..a13cd72a3e4 100644 --- a/ttnn/ttnn/types.py +++ b/ttnn/ttnn/types.py @@ -66,6 +66,7 @@ class ShardStrategy(Enum): MeshShape = ttnn._ttnn.multi_device.MeshShape MeshCoordinate = ttnn._ttnn.multi_device.MeshCoordinate +MeshCoordinateRange = ttnn._ttnn.multi_device.MeshCoordinateRange ShardOrientation = ttnn._ttnn.tensor.ShardOrientation ShardMode = ttnn._ttnn.tensor.ShardMode ShardSpec = ttnn._ttnn.tensor.ShardSpec @@ -73,6 +74,7 @@ class ShardStrategy(Enum): CoreRange = ttnn._ttnn.tensor.CoreRange CoreCoord = ttnn._ttnn.tensor.CoreCoord +QueueId = ttnn._ttnn.types.QueueId UnaryWithParam = ttnn._ttnn.activation.UnaryWithParam UnaryOpType = ttnn._ttnn.activation.UnaryOpType