Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17477: Adopt ND coordinate system in system mesh, coordinate translation #17926

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <unordered_set>

#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {
namespace {

using ::testing::ElementsAre;

using ::testing::UnorderedElementsAre;
TEST(SimpleMeshShapeTest, Construction) {
SimpleMeshShape shape_1d(3);
EXPECT_EQ(shape_1d.dims(), 1);
Expand Down Expand Up @@ -100,6 +101,21 @@ TEST(MeshCoordinateTest, Comparison) {
EXPECT_NE(coord1, MeshCoordinate(1, 2, 1));
}

TEST(MeshCoordinateTest, UnorderedSet) {
std::unordered_set<MeshCoordinate> set;
set.insert(MeshCoordinate(0, 0, 0));
set.insert(MeshCoordinate(0, 0, 1));
set.insert(MeshCoordinate(0, 0, 2));

EXPECT_FALSE(set.insert(MeshCoordinate(0, 0, 2)).second);
EXPECT_THAT(
set,
UnorderedElementsAre(
MeshCoordinate(0, 0, 0), //
MeshCoordinate(0, 0, 1),
MeshCoordinate(0, 0, 2)));
}

TEST(MeshCoordinateRangeTest, FromShape) {
SimpleMeshShape shape(2, 3);
MeshCoordinateRange range(shape);
Expand Down Expand Up @@ -232,6 +248,7 @@ TEST(MeshContainerTest, ElementAccessRowMajor) {
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, ConstContainer) {
Expand All @@ -254,6 +271,7 @@ TEST(MeshContainerTest, ConstContainer) {
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0));
EXPECT_THAT(container.values(), ElementsAre(0, 0, 0, 0, 0, 0));
}

TEST(MeshContainerTest, MutateThroughProxy) {
Expand All @@ -276,6 +294,7 @@ TEST(MeshContainerTest, MutateThroughProxy) {
values.push_back(value);
}
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
EXPECT_THAT(container.values(), ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, OutOfBounds) {
Expand Down
16 changes: 10 additions & 6 deletions tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <gtest/gtest.h>

#include <cstddef>
#include <ttnn/core.hpp>
#include <ttnn/distributed/api.hpp>

Expand All @@ -19,11 +18,16 @@ class DistributedTest : public ::testing::Test {
TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
auto& sys = SystemMesh::instance();
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto [rows, cols] = sys.get_shape();
EXPECT_GT(rows, 0);
EXPECT_GT(cols, 0);
/*mesh_shape=*/{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER);

const auto system_shape = sys.get_shape();
ASSERT_EQ(system_shape.dims(), 2);
EXPECT_EQ(system_shape[0], 2);
EXPECT_EQ(system_shape[1], 4);
}

TEST_F(DistributedTest, TestMemoryAllocationStatistics) {
Expand Down
7 changes: 4 additions & 3 deletions tests/ttnn/distributed/test_distributed_atexit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ TEST(DistributedTestStandalone, TestSystemMeshTearDownWithoutClose) {
mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto [rows, cols] = sys.get_shape();
EXPECT_GT(rows, 0);
EXPECT_GT(cols, 0);
const auto system_shape = sys.get_shape();
ASSERT_EQ(system_shape.dims(), 2);
EXPECT_EQ(system_shape[0], 2);
EXPECT_EQ(system_shape[1], 4);
}

} // namespace ttnn::distributed::test
39 changes: 25 additions & 14 deletions tt_metal/api/tt-metalium/mesh_coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "shape_base.hpp"
#include "utils.hpp"

namespace tt::tt_metal::distributed {

Expand All @@ -21,7 +22,7 @@ class SimpleMeshShape : public ShapeBase {
using ShapeBase::operator[];

// Shorthands for constructing 1D, 2D and 3D shapes.
SimpleMeshShape(uint32_t x);
explicit SimpleMeshShape(uint32_t x);
SimpleMeshShape(uint32_t x, uint32_t y);
SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z);

Expand Down Expand Up @@ -56,7 +57,7 @@ class SimpleMeshShape : public ShapeBase {
class MeshCoordinate {
public:
// Shorthands for constructing 1D, 2D and 3D coordinates.
MeshCoordinate(uint32_t x);
explicit MeshCoordinate(uint32_t x);
MeshCoordinate(uint32_t x, uint32_t y);
MeshCoordinate(uint32_t x, uint32_t y, uint32_t z);

Expand Down Expand Up @@ -199,7 +200,10 @@ class MeshContainer {
using ValueProxy = detail::MeshCoordinateValueProxy<T>;

Iterator& operator++();
ValueProxy& operator*();
ValueProxy& operator*() { return value_proxy_; }
const ValueProxy& operator*() const { return value_proxy_; }
ValueProxy* operator->() { return &value_proxy_; }
const ValueProxy* operator->() const { return &value_proxy_; }
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

Expand All @@ -220,7 +224,8 @@ class MeshContainer {
using ValueProxy = detail::MeshCoordinateValueProxy<const T>;

ConstIterator& operator++();
const ValueProxy& operator*() const;
const ValueProxy& operator*() const { return value_proxy_; }
const ValueProxy* operator->() const { return &value_proxy_; }
bool operator==(const ConstIterator& other) const;
bool operator!=(const ConstIterator& other) const;

Expand All @@ -237,11 +242,16 @@ class MeshContainer {
ValueProxy value_proxy_;
};

// Iterators provide a reference to the value along with the coordinate.
Iterator begin();
Iterator end();
ConstIterator begin() const;
ConstIterator end() const;

// View of the flat container of values.
std::vector<T>& values() { return values_; }
const std::vector<T>& values() const { return values_; }

private:
SimpleMeshShape shape_;
MeshCoordinateRange coord_range_;
Expand Down Expand Up @@ -283,11 +293,6 @@ typename MeshContainer<T>::Iterator& MeshContainer<T>::Iterator::operator++() {
return *this;
}

template <typename T>
typename MeshContainer<T>::Iterator::ValueProxy& MeshContainer<T>::Iterator::operator*() {
return value_proxy_;
}

template <typename T>
MeshContainer<T>::ConstIterator::ConstIterator(
const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) :
Expand All @@ -304,11 +309,6 @@ typename MeshContainer<T>::ConstIterator& MeshContainer<T>::ConstIterator::opera
return *this;
}

template <typename T>
const typename MeshContainer<T>::ConstIterator::ValueProxy& MeshContainer<T>::ConstIterator::operator*() const {
return value_proxy_;
}

template <typename T>
bool MeshContainer<T>::Iterator::operator==(const Iterator& other) const {
return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_;
Expand Down Expand Up @@ -367,4 +367,15 @@ struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValuePr
using type = T;
};

template <>
struct hash<tt::tt_metal::distributed::MeshCoordinate> {
size_t operator()(const tt::tt_metal::distributed::MeshCoordinate& coord) const noexcept {
size_t seed = 0;
for (const auto coord_value : coord.coords()) {
tt::utils::hash_combine(seed, coord_value);
}
return seed;
}
};

} // namespace std
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
class ScopedDevices {
private:
std::map<chip_id_t, IDevice*> opened_devices_;
std::vector<IDevice*> devices_;
MeshContainer<IDevice*> devices_;

public:
// Constructor acquires physical resources
Expand All @@ -50,6 +50,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
ScopedDevices& operator=(const ScopedDevices&) = delete;

const std::vector<IDevice*>& get_devices() const;
IDevice* get_device(const MeshCoordinate& coord) const;
};

std::shared_ptr<ScopedDevices> scoped_devices_;
Expand Down Expand Up @@ -202,7 +203,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

// Returns the devices in the mesh in row-major order.
std::vector<IDevice*> get_devices() const;
IDevice* get_device_index(size_t logical_device_id) const;
IDevice* get_device(chip_id_t physical_device_id) const;
IDevice* get_device(size_t row_idx, size_t col_idx) const;
IDevice* get_device(const MeshCoordinate& coord) const;
Expand Down
9 changes: 3 additions & 6 deletions tt_metal/api/tt-metalium/system_mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
#include <vector>

#include "mesh_config.hpp"
#include "mesh_device.hpp"
#include "device.hpp"
#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {

Expand All @@ -21,7 +20,6 @@ class SystemMesh {
class Impl; // Forward declaration only
std::unique_ptr<Impl> pimpl_;
SystemMesh();
~SystemMesh();

public:
static SystemMesh& instance();
Expand All @@ -30,11 +28,10 @@ class SystemMesh {
SystemMesh(SystemMesh&&) = delete;
SystemMesh& operator=(SystemMesh&&) = delete;

const MeshShape& get_shape() const;
size_t get_num_devices() const;
const SimpleMeshShape& get_shape() const;

// Gets the physical device ID for a given logical row and column index
chip_id_t get_physical_device_id(size_t logical_row_idx, size_t logical_col_idx) const;
chip_id_t get_physical_device_id(const MeshCoordinate& coord) const;

// Get the physical device IDs mapped to a MeshDevice
std::vector<chip_id_t> get_mapped_physical_device_ids(const MeshDeviceConfig& config) const;
Expand Down
58 changes: 24 additions & 34 deletions tt_metal/distributed/coordinate_translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const
TT_THROW("Invalid coordinate format in JSON file: {}", filename);
}
result.emplace(
Coordinate{mapping[0][0], mapping[0][1]},
MeshCoordinate(mapping[0][0], mapping[0][1]),
PhysicalCoordinate{
mapping[1][0], // cluster_id
mapping[1][2], // x
Expand All @@ -49,49 +49,39 @@ CoordinateTranslationMap load_translation_map(const std::string& filename, const
return result;
}

MeshShape get_system_mesh_shape(size_t system_num_devices) {
static const std::unordered_map<size_t, MeshShape> system_mesh_to_shape = {
{1, MeshShape{1, 1}}, // single-device
{2, MeshShape{1, 2}}, // N300
{8, MeshShape{2, 4}}, // T3000; as ring to match existing tests
{32, MeshShape{8, 4}}, // TG, QG
{64, MeshShape{8, 8}}, // TGG
};
TT_FATAL(
system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices);
auto shape = system_mesh_to_shape.at(system_num_devices);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols);
return shape;
}

} // namespace

std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map() {
static const auto* cached_translation_map = new std::pair<CoordinateTranslationMap, MeshShape>([] {
auto system_num_devices = tt::Cluster::instance().number_of_user_devices();
const std::pair<CoordinateTranslationMap, SimpleMeshShape>& get_system_mesh_coordinate_translation_map() {
static const auto* cached_translation_map = new std::pair<CoordinateTranslationMap, SimpleMeshShape>([] {
const auto system_num_devices = tt::Cluster::instance().number_of_user_devices();

std::string galaxy_mesh_descriptor = "TG.json";
if (tt::Cluster::instance().number_of_pci_devices() == system_num_devices) {
galaxy_mesh_descriptor = "QG.json";
}
const bool is_qg = tt::Cluster::instance().number_of_pci_devices() == system_num_devices;

const std::unordered_map<size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
{8, "T3000.json"},
{32, galaxy_mesh_descriptor},
{64, "TGG.json"},
// TODO: #17477 - This assumes shapes and coordinates are in 2D. This will be extended for 3D.
// Consider if 1D can be used for single device and N300.
const std::unordered_map<size_t, std::pair<std::string, SimpleMeshShape>> system_mesh_translation_map = {
{1, std::make_pair("device.json", SimpleMeshShape(1, 1))},
{2, std::make_pair("N300.json", SimpleMeshShape(1, 2))},
{8, std::make_pair("T3000.json", SimpleMeshShape(2, 4))},
{32, std::make_pair(is_qg ? "QG.json" : "TG.json", SimpleMeshShape(8, 4))},
{64, std::make_pair("TGG.json", SimpleMeshShape(8, 8))},
Comment on lines +63 to +67
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping these 2 together to ensure things are in sync.

};

TT_FATAL(
system_mesh_translation_map.contains(system_num_devices),
"Unsupported number of devices: {}",
system_num_devices);

auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices));
return std::pair<CoordinateTranslationMap, MeshShape>{
load_translation_map(translation_config_file, "logical_to_physical_coordinates"),
get_system_mesh_shape(system_num_devices)};
const auto [translation_config_file, shape] = system_mesh_translation_map.at(system_num_devices);
TT_FATAL(
system_num_devices == shape.mesh_size(),
"Mismatch between number of devices and the mesh shape: {} != {}",
system_num_devices,
shape.mesh_size());
log_debug(LogMetal, "Logical SystemMesh Shape: {}", shape);

return std::pair<CoordinateTranslationMap, SimpleMeshShape>{
load_translation_map(get_config_path(translation_config_file), /*key=*/"logical_to_physical_coordinates"),
shape};
}());

return *cached_translation_map;
Expand Down
9 changes: 5 additions & 4 deletions tt_metal/distributed/coordinate_translation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
#include <unordered_map>

#include "umd/device/types/cluster_descriptor_types.h"
#include <mesh_coord.hpp>
#include <mesh_device_view.hpp>

namespace tt::tt_metal::distributed {

// TODO: Consider conversion to StrongType instead of alias
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
using CoordinateTranslationMap = std::unordered_map<LogicalCoordinate, PhysicalCoordinate>;
using CoordinateTranslationMap = std::unordered_map<MeshCoordinate, PhysicalCoordinate>;

// Returns a translation map between logical coordinates in logical 2D space
// Returns a translation map between logical coordinates in logical ND space
// to the physical coordinates as defined by the UMD layer.
std::pair<CoordinateTranslationMap, MeshShape> get_system_mesh_coordinate_translation_map();
// TODO: #17477 - Return MeshContainer<PhysicalCoordinate> that contains everything we need.
const std::pair<CoordinateTranslationMap, SimpleMeshShape>& get_system_mesh_coordinate_translation_map();

} // namespace tt::tt_metal::distributed
Loading
Loading