Skip to content

Commit

Permalink
#0: Use StrongType for SubDevice ID types (#18366)
Browse files Browse the repository at this point in the history
### Ticket
N/A

### Problem description
SubDevice ID types duplicate what StrongType is supposed to provide.
There are 2 problems:
1. Using default constructor of both `SubDeviceId` and
`SubDeviceManagerId` is
[UB](https://github.com/tenstorrent/tt-metal/blob/4910164b2d7860b810121c02f5d1f42b345c1f39/contributing/BestPractices.md#15-initialize-primitive-types-on-declaration).
2. Adding / subtracting IDs is not semantically valid. In some cases
when we do enumerate IDs in a linear way, an explicit conversion to/from
the underlying type should be used, to highlight the logic.

### What's changed
Adopt StrongType for the SubDevice ID types.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13551182295),
[compiler error
fixed](https://github.com/tenstorrent/tt-metal/actions/runs/13572261235)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Feb 27, 2025
1 parent 0548b09 commit 9b3d855
Show file tree
Hide file tree
Showing 18 changed files with 102 additions and 165 deletions.
8 changes: 7 additions & 1 deletion tt_metal/api/tt-metalium/strong_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

#pragma once

#include <utility>
#include <ostream>
#include <tuple>
#include <utility>

namespace tt::stl {

Expand Down Expand Up @@ -55,6 +56,8 @@ namespace tt::stl {
template <typename T, typename Tag>
class StrongType {
public:
using value_type = T;

constexpr explicit StrongType(T v) : value_(std::move(v)) {}

StrongType(const StrongType&) = default;
Expand All @@ -66,6 +69,9 @@ class StrongType {

auto operator<=>(const StrongType&) const = default;

static constexpr auto attribute_names = std::forward_as_tuple("value");
auto attribute_values() const { return std::forward_as_tuple(value_); }

private:
T value_;
};
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/api/tt-metalium/sub_device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SubDeviceManager {
public:
static constexpr uint32_t MAX_NUM_SUB_DEVICES = 16;
static_assert(
MAX_NUM_SUB_DEVICES <= std::numeric_limits<SubDeviceId::Id>::max(),
MAX_NUM_SUB_DEVICES <= std::numeric_limits<SubDeviceId::value_type>::max(),
"MAX_NUM_SUB_DEVICES must be less than or equal to the max value of SubDeviceId::Id");
// Constructor used for the default/global device
SubDeviceManager(
Expand Down
84 changes: 4 additions & 80 deletions tt_metal/api/tt-metalium/sub_device_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,12 @@
#pragma once

#include <cstdint>
#include <functional>
#include <tuple>
#include <type_traits>

namespace tt::tt_metal {

struct SubDeviceId {
using Id = uint8_t;
Id id;

Id to_index() const { return id; }

SubDeviceId& operator++() {
id++;
return *this;
}

SubDeviceId operator++(int) {
auto ret = *this;
this->operator++();
return ret;
}

SubDeviceId& operator+=(Id n) {
id += n;
return *this;
}

bool operator<(size_t other) const { return id < other; }
bool operator==(const SubDeviceId& other) const { return id == other.id; }

bool operator!=(const SubDeviceId& other) const { return id != other.id; }

static constexpr auto attribute_names = std::forward_as_tuple("id");
constexpr auto attribute_values() const { return std::forward_as_tuple(this->id); }
};

struct SubDeviceManagerId {
using Id = uint64_t;
Id id;
#include "strong_type.hpp"

Id to_index() const { return id; }

SubDeviceManagerId& operator++() {
id++;
return *this;
}

SubDeviceManagerId operator++(int) {
auto ret = *this;
this->operator++();
return ret;
}

SubDeviceManagerId& operator+=(Id n) {
id += n;
return *this;
}

bool operator==(const SubDeviceManagerId& other) const { return id == other.id; }

bool operator!=(const SubDeviceManagerId& other) const { return id != other.id; }
namespace tt::tt_metal {

static constexpr auto attribute_names = std::forward_as_tuple("id");
constexpr auto attribute_values() const { return std::forward_as_tuple(this->id); }
};
using SubDeviceId = tt::stl::StrongType<uint8_t, struct SubDeviceIdTag>;
using SubDeviceManagerId = tt::stl::StrongType<uint64_t, struct SubDeviceManagerIdTag>;

} // namespace tt::tt_metal

namespace std {
template <>
struct hash<tt::tt_metal::SubDeviceId> {
std::size_t operator()(tt::tt_metal::SubDeviceId const& o) const {
return std::hash<decltype(tt::tt_metal::SubDeviceId::id)>{}(o.to_index());
}
};

template <>
struct hash<tt::tt_metal::SubDeviceManagerId> {
std::size_t operator()(tt::tt_metal::SubDeviceManagerId const& o) const {
return std::hash<decltype(tt::tt_metal::SubDeviceManagerId::id)>{}(o.to_index());
}
};

} // namespace std
17 changes: 8 additions & 9 deletions tt_metal/distributed/mesh_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ CoreType MeshCommandQueue::dispatch_core_type() const { return this->dispatch_co
void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking) {
std::unordered_set<SubDeviceId> sub_device_ids = mesh_workload.determine_sub_device_ids(mesh_device_);
TT_FATAL(sub_device_ids.size() == 1, "Programs must be executed on a single sub-device");
auto sub_device_id = *(sub_device_ids.begin());
auto sub_device_index = sub_device_id.to_index();
SubDeviceId sub_device_id = *(sub_device_ids.begin());
auto mesh_device_id = this->mesh_device_->id();
auto& sysmem_manager = this->reference_sysmem_manager();
auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config();
Expand All @@ -96,10 +95,10 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b
program_dispatch::ProgramDispatchMetadata dispatch_metadata;
uint32_t expected_num_workers_completed = sysmem_manager.get_bypass_mode()
? trace_ctx_->descriptors[sub_device_id].num_completion_worker_cores
: expected_num_workers_completed_[sub_device_index];
: expected_num_workers_completed_[*sub_device_id];
// Reserve space in the L1 Kernel Config Ring Buffer for this workload.
program_dispatch::reserve_space_in_kernel_config_buffer(
this->get_config_buffer_mgr(sub_device_index),
this->get_config_buffer_mgr(*sub_device_id),
mesh_workload.get_program_config_sizes(),
mesh_workload.get_program_binary_status(mesh_device_id),
num_workers,
Expand All @@ -116,8 +115,8 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b
program_dispatch::update_program_dispatch_commands(
program,
program_cmd_seq,
sysmem_manager.get_worker_launch_message_buffer_state()[sub_device_index].get_mcast_wptr(),
sysmem_manager.get_worker_launch_message_buffer_state()[sub_device_index].get_unicast_wptr(),
sysmem_manager.get_worker_launch_message_buffer_state()[*sub_device_id].get_mcast_wptr(),
sysmem_manager.get_worker_launch_message_buffer_state()[*sub_device_id].get_unicast_wptr(),
expected_num_workers_completed,
this->virtual_program_dispatch_core(),
dispatch_core_type,
Expand Down Expand Up @@ -160,10 +159,10 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b
}
// Increment Launch Message Buffer Write Pointers
if (mcast_go_signals) {
sysmem_manager.get_worker_launch_message_buffer_state()[sub_device_index].inc_mcast_wptr(1);
sysmem_manager.get_worker_launch_message_buffer_state()[*sub_device_id].inc_mcast_wptr(1);
}
if (unicast_go_signals) {
sysmem_manager.get_worker_launch_message_buffer_state()[sub_device_index].inc_unicast_wptr(1);
sysmem_manager.get_worker_launch_message_buffer_state()[*sub_device_id].inc_unicast_wptr(1);
}

if (sysmem_manager.get_bypass_mode()) {
Expand All @@ -176,7 +175,7 @@ void MeshCommandQueue::enqueue_mesh_workload(MeshWorkload& mesh_workload, bool b
// Update the expected number of workers dispatch must wait on
trace_ctx_->descriptors[sub_device_id].num_completion_worker_cores += num_workers;
} else {
expected_num_workers_completed_[sub_device_index] += num_workers;
expected_num_workers_completed_[*sub_device_id] += num_workers;
}
// From the dispatcher's perspective, binaries are now committed to DRAM
mesh_workload.set_program_binary_status(mesh_device_id, ProgramBinaryStatus::Committed);
Expand Down
8 changes: 5 additions & 3 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ MeshSubDeviceManagerId MeshDevice::mesh_create_sub_device_manager(

std::tuple<MeshSubDeviceManagerId, SubDeviceId> MeshDevice::mesh_create_sub_device_manager_with_fabric(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
SubDeviceId fabric_sub_device_id;
SubDeviceId fabric_sub_device_id(0);
const auto& devices = scoped_devices_->root_devices();
for (uint32_t i = 0; i < devices.size(); i++) {
auto* device = devices[i];
Expand Down Expand Up @@ -838,8 +838,10 @@ void MeshDevice::mesh_reset_sub_device_stall_group() {
}

MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) {
this->sub_device_manager_ids.resize(mesh_device.num_devices());
this->sub_device_manager_ids.reserve(mesh_device.num_devices());
for (uint32_t i = 0; i < mesh_device.num_devices(); i++) {
this->sub_device_manager_ids.push_back(SubDeviceManagerId(0));
}
}


} // namespace tt::tt_metal::distributed
2 changes: 1 addition & 1 deletion tt_metal/distributed/mesh_workload_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void write_go_signal(

auto dispatch_core_config = DispatchQueryManager::instance().get_dispatch_core_config();
CoreType dispatch_core_type = dispatch_core_config.get_core_type();
auto sub_device_index = sub_device_id.to_index();
auto sub_device_index = *sub_device_id;

HugepageDeviceCommand go_signal_cmd_sequence(cmd_region, cmd_sequence_sizeB);
go_msg_t run_program_go_signal;
Expand Down
11 changes: 4 additions & 7 deletions tt_metal/impl/buffers/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void issue_buffer_dispatch_command_sequence(
DispatchMemMap::get(dispatch_core_type)
.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
for (const auto& sub_device_id : sub_device_ids) {
auto offset_index = sub_device_id.to_index();
auto offset_index = *sub_device_id;
uint32_t dispatch_message_addr =
dispatch_message_base_addr +
DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index);
Expand Down Expand Up @@ -640,14 +640,14 @@ void issue_read_buffer_dispatch_command_sequence(
uint32_t last_index = num_worker_counters - 1;
// We only need the write barrier + prefetch stall for the last wait cmd
for (uint32_t i = 0; i < last_index; ++i) {
auto offset_index = sub_device_ids[i].to_index();
auto offset_index = *sub_device_ids[i];
uint32_t dispatch_message_addr =
dispatch_message_base_addr +
DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index);
command_sequence.add_dispatch_wait(
false, dispatch_message_addr, dispatch_params.expected_num_workers_completed[offset_index]);
}
auto offset_index = sub_device_ids[last_index].to_index();
auto offset_index = *sub_device_ids[last_index];
uint32_t dispatch_message_addr =
dispatch_message_base_addr + DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index);
command_sequence.add_dispatch_wait_with_prefetch_stall(
Expand Down Expand Up @@ -997,10 +997,7 @@ tt::stl::Span<const SubDeviceId> select_sub_device_ids(
return device->get_sub_device_stall_group();
} else {
for (const auto& sub_device_id : sub_device_ids) {
TT_FATAL(
sub_device_id.to_index() < device->num_sub_devices(),
"Invalid sub-device id specified {}",
sub_device_id.to_index());
TT_FATAL(*sub_device_id < device->num_sub_devices(), "Invalid sub-device id specified {}", *sub_device_id);
}
return sub_device_ids;
}
Expand Down
7 changes: 3 additions & 4 deletions tt_metal/impl/dispatch/hardware_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) {
}
#endif
auto sub_device_id = sub_device_ids[0];
auto sub_device_index = sub_device_id.to_index();
auto sub_device_index = *sub_device_id;

// Snapshot of expected workers from previous programs, used for dispatch_wait cmd generation.
uint32_t expected_workers_completed = this->manager.get_bypass_mode()
Expand Down Expand Up @@ -325,8 +325,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) {
}
}

auto& worker_launch_message_buffer_state =
this->manager.get_worker_launch_message_buffer_state()[sub_device_id.to_index()];
auto& worker_launch_message_buffer_state = this->manager.get_worker_launch_message_buffer_state()[*sub_device_id];
auto command = EnqueueProgramCommand(
this->id_,
this->device_,
Expand Down Expand Up @@ -560,7 +559,7 @@ void HWCommandQueue::record_end() {
// separately
this->trace_ctx->sub_device_ids.reserve(this->trace_ctx->descriptors.size());
for (const auto& [id, _] : this->trace_ctx->descriptors) {
auto index = id.to_index();
auto index = *id;
this->trace_ctx->sub_device_ids.push_back(id);
}
this->tid_ = std::nullopt;
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/host_runtime_commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ EnqueueProgramCommand::EnqueueProgramCommand(
this->dispatch_message_addr =
DispatchMemMap::get(this->dispatch_core_type)
.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) +
DispatchMemMap::get(this->dispatch_core_type).get_dispatch_message_offset(this->sub_device_id.to_index());
DispatchMemMap::get(this->dispatch_core_type).get_dispatch_message_offset(*this->sub_device_id);
}

void EnqueueProgramCommand::process() {
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/event/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void issue_record_event_commands(

uint32_t last_index = num_worker_counters - 1;
for (uint32_t i = 0; i < num_worker_counters; ++i) {
auto offset_index = sub_device_ids[i].to_index();
auto offset_index = *sub_device_ids[i];
uint32_t dispatch_message_addr =
dispatch_message_base_addr +
DispatchMemMap::get(dispatch_core_type).get_dispatch_message_offset(offset_index);
Expand Down
8 changes: 6 additions & 2 deletions tt_metal/impl/flatbuffer/program_types_from_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config) {
}

std::vector<SubDeviceId> from_flatbuffer(const flatbuffers::Vector<uint8_t>* fb_sub_device_ids) {
std::vector<SubDeviceId> sub_device_ids(fb_sub_device_ids ? fb_sub_device_ids->size() : 0);
if (!fb_sub_device_ids) {
return {};
}

std::vector<SubDeviceId> sub_device_ids;
sub_device_ids.reserve(fb_sub_device_ids->size());
for (size_t i = 0; i < sub_device_ids.size(); ++i) {
sub_device_ids[i] = SubDeviceId{(*fb_sub_device_ids)[i]};
sub_device_ids.push_back(SubDeviceId{(*fb_sub_device_ids)[i]});
}

return sub_device_ids;
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> to_flatbuffer(
flatbuffers::FlatBufferBuilder& builder, tt::stl::Span<const SubDeviceId> sub_device_ids) {
std::vector<uint8_t> fb_sub_device_ids(sub_device_ids.size());
for (size_t i = 0; i < sub_device_ids.size(); ++i) {
fb_sub_device_ids[i] = sub_device_ids[i].id;
fb_sub_device_ids[i] = *sub_device_ids[i];
}
return builder.CreateVector(fb_sub_device_ids);
}
Expand Down
3 changes: 1 addition & 2 deletions tt_metal/impl/lightmetal/host_api_capture_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ void CaptureBufferCreate(
// and one without, so commonize via single capture function and schema and treat it as optional.
auto address_offset = address.has_value() ? flatbuffer::CreateUint32Optional(fbb, address.value()) : 0;
auto bottom_up_offset = bottom_up.has_value() ? flatbuffer::CreateBoolOptional(fbb, bottom_up.value()) : 0;
auto sub_device_id_offset =
sub_device_id.has_value() ? flatbuffer::CreateUint8Optional(fbb, sub_device_id.value().id) : 0;
auto sub_device_id_offset = sub_device_id.has_value() ? flatbuffer::CreateUint8Optional(fbb, **sub_device_id) : 0;
auto shard_parameters_offset = to_flatbuffer(shard_parameters, fbb);

auto cmd = tt::tt_metal::flatbuffer::CreateBufferCreateCommand(
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/impl/lightmetal/lightmetal_capture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ TraceDescriptorByTraceIdOffset to_flatbuffer(
descriptor.num_traced_programs_needing_go_signal_unicast);
auto mapping_offset = tt::tt_metal::flatbuffer::CreateSubDeviceDescriptorMapping(
builder,
sub_device_id.to_index(), // No need for static_cast; directly use uint8_t
*sub_device_id, // No need for static_cast; directly use uint8_t
descriptor_offset);
sub_device_descriptor_offsets.push_back(mapping_offset);
}
Expand All @@ -218,7 +218,7 @@ TraceDescriptorByTraceIdOffset to_flatbuffer(
std::vector<uint8_t> sub_device_ids_converted;
sub_device_ids_converted.reserve(trace_desc.sub_device_ids.size());
for (const auto& sub_device_id : trace_desc.sub_device_ids) {
sub_device_ids_converted.push_back(sub_device_id.to_index());
sub_device_ids_converted.push_back(*sub_device_id);
}
auto sub_device_ids_offset = builder.CreateVector(sub_device_ids_converted);

Expand Down
Loading

0 comments on commit 9b3d855

Please sign in to comment.