Skip to content

Commit

Permalink
#17477: Introduce ND coordinate system for TT-distributed (#17745)
Browse files Browse the repository at this point in the history
### Ticket
#17477 

### Problem description
Existing mesh infra assumes 2D. This assumption won't hold in the
future.

### What's changed
Introduce a new `SimpleMeshShape` that will gradually replace the
existing `MeshShape`, after which it will be renamed to `MeshShape`.

Introduce `MeshCoordinate`, `MeshCoordinateRange`, and `MeshContainer` -
primitives designed to work with the new ND coordinate system.

`MeshContainer` allows efficient flat representation of various metadata
that matches the mesh shape. Iterators are available to make it easy to
use. `MeshCoordinate` along with strides that are precomputed on
`SimpleMeshShape` allows for an easy point access. The integration with
`MeshBuffer` demonstrates the use case.

Next steps:
* Replace the existing `MeshShape`, `MeshOffset`, and the related
aliases with the new `SimpleMeshShape`, and `MeshCoordinate`.
* No plans to generalize with `CoreCoord`, for now. Cores are
fundamentally in 2D, so a more specialized system can be used for
efficiency. Also it is not desired to make `CoreCoord` to interop with
`MeshCoordinate` - the 2 sets of coordinates mean entirely different
concepts.
* More functionality might be added, as we continue working on
TT-distributed.

### Checklist
- [X] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13347753550)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Feb 16, 2025
1 parent 907fffd commit 52c53d5
Show file tree
Hide file tree
Showing 14 changed files with 862 additions and 30 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_coord.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_sub_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp
Expand Down
290 changes: 290 additions & 0 deletions tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {
namespace {

using ::testing::ElementsAre;

TEST(SimpleMeshShapeTest, Construction) {
SimpleMeshShape shape_1d(3);
EXPECT_EQ(shape_1d.dims(), 1);
EXPECT_EQ(shape_1d[0], 3);
EXPECT_EQ(shape_1d.mesh_size(), 3);

SimpleMeshShape shape_2d(3, 4);
EXPECT_EQ(shape_2d.dims(), 2);
EXPECT_EQ(shape_2d[0], 3);
EXPECT_EQ(shape_2d[1], 4);
EXPECT_EQ(shape_2d.mesh_size(), 12);

SimpleMeshShape shape_3d(2, 3, 4);
EXPECT_EQ(shape_3d.dims(), 3);
EXPECT_EQ(shape_3d[0], 2);
EXPECT_EQ(shape_3d[1], 3);
EXPECT_EQ(shape_3d[2], 4);
EXPECT_EQ(shape_3d.mesh_size(), 24);

SimpleMeshShape shape_5d({2, 3, 4, 5, 6});
EXPECT_EQ(shape_5d.dims(), 5);
EXPECT_EQ(shape_5d[0], 2);
EXPECT_EQ(shape_5d[1], 3);
EXPECT_EQ(shape_5d[2], 4);
EXPECT_EQ(shape_5d[3], 5);
EXPECT_EQ(shape_5d[4], 6);
EXPECT_EQ(shape_5d.mesh_size(), 720);
}

TEST(SimpleMeshShapeTest, ZeroShape) {
SimpleMeshShape shape({});
EXPECT_EQ(shape.dims(), 0);
EXPECT_EQ(shape.mesh_size(), 0);
}

TEST(SimpleMeshShapeTest, Strides) {
SimpleMeshShape shape(2, 3, 4);
EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4
EXPECT_EQ(shape.get_stride(1), 4); // 4
EXPECT_EQ(shape.get_stride(2), 1); // 1
}

TEST(SimpleMeshShapeTest, Comparison) {
SimpleMeshShape shape(2, 3);

EXPECT_EQ(shape, SimpleMeshShape(2, 3));
EXPECT_NE(shape, SimpleMeshShape(3, 2));
EXPECT_NE(shape, SimpleMeshShape(1, 2, 3));
}

TEST(MeshCoordinateTest, Construction) {
MeshCoordinate coord_1d(1);
EXPECT_EQ(coord_1d.dims(), 1);
EXPECT_THAT(coord_1d.coords(), ElementsAre(1));
EXPECT_EQ(coord_1d[0], 1);

MeshCoordinate coord_2d(1, 2);
EXPECT_EQ(coord_2d.dims(), 2);
EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2));
EXPECT_EQ(coord_2d[0], 1);
EXPECT_EQ(coord_2d[1], 2);

MeshCoordinate coord_3d(1, 2, 3);
EXPECT_EQ(coord_3d.dims(), 3);
EXPECT_THAT(coord_3d.coords(), ElementsAre(1, 2, 3));
EXPECT_EQ(coord_3d[0], 1);
EXPECT_EQ(coord_3d[1], 2);
EXPECT_EQ(coord_3d[2], 3);

std::vector<uint32_t> values = {1, 2, 3, 4, 5};
MeshCoordinate coord_span(values);
EXPECT_EQ(coord_span.dims(), 5);
EXPECT_THAT(coord_span.coords(), ElementsAre(1, 2, 3, 4, 5));
EXPECT_EQ(coord_span[0], 1);
EXPECT_EQ(coord_span[1], 2);
EXPECT_EQ(coord_span[2], 3);
EXPECT_EQ(coord_span[3], 4);
EXPECT_EQ(coord_span[4], 5);
}

TEST(MeshCoordinateTest, Comparison) {
MeshCoordinate coord1(1, 2);

EXPECT_EQ(coord1, MeshCoordinate(1, 2));
EXPECT_NE(coord1, MeshCoordinate(2, 1));
EXPECT_NE(coord1, MeshCoordinate(1, 2, 1));
}

TEST(MeshCoordinateRangeTest, FromShape) {
SimpleMeshShape shape(2, 3);
MeshCoordinateRange range(shape);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
}

TEST(MeshCoordinateRangeTest, Subrange) {
MeshCoordinate start(1, 1, 1);
MeshCoordinate end(2, 1, 4);
MeshCoordinateRange range(start, end);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(1, 1, 1),
MeshCoordinate(1, 1, 2),
MeshCoordinate(1, 1, 3),
MeshCoordinate(1, 1, 4),
MeshCoordinate(2, 1, 1),
MeshCoordinate(2, 1, 2),
MeshCoordinate(2, 1, 3),
MeshCoordinate(2, 1, 4)));
}

TEST(MeshCoordinateRangeTest, SubrangeOneElement) {
MeshCoordinate start(1, 1, 1);
MeshCoordinate end(1, 1, 1);
MeshCoordinateRange range(start, end);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1)));
}

TEST(MeshCoordinateRangeTest, MismatchedDimensions) {
MeshCoordinate start(1, 0);
MeshCoordinate end(2, 3, 1);
EXPECT_ANY_THROW(MeshCoordinateRange(start, end));
}

TEST(MeshCoordinateRangeTest, InvalidRange) {
MeshCoordinate start(1, 2, 0);
MeshCoordinate end(1, 1, 1);
EXPECT_ANY_THROW(MeshCoordinateRange(start, end));
}

TEST(ToLinearIndexTest, Basic) {
SimpleMeshShape shape(2, 2, 3);

EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 0)), 0);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 1)), 1);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 2)), 2);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 0)), 3);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 1)), 4);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 2)), 5);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 0)), 6);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 1)), 7);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 2)), 8);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 0)), 9);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 1)), 10);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 2)), 11);
}

TEST(ToLinearIndexTest, MismatchedDimensions) {
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0)));
}

TEST(ToLinearIndexTest, OutOfBounds) {
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0)));
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(0, 3)));
}

TEST(MeshContainerTest, InitialValues) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 3);

std::vector<int> initial_values;
for (const auto& [_, value] : container) {
initial_values.push_back(value);
}
EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3));
}

TEST(MeshContainerTest, ElementAccessRowMajor) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

container.at(MeshCoordinate(0, 0)) = 0;
container.at(MeshCoordinate(0, 1)) = 1;
container.at(MeshCoordinate(0, 2)) = 2;
container.at(MeshCoordinate(1, 0)) = 3;
container.at(MeshCoordinate(1, 1)) = 4;
container.at(MeshCoordinate(1, 2)) = 5;

std::vector<MeshCoordinate> coords;
std::vector<int> values;
for (const auto& [coord, value] : container) {
coords.push_back(coord);
values.push_back(value);
}
EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, ConstContainer) {
SimpleMeshShape shape(2, 3);
const MeshContainer<int> container(shape, 0);

std::vector<MeshCoordinate> coords;
std::vector<int> values;
for (const auto& [coord, value] : container) {
coords.push_back(coord);
values.push_back(value);
}
EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0));
}

TEST(MeshContainerTest, MutateThroughProxy) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

// Proxy class provides access to the container value through the mutable reference.
int updated_value = 0;
for (auto& [_, value] : container) {
value = updated_value++;
}

// `auto` makes a copy of the value, verify this loop is a no-op.
for (auto [_, value] : container) {
value = updated_value++;
}

std::vector<int> values;
for (const auto& [_, value] : container) {
values.push_back(value);
}
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, OutOfBounds) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

EXPECT_ANY_THROW(container.at(MeshCoordinate(2, 0)));
EXPECT_ANY_THROW(container.at(MeshCoordinate(0, 0, 0)));
}

} // namespace
} // namespace tt::tt_metal::distributed
7 changes: 5 additions & 2 deletions tt_metal/api/tt-metalium/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "buffer.hpp"
#include "buffer_constants.hpp"
#include "mesh_coord.hpp"
#include "mesh_device.hpp"
#include "mesh_device_view.hpp"
#include "shape2d.hpp"
Expand Down Expand Up @@ -96,6 +97,7 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; }

std::shared_ptr<Buffer> get_device_buffer(const Coordinate& device_coord) const;
std::shared_ptr<Buffer> get_device_buffer(const MeshCoordinate& device_coord) const;
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;
std::pair<bool, bool> replicated_dims() const;
Expand All @@ -108,6 +110,7 @@ class MeshBuffer {
DeviceAddr device_local_size,
MeshDevice* mesh_device,
std::shared_ptr<Buffer> backing_buffer) :
buffers_(SimpleMeshShape(mesh_device->shape()), nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
Expand All @@ -122,6 +125,7 @@ class MeshBuffer {
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device) :
buffers_(SimpleMeshShape(mesh_device->shape()), /*fill_value=*/nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
Expand All @@ -136,8 +140,7 @@ class MeshBuffer {
DeviceAddr address_ = 0;
DeviceAddr device_local_size_ = 0;

// TODO: Consider optimizing with SmallVector.
std::vector<std::vector<std::shared_ptr<Buffer>>> buffers_;
MeshContainer<std::shared_ptr<Buffer>> buffers_;

// `MeshBufferState` specifies the state of the MeshBuffer. It can either be:
// 1. Owned - a single device buffer is responsible for providing the address for the entire mesh buffer.
Expand Down
Loading

0 comments on commit 52c53d5

Please sign in to comment.