Skip to content

Commit

Permalink
#17477: Adopt ND coordinate system in system mesh, coordinate transla…
Browse files Browse the repository at this point in the history
…tion (#17926)

### Ticket
#17477

### Problem description
TT-distributed needs to adopt ND coordinate system for mesh primitives.

### What's changed
Plumbed `SimpleMeshShape` in `SystemMesh`, logical to physical
coordinate translation mapping.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13395057290)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored and dgomezTT committed Feb 19, 2025
1 parent f4a1916 commit d985a51
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 130 deletions.
21 changes: 20 additions & 1 deletion tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <unordered_set>

#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {
namespace {

using ::testing::ElementsAre;

using ::testing::UnorderedElementsAre;
TEST(SimpleMeshShapeTest, Construction) {
SimpleMeshShape shape_1d(3);
EXPECT_EQ(shape_1d.dims(), 1);
Expand Down Expand Up @@ -100,6 +101,21 @@ TEST(MeshCoordinateTest, Comparison) {
EXPECT_NE(coord1, MeshCoordinate(1, 2, 1));
}

TEST(MeshCoordinateTest, UnorderedSet) {
std::unordered_set<MeshCoordinate> set;
set.insert(MeshCoordinate(0, 0, 0));
set.insert(MeshCoordinate(0, 0, 1));
set.insert(MeshCoordinate(0, 0, 2));

EXPECT_FALSE(set.insert(MeshCoordinate(0, 0, 2)).second);
EXPECT_THAT(
set,
UnorderedElementsAre(
MeshCoordinate(0, 0, 0), //
MeshCoordinate(0, 0, 1),
MeshCoordinate(0, 0, 2)));
}

TEST(MeshCoordinateRangeTest, FromShape) {
SimpleMeshShape shape(2, 3);
MeshCoordinateRange range(shape);
Expand Down Expand Up @@ -232,6 +248,7 @@ TEST(MeshContainerTest, ElementAccessRowMajor) {
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, ConstContainer) {
Expand All @@ -254,6 +271,7 @@ TEST(MeshContainerTest, ConstContainer) {
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0));
EXPECT_THAT(container.values(), ElementsAre(0, 0, 0, 0, 0, 0));
}

TEST(MeshContainerTest, MutateThroughProxy) {
Expand All @@ -276,6 +294,7 @@ TEST(MeshContainerTest, MutateThroughProxy) {
values.push_back(value);
}
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, OutOfBounds) {
Expand Down
16 changes: 10 additions & 6 deletions tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <gtest/gtest.h>

#include <cstddef>
#include <ttnn/core.hpp>
#include <ttnn/distributed/api.hpp>

Expand All @@ -19,11 +18,16 @@ class DistributedTest : public ::testing::Test {
TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
auto& sys = SystemMesh::instance();
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto [rows, cols] = sys.get_shape();
EXPECT_GT(rows, 0);
EXPECT_GT(cols, 0);
/*mesh_shape=*/{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER);

const auto system_shape = sys.get_shape();
ASSERT_EQ(system_shape.dims(), 2);
EXPECT_EQ(system_shape[0], 2);
EXPECT_EQ(system_shape[1], 4);
}

TEST_F(DistributedTest, TestMemoryAllocationStatistics) {
Expand Down
7 changes: 4 additions & 3 deletions tests/ttnn/distributed/test_distributed_atexit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ TEST(DistributedTestStandalone, TestSystemMeshTearDownWithoutClose) {
mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto [rows, cols] = sys.get_shape();
EXPECT_GT(rows, 0);
EXPECT_GT(cols, 0);
const auto system_shape = sys.get_shape();
ASSERT_EQ(system_shape.dims(), 2);
EXPECT_EQ(system_shape[0], 2);
EXPECT_EQ(system_shape[1], 4);
}

} // namespace ttnn::distributed::test
39 changes: 25 additions & 14 deletions tt_metal/api/tt-metalium/mesh_coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "shape_base.hpp"
#include "utils.hpp"

namespace tt::tt_metal::distributed {

Expand All @@ -21,7 +22,7 @@ class SimpleMeshShape : public ShapeBase {
using ShapeBase::operator[];

// Shorthands for constructing 1D, 2D and 3D shapes.
SimpleMeshShape(uint32_t x);
explicit SimpleMeshShape(uint32_t x);
SimpleMeshShape(uint32_t x, uint32_t y);
SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z);

Expand Down Expand Up @@ -56,7 +57,7 @@ class SimpleMeshShape : public ShapeBase {
class MeshCoordinate {
public:
// Shorthands for constructing 1D, 2D and 3D coordinates.
MeshCoordinate(uint32_t x);
explicit MeshCoordinate(uint32_t x);
MeshCoordinate(uint32_t x, uint32_t y);
MeshCoordinate(uint32_t x, uint32_t y, uint32_t z);

Expand Down Expand Up @@ -199,7 +200,10 @@ class MeshContainer {
using ValueProxy = detail::MeshCoordinateValueProxy<T>;

Iterator& operator++();
ValueProxy& operator*();
ValueProxy& operator*() { return value_proxy_; }
const ValueProxy& operator*() const { return value_proxy_; }
ValueProxy* operator->() { return &value_proxy_; }
const ValueProxy* operator->() const { return &value_proxy_; }
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

Expand All @@ -220,7 +224,8 @@ class MeshContainer {
using ValueProxy = detail::MeshCoordinateValueProxy<const T>;

ConstIterator& operator++();
const ValueProxy& operator*() const;
const ValueProxy& operator*() const { return value_proxy_; }
const ValueProxy* operator->() const { return &value_proxy_; }
bool operator==(const ConstIterator& other) const;
bool operator!=(const ConstIterator& other) const;

Expand All @@ -237,11 +242,16 @@ class MeshContainer {
ValueProxy value_proxy_;
};

// Iterators provide a reference to the value along with the coordinate.
Iterator begin();
Iterator end();
ConstIterator begin() const;
ConstIterator end() const;

// View of the flat container of values.
std::vector<T>& values() { return values_; }
const std::vector<T>& values() const { return values_; }

private:
SimpleMeshShape shape_;
MeshCoordinateRange coord_range_;
Expand Down Expand Up @@ -283,11 +293,6 @@ typename MeshContainer<T>::Iterator& MeshContainer<T>::Iterator::operator++() {
return *this;
}

template <typename T>
typename MeshContainer<T>::Iterator::ValueProxy& MeshContainer<T>::Iterator::operator*() {
return value_proxy_;
}

template <typename T>
MeshContainer<T>::ConstIterator::ConstIterator(
const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) :
Expand All @@ -304,11 +309,6 @@ typename MeshContainer<T>::ConstIterator& MeshContainer<T>::ConstIterator::opera
return *this;
}

template <typename T>
const typename MeshContainer<T>::ConstIterator::ValueProxy& MeshContainer<T>::ConstIterator::operator*() const {
return value_proxy_;
}

template <typename T>
bool MeshContainer<T>::Iterator::operator==(const Iterator& other) const {
return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_;
Expand Down Expand Up @@ -367,4 +367,15 @@ struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValuePr
using type = T;
};

template <>
struct hash<tt::tt_metal::distributed::MeshCoordinate> {
size_t operator()(const tt::tt_metal::distributed::MeshCoordinate& coord) const noexcept {
size_t seed = 0;
for (const auto coord_value : coord.coords()) {
tt::utils::hash_combine(seed, coord_value);
}
return seed;
}
};

} // namespace std
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
class ScopedDevices {
private:
std::map<chip_id_t, IDevice*> opened_devices_;
std::vector<IDevice*> devices_;
MeshContainer<IDevice*> devices_;

public:
// Constructor acquires physical resources
Expand All @@ -50,6 +50,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
ScopedDevices& operator=(const ScopedDevices&) = delete;

const std::vector<IDevice*>& get_devices() const;
IDevice* get_device(const MeshCoordinate& coord) const;
};

std::shared_ptr<ScopedDevices> scoped_devices_;
Expand Down Expand Up @@ -202,7 +203,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

// Returns the devices in the mesh in row-major order.
std::vector<IDevice*> get_devices() const;
IDevice* get_device_index(size_t logical_device_id) const;
IDevice* get_device(chip_id_t physical_device_id) const;
IDevice* get_device(size_t row_idx, size_t col_idx) const;
IDevice* get_device(const MeshCoordinate& coord) const;
Expand Down
9 changes: 3 additions & 6 deletions tt_metal/api/tt-metalium/system_mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
#include <vector>

#include "mesh_config.hpp"
#include "mesh_device.hpp"
#include "device.hpp"
#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {

Expand All @@ -21,7 +20,6 @@ class SystemMesh {
class Impl; // Forward declaration only
std::unique_ptr<Impl> pimpl_;
SystemMesh();
~SystemMesh();

public:
static SystemMesh& instance();
Expand All @@ -30,11 +28,10 @@ class SystemMesh {
SystemMesh(SystemMesh&&) = delete;
SystemMesh& operator=(SystemMesh&&) = delete;

const MeshShape& get_shape() const;
size_t get_num_devices() const;
const SimpleMeshShape& get_shape() const;

// Gets the physical device ID for a given logical row and column index
chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const;
chip_id_t get_physical_device_id(const MeshCoordinate& coord) const;

// Get the physical device IDs mapped to a MeshDevice
std::vector<chip_id_t> get_mapped_physical_device_ids(const MeshDeviceConfig& config) const;
Expand Down
58 changes: 24 additions & 34 deletions tt_metal/distributed/coordinate_translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const
TT_THROW("Invalid coordinate format in JSON file: {}", filename);
}
result.emplace(
Coordinate{mapping[0][0], mapping[0][1]},
MeshCoordinate(mapping[0][0], mapping[0][1]),
PhysicalCoordinate{
mapping[1][0], // cluster_id
mapping[1][2], // x
Expand All @@ -49,49 +49,39 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const
return result;
}

MeshShape get_system_mesh_shape(size_t system_num_devices) {
static const std::unordered_map<size_t, MeshShape> system_mesh_to_shape = {
{1, MeshShape{1, 1}}, // single-device
{2, MeshShape{1, 2}}, // N300
{8, MeshShape{2, 4}}, // T3000; as ring to match existing tests
{32, MeshShape{8, 4}}, // TG, QG
{64, MeshShape{8, 8}}, // TGG
};
TT_FATAL(
system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices);
auto shape = system_mesh_to_shape.at(system_num_devices);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols);
return shape;
}

} // namespace

std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map() {
static const auto* cached_translation_map = new std::pair<CoordinateTranslationMap, MeshShape>([] {
auto system_num_devices = tt::Cluster::instance().number_of_user_devices();
const std::pair<CoordinateTranslationMap, SimpleMeshShape>& get_system_mesh_coordinate_translation_map() {
static const auto* cached_translation_map = new std::pair<CoordinateTranslationMap, SimpleMeshShape>([] {
const auto system_num_devices = tt::Cluster::instance().number_of_user_devices();

std::string galaxy_mesh_descriptor = "TG.json";
if (tt::Cluster::instance().number_of_pci_devices() == system_num_devices) {
galaxy_mesh_descriptor = "QG.json";
}
const bool is_qg = tt::Cluster::instance().number_of_pci_devices() == system_num_devices;

const std::unordered_map<size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
{8, "T3000.json"},
{32, galaxy_mesh_descriptor},
{64, "TGG.json"},
// TODO: #17477 - This assumes shapes and coordinates are in 2D. This will be extended for 3D.
// Consider if 1D can be used for single device and N300.
const std::unordered_map<size_t, std::pair<std::string, SimpleMeshShape>> system_mesh_translation_map = {
{1, std::make_pair("device.json", SimpleMeshShape(1, 1))},
{2, std::make_pair("N300.json", SimpleMeshShape(1, 2))},
{8, std::make_pair("T3000.json", SimpleMeshShape(2, 4))},
{32, std::make_pair(is_qg ? "QG.json" : "TG.json", SimpleMeshShape(8, 4))},
{64, std::make_pair("TGG.json", SimpleMeshShape(8, 8))},
};

TT_FATAL(
system_mesh_translation_map.contains(system_num_devices),
"Unsupported number of devices: {}",
system_num_devices);

auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices));
return std::pair<CoordinateTranslationMap, MeshShape>{
load_translation_map(translation_config_file, "logical_to_physical_coordinates"),
get_system_mesh_shape(system_num_devices)};
const auto [translation_config_file, shape] = system_mesh_translation_map.at(system_num_devices);
TT_FATAL(
system_num_devices == shape.mesh_size(),
"Mismatch between number of devices and the mesh shape: {} != {}",
system_num_devices,
shape.mesh_size());
log_debug(LogMetal, "Logical SystemMesh Shape: {}", shape);

return std::pair<CoordinateTranslationMap, SimpleMeshShape>{
load_translation_map(get_config_path(translation_config_file), /*key=*/"logical_to_physical_coordinates"),
shape};
}());

return *cached_translation_map;
Expand Down
9 changes: 5 additions & 4 deletions tt_metal/distributed/coordinate_translation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
#include <unordered_map>

#include "umd/device/types/cluster_descriptor_types.h"
#include <mesh_coord.hpp>
#include <mesh_device_view.hpp>

namespace tt::tt_metal::distributed {

// TODO: Consider conversion to StrongType instead of alias
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
using CoordinateTranslationMap = std::unordered_map<LogicalCoordinate, PhysicalCoordinate>;
using CoordinateTranslationMap = std::unordered_map<MeshCoordinate, PhysicalCoordinate>;

// Returns a translation map between logical coordinates in logical 2D space
// Returns a translation map between logical coordinates in logical ND space
// to the physical coordinates as defined by the UMD layer.
std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map();
// TODO: #17477 - Return MeshContainer<PhysicalCoordinate> that contains everything we need.
const std::pair<CoordinateTranslationMap, SimpleMeshShape>& get_system_mesh_coordinate_translation_map();

} // namespace tt::tt_metal::distributed
Loading

0 comments on commit d985a51

Please sign in to comment.