diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index 09853a488a0..9c364c735b4 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "mesh_coord.hpp" @@ -11,7 +12,7 @@ namespace tt::tt_metal::distributed { namespace { using ::testing::ElementsAre; - +using ::testing::UnorderedElementsAre; TEST(SimpleMeshShapeTest, Construction) { SimpleMeshShape shape_1d(3); EXPECT_EQ(shape_1d.dims(), 1); @@ -100,6 +101,21 @@ TEST(MeshCoordinateTest, Comparison) { EXPECT_NE(coord1, MeshCoordinate(1, 2, 1)); } +TEST(MeshCoordinateTest, UnorderedSet) { + std::unordered_set set; + set.insert(MeshCoordinate(0, 0, 0)); + set.insert(MeshCoordinate(0, 0, 1)); + set.insert(MeshCoordinate(0, 0, 2)); + + EXPECT_FALSE(set.insert(MeshCoordinate(0, 0, 2)).second); + EXPECT_THAT( + set, + UnorderedElementsAre( + MeshCoordinate(0, 0, 0), // + MeshCoordinate(0, 0, 1), + MeshCoordinate(0, 0, 2))); +} + TEST(MeshCoordinateRangeTest, FromShape) { SimpleMeshShape shape(2, 3); MeshCoordinateRange range(shape); @@ -232,6 +248,7 @@ TEST(MeshContainerTest, ElementAccessRowMajor) { MeshCoordinate(1, 1), MeshCoordinate(1, 2))); EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); + EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5)); } TEST(MeshContainerTest, ConstContainer) { @@ -254,6 +271,7 @@ TEST(MeshContainerTest, ConstContainer) { MeshCoordinate(1, 1), MeshCoordinate(1, 2))); EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0)); + EXPECT_THAT(container.values(), ElementsAre(0, 0, 0, 0, 0, 0)); } TEST(MeshContainerTest, MutateThroughProxy) { @@ -276,6 +294,7 @@ TEST(MeshContainerTest, MutateThroughProxy) { values.push_back(value); } EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); + EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5)); } TEST(MeshContainerTest, OutOfBounds) { diff --git a/tests/ttnn/distributed/test_distributed.cpp b/tests/ttnn/distributed/test_distributed.cpp index cb4d22448c5..f6e4cf7d5da 100644 --- a/tests/ttnn/distributed/test_distributed.cpp +++ b/tests/ttnn/distributed/test_distributed.cpp @@ -4,7 +4,6 @@ #include -#include #include #include @@ -19,11 +18,16 @@ class DistributedTest : public ::testing::Test { TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) { auto& sys = SystemMesh::instance(); auto mesh = ttnn::distributed::open_mesh_device( - {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); - - auto [rows, cols] = sys.get_shape(); - EXPECT_GT(rows, 0); - EXPECT_GT(cols, 0); + /*mesh_shape=*/{2, 4}, + DEFAULT_L1_SMALL_SIZE, + DEFAULT_TRACE_REGION_SIZE, + 1, + tt::tt_metal::DispatchCoreType::WORKER); + + const auto system_shape = sys.get_shape(); + ASSERT_EQ(system_shape.dims(), 2); + EXPECT_EQ(system_shape[0], 2); + EXPECT_EQ(system_shape[1], 4); } TEST_F(DistributedTest, TestMemoryAllocationStatistics) { diff --git a/tests/ttnn/distributed/test_distributed_atexit.cpp b/tests/ttnn/distributed/test_distributed_atexit.cpp index 283076076b2..6d4461f7386 100644 --- a/tests/ttnn/distributed/test_distributed_atexit.cpp +++ b/tests/ttnn/distributed/test_distributed_atexit.cpp @@ -18,9 +18,10 @@ TEST(DistributedTestStandalone, TestSystemMeshTearDownWithoutClose) { mesh = ttnn::distributed::open_mesh_device( {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); - auto [rows, cols] = sys.get_shape(); - EXPECT_GT(rows, 0); - EXPECT_GT(cols, 0); + const auto system_shape = sys.get_shape(); + ASSERT_EQ(system_shape.dims(), 2); + EXPECT_EQ(system_shape[0], 2); + EXPECT_EQ(system_shape[1], 4); } } // namespace ttnn::distributed::test diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index e346ce2ca83..5160bdb745f 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -9,6 +9,7 @@ #include #include "shape_base.hpp" +#include "utils.hpp" namespace tt::tt_metal::distributed { @@ -21,7 +22,7 @@ class SimpleMeshShape : public ShapeBase { using ShapeBase::operator[]; // Shorthands for constructing 1D, 2D and 3D shapes. - SimpleMeshShape(uint32_t x); + explicit SimpleMeshShape(uint32_t x); SimpleMeshShape(uint32_t x, uint32_t y); SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z); @@ -56,7 +57,7 @@ class SimpleMeshShape : public ShapeBase { class MeshCoordinate { public: // Shorthands for constructing 1D, 2D and 3D coordinates. - MeshCoordinate(uint32_t x); + explicit MeshCoordinate(uint32_t x); MeshCoordinate(uint32_t x, uint32_t y); MeshCoordinate(uint32_t x, uint32_t y, uint32_t z); @@ -199,7 +200,10 @@ class MeshContainer { using ValueProxy = detail::MeshCoordinateValueProxy; Iterator& operator++(); - ValueProxy& operator*(); + ValueProxy& operator*() { return value_proxy_; } + const ValueProxy& operator*() const { return value_proxy_; } + ValueProxy* operator->() { return &value_proxy_; } + const ValueProxy* operator->() const { return &value_proxy_; } bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; @@ -220,7 +224,8 @@ class MeshContainer { using ValueProxy = detail::MeshCoordinateValueProxy; ConstIterator& operator++(); - const ValueProxy& operator*() const; + const ValueProxy& operator*() const { return value_proxy_; } + const ValueProxy* operator->() const { return &value_proxy_; } bool operator==(const ConstIterator& other) const; bool operator!=(const ConstIterator& other) const; @@ -237,11 +242,16 @@ class MeshContainer { ValueProxy value_proxy_; }; + // Iterators provide a reference to the value along with the coordinate. Iterator begin(); Iterator end(); ConstIterator begin() const; ConstIterator end() const; + // View of the flat container of values. + std::vector& values() { return values_; } + const std::vector& values() const { return values_; } + private: SimpleMeshShape shape_; MeshCoordinateRange coord_range_; @@ -283,11 +293,6 @@ typename MeshContainer::Iterator& MeshContainer::Iterator::operator++() { return *this; } -template -typename MeshContainer::Iterator::ValueProxy& MeshContainer::Iterator::operator*() { - return value_proxy_; -} - template MeshContainer::ConstIterator::ConstIterator( const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : @@ -304,11 +309,6 @@ typename MeshContainer::ConstIterator& MeshContainer::ConstIterator::opera return *this; } -template -const typename MeshContainer::ConstIterator::ValueProxy& MeshContainer::ConstIterator::operator*() const { - return value_proxy_; -} - template bool MeshContainer::Iterator::operator==(const Iterator& other) const { return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; @@ -367,4 +367,15 @@ struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValuePr using type = T; }; +template <> +struct hash { + size_t operator()(const tt::tt_metal::distributed::MeshCoordinate& coord) const noexcept { + size_t seed = 0; + for (const auto coord_value : coord.coords()) { + tt::utils::hash_combine(seed, coord_value); + } + return seed; + } +}; + } // namespace std diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 979e603a6cd..1ff63629b16 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -33,7 +33,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this opened_devices_; - std::vector devices_; + MeshContainer devices_; public: // Constructor acquires physical resources @@ -50,6 +50,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this& get_devices() const; + IDevice* get_device(const MeshCoordinate& coord) const; }; std::shared_ptr scoped_devices_; @@ -202,7 +203,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this get_devices() const; - IDevice* get_device_index(size_t logical_device_id) const; IDevice* get_device(chip_id_t physical_device_id) const; IDevice* get_device(size_t row_idx, size_t col_idx) const; IDevice* get_device(const MeshCoordinate& coord) const; diff --git a/tt_metal/api/tt-metalium/system_mesh.hpp b/tt_metal/api/tt-metalium/system_mesh.hpp index 64c040edf82..1ee91588dcc 100644 --- a/tt_metal/api/tt-metalium/system_mesh.hpp +++ b/tt_metal/api/tt-metalium/system_mesh.hpp @@ -8,8 +8,7 @@ #include #include "mesh_config.hpp" -#include "mesh_device.hpp" -#include "device.hpp" +#include "mesh_coord.hpp" namespace tt::tt_metal::distributed { @@ -21,7 +20,6 @@ class SystemMesh { class Impl; // Forward declaration only std::unique_ptr pimpl_; SystemMesh(); - ~SystemMesh(); public: static SystemMesh& instance(); @@ -30,11 +28,10 @@ class SystemMesh { SystemMesh(SystemMesh&&) = delete; SystemMesh& operator=(SystemMesh&&) = delete; - const MeshShape& get_shape() const; - size_t get_num_devices() const; + const SimpleMeshShape& get_shape() const; // Gets the physical device ID for a given logical row and column index - chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const; + chip_id_t get_physical_device_id(const MeshCoordinate& coord) const; // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig& config) const; diff --git a/tt_metal/distributed/coordinate_translation.cpp b/tt_metal/distributed/coordinate_translation.cpp index e834ae37e2d..2070a138ed0 100644 --- a/tt_metal/distributed/coordinate_translation.cpp +++ b/tt_metal/distributed/coordinate_translation.cpp @@ -36,7 +36,7 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const TT_THROW("Invalid coordinate format in JSON file: {}", filename); } result.emplace( - Coordinate{mapping[0][0], mapping[0][1]}, + MeshCoordinate(mapping[0][0], mapping[0][1]), PhysicalCoordinate{ mapping[1][0], // cluster_id mapping[1][2], // x @@ -49,49 +49,39 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const return result; } -MeshShape get_system_mesh_shape(size_t system_num_devices) { - static const std::unordered_map system_mesh_to_shape = { - {1, MeshShape{1, 1}}, // single-device - {2, MeshShape{1, 2}}, // N300 - {8, MeshShape{2, 4}}, // T3000; as ring to match existing tests - {32, MeshShape{8, 4}}, // TG, QG - {64, MeshShape{8, 8}}, // TGG - }; - TT_FATAL( - system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); - auto shape = system_mesh_to_shape.at(system_num_devices); - log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols); - return shape; -} - } // namespace -std::pair get_system_mesh_coordinate_translation_map() { - static const auto* cached_translation_map = new std::pair([] { - auto system_num_devices = tt::Cluster::instance().number_of_user_devices(); +const std::pair& get_system_mesh_coordinate_translation_map() { + static const auto* cached_translation_map = new std::pair([] { + const auto system_num_devices = tt::Cluster::instance().number_of_user_devices(); - std::string galaxy_mesh_descriptor = "TG.json"; - if (tt::Cluster::instance().number_of_pci_devices() == system_num_devices) { - galaxy_mesh_descriptor = "QG.json"; - } + const bool is_qg = tt::Cluster::instance().number_of_pci_devices() == system_num_devices; - const std::unordered_map system_mesh_translation_map = { - {1, "device.json"}, - {2, "N300.json"}, - {8, "T3000.json"}, - {32, galaxy_mesh_descriptor}, - {64, "TGG.json"}, + // TODO: #17477 - This assumes shapes and coordinates are in 2D. This will be extended for 3D. + // Consider if 1D can be used for single device and N300. + const std::unordered_map> system_mesh_translation_map = { + {1, std::make_pair("device.json", SimpleMeshShape(1, 1))}, + {2, std::make_pair("N300.json", SimpleMeshShape(1, 2))}, + {8, std::make_pair("T3000.json", SimpleMeshShape(2, 4))}, + {32, std::make_pair(is_qg ? "QG.json" : "TG.json", SimpleMeshShape(8, 4))}, + {64, std::make_pair("TGG.json", SimpleMeshShape(8, 8))}, }; - TT_FATAL( system_mesh_translation_map.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); - auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices)); - return std::pair{ - load_translation_map(translation_config_file, "logical_to_physical_coordinates"), - get_system_mesh_shape(system_num_devices)}; + const auto [translation_config_file, shape] = system_mesh_translation_map.at(system_num_devices); + TT_FATAL( + system_num_devices == shape.mesh_size(), + "Mismatch between number of devices and the mesh shape: {} != {}", + system_num_devices, + shape.mesh_size()); + log_debug(LogMetal, "Logical SystemMesh Shape: {}", shape); + + return std::pair{ + load_translation_map(get_config_path(translation_config_file), /*key=*/"logical_to_physical_coordinates"), + shape}; }()); return *cached_translation_map; diff --git a/tt_metal/distributed/coordinate_translation.hpp b/tt_metal/distributed/coordinate_translation.hpp index b4fc5c21b85..5aa0f7242f0 100644 --- a/tt_metal/distributed/coordinate_translation.hpp +++ b/tt_metal/distributed/coordinate_translation.hpp @@ -7,17 +7,18 @@ #include #include "umd/device/types/cluster_descriptor_types.h" +#include #include namespace tt::tt_metal::distributed { // TODO: Consider conversion to StrongType instead of alias -using LogicalCoordinate = Coordinate; using PhysicalCoordinate = eth_coord_t; -using CoordinateTranslationMap = std::unordered_map; +using CoordinateTranslationMap = std::unordered_map; -// Returns a translation map between logical coordinates in logical 2D space +// Returns a translation map between logical coordinates in logical ND space // to the physical coordinates as defined by the UMD layer. -std::pair get_system_mesh_coordinate_translation_map(); +// TODO: #17477 - Return MeshContainer that contains everything we need. +const std::pair& get_system_mesh_coordinate_translation_map(); } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 603ce95212e..63cf7a6621a 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -68,27 +68,36 @@ MeshDevice::ScopedDevices::ScopedDevices( size_t trace_region_size, size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, - const MeshDeviceConfig& config) { + const MeshDeviceConfig& config) : + devices_(SimpleMeshShape(config.mesh_shape), /*fill_value=*/nullptr) { auto& system_mesh = SystemMesh::instance(); auto physical_device_ids = system_mesh.request_available_devices(config); opened_devices_ = tt::tt_metal::detail::CreateDevices( physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_config); + TT_FATAL( + physical_device_ids.size() == devices_.shape().mesh_size(), + "Device size mismatch; expected: {}, actual: {}", + devices_.shape().mesh_size(), + opened_devices_.size()); + + auto it = devices_.begin(); for (auto physical_device_id : physical_device_ids) { - devices_.push_back(opened_devices_.at(physical_device_id)); + it->value() = opened_devices_.at(physical_device_id); + ++it; } } MeshDevice::ScopedDevices::~ScopedDevices() { - if (not opened_devices_.empty()) { + if (!opened_devices_.empty()) { tt::tt_metal::detail::CloseDevices(opened_devices_); - opened_devices_.clear(); - devices_.clear(); } } -const std::vector& MeshDevice::ScopedDevices::get_devices() const { return devices_; } +const std::vector& MeshDevice::ScopedDevices::get_devices() const { return devices_.values(); } + +IDevice* MeshDevice::ScopedDevices::get_device(const MeshCoordinate& coord) const { return devices_.at(coord); } uint8_t MeshDevice::num_hw_cqs() const { return validate_and_get_reference_value( @@ -192,12 +201,6 @@ std::vector> MeshDevice::create_submeshes(const Mesh MeshDevice::~MeshDevice() {} -IDevice* MeshDevice::get_device_index(size_t device_index) const { - TT_FATAL(device_index >= 0 and device_index < num_devices(), "Invalid device index"); - const auto& devices = scoped_devices_->get_devices(); - return devices.at(device_index); -} - IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const { for (auto device : this->get_devices()) { if (device->id() == physical_device_id) { @@ -214,9 +217,7 @@ IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { return get_device(MeshCoordinate{row_idx, col_idx}); } -IDevice* MeshDevice::get_device(const MeshCoordinate& coord) const { - return this->get_device_index(to_linear_index(SimpleMeshShape(mesh_shape_), coord)); -} +IDevice* MeshDevice::get_device(const MeshCoordinate& coord) const { return scoped_devices_->get_device(coord); } MeshCommandQueue& MeshDevice::mesh_command_queue(std::size_t cq_id) const { TT_FATAL(this->using_fast_dispatch(), "Can only access the MeshCommandQueue when using Fast Dispatch."); diff --git a/tt_metal/distributed/mesh_workload.cpp b/tt_metal/distributed/mesh_workload.cpp index 21fd77cc409..a9efcb406c7 100644 --- a/tt_metal/distributed/mesh_workload.cpp +++ b/tt_metal/distributed/mesh_workload.cpp @@ -257,12 +257,11 @@ uint32_t MeshWorkload::get_sem_size( std::shared_ptr& mesh_device, CoreCoord logical_core, CoreType core_type) { uint32_t sem_size = 0; uint32_t program_idx = 0; - IDevice* device = mesh_device->get_device_index(0); for (auto& [device_range, program] : programs_) { if (program_idx) { - TT_ASSERT(sem_size == program.get_sem_size(device, logical_core, core_type)); + TT_ASSERT(sem_size == program.get_sem_size(mesh_device.get(), logical_core, core_type)); } else { - sem_size = program.get_sem_size(device, logical_core, core_type); + sem_size = program.get_sem_size(mesh_device.get(), logical_core, core_type); } program_idx++; } @@ -281,12 +280,11 @@ uint32_t MeshWorkload::get_cb_size( std::shared_ptr& mesh_device, CoreCoord logical_core, CoreType core_type) { uint32_t cb_size = 0; uint32_t program_idx = 0; - IDevice* device = mesh_device->get_device_index(0); for (auto& [device_range, program] : programs_) { if (program_idx) { - TT_ASSERT(cb_size == program.get_cb_size(device, logical_core, core_type)); + TT_ASSERT(cb_size == program.get_cb_size(mesh_device.get(), logical_core, core_type)); } else { - cb_size = program.get_cb_size(device, logical_core, core_type); + cb_size = program.get_cb_size(mesh_device.get(), logical_core, core_type); } program_idx++; } diff --git a/tt_metal/distributed/system_mesh.cpp b/tt_metal/distributed/system_mesh.cpp index c90fed6f897..20d912a3b1a 100644 --- a/tt_metal/distributed/system_mesh.cpp +++ b/tt_metal/distributed/system_mesh.cpp @@ -7,31 +7,30 @@ #include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/distributed/coordinate_translation.hpp" +#include "mesh_coord.hpp" #include "tt_cluster.hpp" namespace tt::tt_metal::distributed { class SystemMesh::Impl { private: - MeshShape logical_mesh_shape_; + SimpleMeshShape logical_mesh_shape_; CoordinateTranslationMap logical_to_physical_coordinates_; - std::unordered_map logical_to_device_id_; + std::unordered_map logical_to_device_id_; std::unordered_map physical_coordinate_to_device_id_; std::unordered_map physical_device_id_to_coordinate_; public: Impl() = default; - ~Impl() = default; bool is_system_mesh_initialized() const; void initialize(); - const MeshShape& get_shape() const; - size_t get_num_devices() const; + const SimpleMeshShape& get_shape() const; std::vector get_mapped_physical_device_ids(const MeshDeviceConfig& config) const; std::vector request_available_devices(const MeshDeviceConfig& config) const; - IDevice* get_device(const chip_id_t physical_device_id) const; - chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const; + IDevice* get_device(const chip_id_t physical_device_id) const; + chip_id_t get_physical_device_id(const MeshCoordinate& coord) const; }; // Implementation of public methods @@ -69,30 +68,34 @@ void SystemMesh::Impl::initialize() { } } -const MeshShape& SystemMesh::Impl::get_shape() const { return logical_mesh_shape_; } -size_t SystemMesh::Impl::get_num_devices() const { - auto [num_rows, num_cols] = this->get_shape(); - return num_rows * num_cols; -} +const SimpleMeshShape& SystemMesh::Impl::get_shape() const { return logical_mesh_shape_; } -chip_id_t SystemMesh::Impl::get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const { +chip_id_t SystemMesh::Impl::get_physical_device_id(const MeshCoordinate& coord) const { TT_FATAL( - logical_row_idx < logical_mesh_shape_.num_rows, - "Row index out of bounds: {} >= {}", - logical_row_idx, - logical_mesh_shape_.num_rows); - TT_FATAL( - logical_col_idx < logical_mesh_shape_.num_cols, - "Column index out of bounds: {} >= {}", - logical_col_idx, - logical_mesh_shape_.num_cols); - auto logical_coordinate = Coordinate{logical_row_idx, logical_col_idx}; - return logical_to_device_id_.at(logical_coordinate); + coord.dims() == logical_mesh_shape_.dims(), + "Coordinate dimensions mismatch: {} != {}", + coord.dims(), + logical_mesh_shape_.dims()); + for (size_t i = 0; i < coord.dims(); ++i) { + TT_FATAL( + coord[i] < logical_mesh_shape_[i], + "Coordinate at index {} out of bounds; mesh shape {}, coordinate {}", + i, + logical_mesh_shape_, + coord); + } + return logical_to_device_id_.at(coord); } std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { std::vector physical_device_ids; - auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); + // TODO: #17477 - Extend to ND. + TT_FATAL( + logical_mesh_shape_.dims() == 2, + "SystemMesh only supports 2D meshes; requested dimensions: {}", + logical_mesh_shape_.dims()); + + auto [system_mesh_rows, system_mesh_cols] = std::make_tuple(logical_mesh_shape_[0], logical_mesh_shape_[1]); auto [requested_num_rows, requested_num_cols] = config.mesh_shape; auto [row_offset, col_offset] = config.offset; @@ -112,7 +115,8 @@ std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const Me auto line_coords = MeshDeviceView::get_line_coordinates( line_length, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); for (const auto& logical_coordinate : line_coords) { - auto physical_device_id = logical_to_device_id_.at(logical_coordinate); + auto physical_device_id = + logical_to_device_id_.at(MeshCoordinate(logical_coordinate.row, logical_coordinate.col)); physical_device_ids.push_back(physical_device_id); log_debug( @@ -178,17 +182,18 @@ std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const Me } TT_FATAL( - logical_coordinate.row < logical_mesh_shape_.num_rows, + logical_coordinate.row < system_mesh_rows, "Row coordinate out of bounds: {} >= {}", logical_coordinate.row, - logical_mesh_shape_.num_rows); + system_mesh_rows); TT_FATAL( - logical_coordinate.col < logical_mesh_shape_.num_cols, + logical_coordinate.col < system_mesh_cols, "Column coordinate out of bounds: {} >= {}", logical_coordinate.col, - logical_mesh_shape_.num_cols); + system_mesh_cols); - auto physical_device_id = logical_to_device_id_.at(logical_coordinate); + auto physical_device_id = + logical_to_device_id_.at(MeshCoordinate(logical_coordinate.row, logical_coordinate.col)); physical_device_ids.push_back(physical_device_id); log_debug( @@ -200,7 +205,6 @@ std::vector SystemMesh::Impl::get_mapped_physical_device_ids(const Me std::vector SystemMesh::Impl::request_available_devices(const MeshDeviceConfig& config) const { auto [requested_num_rows, requested_num_cols] = config.mesh_shape; - auto [max_num_rows, max_num_cols] = logical_mesh_shape_; auto [row_offset, col_offset] = config.offset; log_debug( @@ -216,7 +220,6 @@ std::vector SystemMesh::Impl::request_available_devices(const MeshDev } SystemMesh::SystemMesh() : pimpl_(std::make_unique()) {} -SystemMesh::~SystemMesh() = default; SystemMesh& SystemMesh::instance() { static SystemMesh instance; @@ -226,13 +229,11 @@ SystemMesh& SystemMesh::instance() { return instance; } -chip_id_t SystemMesh::get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const { - return pimpl_->get_physical_device_id(logical_row_idx, logical_col_idx); +chip_id_t SystemMesh::get_physical_device_id(const MeshCoordinate& coord) const { + return pimpl_->get_physical_device_id(coord); } -const MeshShape& SystemMesh::get_shape() const { return pimpl_->get_shape(); } - -size_t SystemMesh::get_num_devices() const { return pimpl_->get_num_devices(); } +const SimpleMeshShape& SystemMesh::get_shape() const { return pimpl_->get_shape(); } std::vector SystemMesh::request_available_devices(const MeshDeviceConfig& config) const { return pimpl_->request_available_devices(config); diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index bd0fd35a206..9133ec419ac 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -124,7 +124,7 @@ Tensor aggregate_as_tensor( std::vector get_t3k_physical_device_ids_ring() { using namespace tt::tt_metal::distributed; auto& instance = SystemMesh::instance(); - auto num_devices = instance.get_num_devices(); + auto num_devices = instance.get_shape().mesh_size(); TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); auto physical_device_ids =