Skip to content

Commit

Permalink
Expose mesh event to TTNN (#18461)
Browse files Browse the repository at this point in the history
### Ticket
N/A

### Problem description
Expose mesh events to TTNN, further integration of TT-distributed with
TTNN.

### What's changed
Related cleanups bundled in this PR:
* Make `EnqueueRecordEvent`, `EnqueueRecordEventToHost` return
`MeshEvent` by value, instead of accepting `std::shared_ptr<MeshEvent>`
and mutating the reference internally.
* Make `EnqueueWaitForEvent` and `EventSynchronize` accept `MeshEvent`
by constant reference.
* Expose `MeshCoordinateRange` to TTNN - this is needed for the
`MeshEvent` APIs.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13577312078)
- [X] New/Existing tests provide coverage for changes - ran `MeshEvents`
tests from `distributed_unit_tests`
  • Loading branch information
omilyutin-tt authored Feb 28, 2025
1 parent b20f868 commit e05b927
Show file tree
Hide file tree
Showing 23 changed files with 238 additions and 143 deletions.
35 changes: 14 additions & 21 deletions tests/tt_metal/distributed/test_mesh_events.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ TEST_F(MeshEventsTestSuite, ReplicatedAsyncIO) {
std::iota(src_vec.begin(), src_vec.end(), i);

std::vector<std::vector<uint32_t>> readback_vecs = {};
std::shared_ptr<MeshEvent> event = std::make_shared<MeshEvent>();
// Writes on CQ 0
EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), buf, src_vec);
// Device to Device Synchronization
EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event);
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event);
auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0));
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), write_event);

// Reads on CQ 1
for (const auto& coord : MeshCoordinateRange(mesh_device_->shape())) {
Expand Down Expand Up @@ -86,17 +85,16 @@ TEST_F(MeshEventsTestT3000, ShardedAsyncIO) {
std::vector<uint32_t> src_vec =
std::vector<uint32_t>(global_buffer_shape.height() * global_buffer_shape.width(), 0);
std::iota(src_vec.begin(), src_vec.end(), i);
std::shared_ptr<MeshEvent> event = std::make_shared<MeshEvent>();
// Writes on CQ 0
EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(0), mesh_buffer, src_vec);
if (i % 2) {
// Test Host <-> Device synchronization
EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), event);
EventSynchronize(event);
auto write_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0));
EventSynchronize(write_event);
} else {
// Test Device <-> Device synchronization
EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), event);
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), event);
auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0));
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), write_event);
}
// Reads on CQ 1
std::vector<uint32_t> dst_vec = {};
Expand Down Expand Up @@ -127,9 +125,6 @@ TEST_F(MeshEventsTestSuite, AsyncWorkloadAndIO) {
std::vector<uint32_t> src0_vec = create_constant_vector_of_bfloat16(src0_bufs[0]->size(), iter + 2);
std::vector<uint32_t> src1_vec = create_constant_vector_of_bfloat16(src1_bufs[0]->size(), iter + 3);

std::shared_ptr<MeshEvent> write_event = std::make_shared<MeshEvent>();
std::shared_ptr<MeshEvent> op_event = std::make_shared<MeshEvent>();

// Issue writes on MeshCQ 1
for (std::size_t col_idx = 0; col_idx < worker_grid_size.x; col_idx++) {
for (std::size_t row_idx = 0; row_idx < worker_grid_size.y; row_idx++) {
Expand All @@ -141,22 +136,22 @@ TEST_F(MeshEventsTestSuite, AsyncWorkloadAndIO) {
}
if (iter % 2) {
// Test Host <-> Device Synchronization
EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), write_event);
auto write_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1));
EventSynchronize(write_event);
} else {
// Test Device <-> Device Synchronization
EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), write_event);
auto write_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(1));
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), write_event);
}
// Issue workloads on MeshCQ 0
EnqueueMeshWorkload(mesh_device_->mesh_command_queue(0), mesh_workload, false);
if (iter % 2) {
// Test Device <-> Device Synchronization
EnqueueRecordEvent(mesh_device_->mesh_command_queue(0), op_event);
auto op_event = EnqueueRecordEvent(mesh_device_->mesh_command_queue(0));
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(1), op_event);
} else {
// Test Host <-> Device Synchronization
EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0), op_event);
auto op_event = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(0));
EventSynchronize(op_event);
}

Expand Down Expand Up @@ -210,12 +205,10 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) {
MeshCoordinateRange devices_1(MeshCoordinate{1, 0}, MeshCoordinate{1, mesh_device_->num_cols() - 1});

std::vector<std::vector<uint32_t>> readback_vecs = {};
std::shared_ptr<MeshEvent> event_0 = std::make_shared<MeshEvent>();
std::shared_ptr<MeshEvent> event_1 = std::make_shared<MeshEvent>();

mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_0, false);
EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), event_0, {}, devices_0);
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), event_0);
auto event0 = EnqueueRecordEvent(mesh_device_->mesh_command_queue(1), {}, devices_0);
EnqueueWaitForEvent(mesh_device_->mesh_command_queue(0), event0);

for (const auto& coord : devices_0) {
readback_vecs.push_back({});
Expand All @@ -224,8 +217,8 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) {
}

mesh_device_->mesh_command_queue(1).enqueue_write_shard_to_sub_grid(*buf, src_vec.data(), devices_1, false);
EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), event_1, {}, devices_1);
EventSynchronize(event_1);
auto event1 = EnqueueRecordEventToHost(mesh_device_->mesh_command_queue(1), {}, devices_1);
EventSynchronize(event1);

for (const auto& coord : devices_1) {
readback_vecs.push_back({});
Expand Down
10 changes: 4 additions & 6 deletions tt_metal/api/tt-metalium/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,19 @@ void EnqueueReadMeshBuffer(
mesh_cq.enqueue_read_mesh_buffer(dst.data(), mesh_buffer, blocking);
}

void EnqueueRecordEvent(
MeshEvent EnqueueRecordEvent(
MeshCommandQueue& mesh_cq,
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids = {},
const std::optional<MeshCoordinateRange>& device_range = std::nullopt);

void EnqueueRecordEventToHost(
MeshEvent EnqueueRecordEventToHost(
MeshCommandQueue& mesh_cq,
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids = {},
const std::optional<MeshCoordinateRange>& device_range = std::nullopt);

void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr<MeshEvent>& event);
void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const MeshEvent& event);

void EventSynchronize(const std::shared_ptr<MeshEvent>& event);
void EventSynchronize(const MeshEvent& event);

MeshTraceId BeginTraceCapture(MeshDevice* device, uint8_t cq_id);

Expand Down
13 changes: 5 additions & 8 deletions tt_metal/api/tt-metalium/mesh_command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ class MeshCommandQueue {
// Helper functions for read and write entire Sharded-MeshBuffers
void write_sharded_buffer(const MeshBuffer& buffer, const void* src);
void read_sharded_buffer(MeshBuffer& buffer, void* dst);
void enqueue_record_event_helper(
const std::shared_ptr<MeshEvent>& event,
MeshEvent enqueue_record_event_helper(
tt::stl::Span<const SubDeviceId> sub_device_ids,
bool notify_host,
const std::optional<MeshCoordinateRange>& device_range = std::nullopt);
Expand Down Expand Up @@ -156,17 +155,15 @@ class MeshCommandQueue {
const std::shared_ptr<MeshBuffer>& mesh_buffer,
bool blocking);

void enqueue_record_event(
const std::shared_ptr<MeshEvent>& event,
MeshEvent enqueue_record_event(
tt::stl::Span<const SubDeviceId> sub_device_ids = {},
const std::optional<MeshCoordinateRange>& device_range = std::nullopt);
void enqueue_record_event_to_host(
const std::shared_ptr<MeshEvent>& event,
MeshEvent enqueue_record_event_to_host(
tt::stl::Span<const SubDeviceId> sub_device_ids = {},
const std::optional<MeshCoordinateRange>& device_range = std::nullopt);
void enqueue_wait_for_event(const std::shared_ptr<MeshEvent>& sync_event);
void enqueue_wait_for_event(const MeshEvent& sync_event);
void drain_events_from_completion_queue();
void verify_reported_events_after_draining(const std::shared_ptr<MeshEvent>& event);
void verify_reported_events_after_draining(const MeshEvent& event);
void finish(tt::stl::Span<const SubDeviceId> sub_device_ids = {});
void reset_worker_state(
bool reset_launch_msg_state,
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/api/tt-metalium/mesh_coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class MeshCoordinateRange {
// Returns the intersection of the range with the given range.
std::optional<MeshCoordinateRange> intersection(const MeshCoordinateRange& range) const;

// Needed for reflect / fmt
static constexpr auto attribute_names = std::forward_as_tuple("start", "end");
auto attribute_values() const { return std::forward_as_tuple(start_, end_); }

class Iterator {
public:
Iterator& operator++();
Expand Down
20 changes: 16 additions & 4 deletions tt_metal/api/tt-metalium/mesh_event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@

#pragma once

#include <cstdint>
#include "mesh_device.hpp"

namespace tt::tt_metal::distributed {

class MeshEvent {
public:
MeshDevice* device = nullptr;
MeshCoordinateRange device_range = MeshCoordinateRange(MeshCoordinate(0, 0), MeshCoordinate(0, 0));
uint32_t cq_id = 0;
uint32_t event_id = 0;
MeshEvent(uint32_t id, MeshDevice* device, uint32_t mesh_cq_id, const MeshCoordinateRange& device_range);

// Returns references to the event data.
uint32_t id() const;
MeshDevice* device() const;
uint32_t mesh_cq_id() const;
const MeshCoordinateRange& device_range() const;

friend std::ostream& operator<<(std::ostream& os, const MeshEvent& event);

private:
uint32_t id_ = 0;
MeshDevice* device_ = nullptr;
uint32_t mesh_cq_id_ = 0;
MeshCoordinateRange device_range_;
};

} // namespace tt::tt_metal::distributed
9 changes: 5 additions & 4 deletions tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
set(DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/coordinate_translation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_command_queue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_trace.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_workload.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_workload_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_command_queue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_trace.cpp
${CMAKE_CURRENT_SOURCE_DIR}/system_mesh.cpp
)

add_library(distributed OBJECT ${DISTRIBUTED_SRC})
Expand Down
18 changes: 7 additions & 11 deletions tt_metal/distributed/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,24 @@ void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload,
mesh_cq.enqueue_mesh_workload(mesh_workload, blocking);
}

void EnqueueRecordEvent(
MeshEvent EnqueueRecordEvent(
MeshCommandQueue& mesh_cq,
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids,
const std::optional<MeshCoordinateRange>& device_range) {
mesh_cq.enqueue_record_event(event, sub_device_ids, device_range);
return mesh_cq.enqueue_record_event(sub_device_ids, device_range);
}

void EnqueueRecordEventToHost(
MeshEvent EnqueueRecordEventToHost(
MeshCommandQueue& mesh_cq,
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids,
const std::optional<MeshCoordinateRange>& device_range) {
mesh_cq.enqueue_record_event_to_host(event, sub_device_ids, device_range);
return mesh_cq.enqueue_record_event_to_host(sub_device_ids, device_range);
}

void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const std::shared_ptr<MeshEvent>& event) {
mesh_cq.enqueue_wait_for_event(event);
}
void EnqueueWaitForEvent(MeshCommandQueue& mesh_cq, const MeshEvent& event) { mesh_cq.enqueue_wait_for_event(event); }

void EventSynchronize(const std::shared_ptr<MeshEvent>& event) {
auto& mesh_cq = event->device->mesh_command_queue(event->cq_id);
void EventSynchronize(const MeshEvent& event) {
auto& mesh_cq = event.device()->mesh_command_queue(event.mesh_cq_id());
mesh_cq.drain_events_from_completion_queue();
mesh_cq.verify_reported_events_after_draining(event);
}
Expand Down
55 changes: 27 additions & 28 deletions tt_metal/distributed/mesh_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b
}

void MeshCommandQueue::finish(tt::stl::Span<const SubDeviceId> sub_device_ids) {
std::shared_ptr<MeshEvent> event = std::make_shared<MeshEvent>();
this->enqueue_record_event_to_host(event, sub_device_ids);
auto event = this->enqueue_record_event_to_host(sub_device_ids);
this->drain_events_from_completion_queue();
this->verify_reported_events_after_draining(event);
}
Expand Down Expand Up @@ -478,51 +477,50 @@ void MeshCommandQueue::enqueue_read_shards(
}
}

void MeshCommandQueue::enqueue_record_event_helper(
const std::shared_ptr<MeshEvent>& event,
MeshEvent MeshCommandQueue::enqueue_record_event_helper(
tt::stl::Span<const SubDeviceId> sub_device_ids,
bool notify_host,
const std::optional<MeshCoordinateRange>& device_range) {
auto& sysmem_manager = this->reference_sysmem_manager();
event->cq_id = id_;
event->event_id = sysmem_manager.get_next_event(id_);
event->device = mesh_device_;
event->device_range = device_range.value_or(MeshCoordinateRange(mesh_device_->shape()));
auto event = MeshEvent(
sysmem_manager.get_next_event(id_),
mesh_device_,
id_,
device_range.value_or(MeshCoordinateRange(mesh_device_->shape())));

sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids);
for (const auto& coord : event->device_range) {
for (const auto& coord : event.device_range()) {
event_dispatch::issue_record_event_commands(
mesh_device_,
event->event_id,
event.id(),
id_,
mesh_device_->num_hw_cqs(),
mesh_device_->get_device(coord)->sysmem_manager(),
sub_device_ids,
expected_num_workers_completed_,
notify_host);
}

return event;
}

void MeshCommandQueue::enqueue_record_event(
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids,
const std::optional<MeshCoordinateRange>& device_range) {
this->enqueue_record_event_helper(event, sub_device_ids, false, device_range);
MeshEvent MeshCommandQueue::enqueue_record_event(
tt::stl::Span<const SubDeviceId> sub_device_ids, const std::optional<MeshCoordinateRange>& device_range) {
return this->enqueue_record_event_helper(sub_device_ids, /*notify_host=*/false, device_range);
}

void MeshCommandQueue::enqueue_record_event_to_host(
const std::shared_ptr<MeshEvent>& event,
tt::stl::Span<const SubDeviceId> sub_device_ids,
const std::optional<MeshCoordinateRange>& device_range) {
this->enqueue_record_event_helper(event, sub_device_ids, true, device_range);
MeshEvent MeshCommandQueue::enqueue_record_event_to_host(
tt::stl::Span<const SubDeviceId> sub_device_ids, const std::optional<MeshCoordinateRange>& device_range) {
auto event = this->enqueue_record_event_helper(sub_device_ids, /*notify_host=*/true, device_range);
event_descriptors_.push(std::make_shared<MeshReadEventDescriptor>(MeshReadEventDescriptor{
.single_device_descriptor = ReadEventDescriptor(event->event_id), .device_range = event->device_range}));
.single_device_descriptor = ReadEventDescriptor(event.id()), .device_range = event.device_range()}));
return event;
}

void MeshCommandQueue::enqueue_wait_for_event(const std::shared_ptr<MeshEvent>& sync_event) {
for (const auto& coord : sync_event->device_range) {
void MeshCommandQueue::enqueue_wait_for_event(const MeshEvent& sync_event) {
for (const auto& coord : sync_event.device_range()) {
event_dispatch::issue_wait_for_event_commands(
id_, sync_event->cq_id, mesh_device_->get_device(coord)->sysmem_manager(), sync_event->event_id);
id_, sync_event.mesh_cq_id(), mesh_device_->get_device(coord)->sysmem_manager(), sync_event.id());
}
}

Expand All @@ -546,13 +544,14 @@ void MeshCommandQueue::drain_events_from_completion_queue() {
}
}

void MeshCommandQueue::verify_reported_events_after_draining(const std::shared_ptr<MeshEvent>& event) {
auto& device_range = event->device_range;
void MeshCommandQueue::verify_reported_events_after_draining(const MeshEvent& event) {
auto& device_range = event.device_range();
for (const auto& coord : device_range) {
TT_FATAL(
mesh_device_->get_device(coord)->sysmem_manager().get_last_completed_event(event->cq_id) >= event->event_id,
mesh_device_->get_device(coord)->sysmem_manager().get_last_completed_event(event.mesh_cq_id()) >=
event.id(),
"Expected to see event id {} in completion queue",
event->event_id);
event.id());
}
}

Expand Down
23 changes: 23 additions & 0 deletions tt_metal/distributed/mesh_event.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <mesh_event.hpp>

namespace tt::tt_metal::distributed {

MeshEvent::MeshEvent(uint32_t id, MeshDevice* device, uint32_t mesh_cq_id, const MeshCoordinateRange& device_range) :
id_(id), device_(device), mesh_cq_id_(mesh_cq_id), device_range_(device_range) {}

uint32_t MeshEvent::id() const { return id_; }
MeshDevice* MeshEvent::device() const { return device_; }
uint32_t MeshEvent::mesh_cq_id() const { return mesh_cq_id_; }
const MeshCoordinateRange& MeshEvent::device_range() const { return device_range_; }

std::ostream& operator<<(std::ostream& os, const MeshEvent& event) {
os << "MeshEvent(id=" << event.id() << ", device_id=" << event.device()->id()
<< ", mesh_cq_id=" << event.mesh_cq_id() << ", device_range=" << event.device_range() << ")";
return os;
}

} // namespace tt::tt_metal::distributed
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/hardware_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ void HWCommandQueue::enqueue_record_event(
event->cq_id = this->id_;
event->event_id = this->manager.get_next_event(this->id_);
event->device = this->device_;
event->ready = true; // what does this mean???
event->ready = true;

sub_device_ids = buffer_dispatch::select_sub_device_ids(this->device_, sub_device_ids);
event_dispatch::issue_record_event_commands(
Expand Down
Loading

0 comments on commit e05b927

Please sign in to comment.