Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17477: Introduce ND coordinate system for TT-distributed #17745

Merged
merged 10 commits into from
Feb 16, 2025
1 change: 1 addition & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_coord.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_sub_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp
Expand Down
262 changes: 262 additions & 0 deletions tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "mesh_coord.hpp"

namespace tt::tt_metal::distributed {
namespace {

using ::testing::ElementsAre;

TEST(SimpleMeshShapeTest, Construction) {
SimpleMeshShape shape_1d(3);
EXPECT_EQ(shape_1d.dims(), 1);
EXPECT_EQ(shape_1d[0], 3);
EXPECT_EQ(shape_1d.mesh_size(), 3);

SimpleMeshShape shape_2d(3, 4);
EXPECT_EQ(shape_2d.dims(), 2);
EXPECT_EQ(shape_2d[0], 3);
EXPECT_EQ(shape_2d[1], 4);
EXPECT_EQ(shape_2d.mesh_size(), 12);

SimpleMeshShape shape_3d(2, 3, 4);
EXPECT_EQ(shape_3d.dims(), 3);
EXPECT_EQ(shape_3d[0], 2);
EXPECT_EQ(shape_3d[1], 3);
EXPECT_EQ(shape_3d[2], 4);
EXPECT_EQ(shape_3d.mesh_size(), 24);

SimpleMeshShape shape_5d({2, 3, 4, 5, 6});
EXPECT_EQ(shape_5d.dims(), 5);
EXPECT_EQ(shape_5d[0], 2);
EXPECT_EQ(shape_5d[1], 3);
EXPECT_EQ(shape_5d[2], 4);
EXPECT_EQ(shape_5d[3], 5);
EXPECT_EQ(shape_5d[4], 6);
EXPECT_EQ(shape_5d.mesh_size(), 720);
}

TEST(SimpleMeshShapeTest, ZeroShape) {
SimpleMeshShape shape({});
EXPECT_EQ(shape.dims(), 0);
EXPECT_EQ(shape.mesh_size(), 0);
}

TEST(SimpleMeshShapeTest, Strides) {
SimpleMeshShape shape(2, 3, 4);
EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4
EXPECT_EQ(shape.get_stride(1), 4); // 4
EXPECT_EQ(shape.get_stride(2), 1); // 1
}

TEST(SimpleMeshShapeTest, Comparison) {
SimpleMeshShape shape(2, 3);

EXPECT_EQ(shape, SimpleMeshShape(2, 3));
EXPECT_NE(shape, SimpleMeshShape(3, 2));
EXPECT_NE(shape, SimpleMeshShape(1, 2, 3));
}

TEST(MeshCoordinateTest, Construction) {
MeshCoordinate coord_1d(1);
EXPECT_EQ(coord_1d.dims(), 1);
EXPECT_THAT(coord_1d.coords(), ElementsAre(1));
EXPECT_EQ(coord_1d[0], 1);

MeshCoordinate coord_2d(1, 2);
EXPECT_EQ(coord_2d.dims(), 2);
EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2));
EXPECT_EQ(coord_2d[0], 1);
EXPECT_EQ(coord_2d[1], 2);

MeshCoordinate coord_3d(1, 2, 3);
EXPECT_EQ(coord_3d.dims(), 3);
EXPECT_THAT(coord_3d.coords(), ElementsAre(1, 2, 3));
EXPECT_EQ(coord_3d[0], 1);
EXPECT_EQ(coord_3d[1], 2);
EXPECT_EQ(coord_3d[2], 3);

std::vector<uint32_t> values = {1, 2, 3, 4, 5};
MeshCoordinate coord_span(values);
EXPECT_EQ(coord_span.dims(), 5);
EXPECT_THAT(coord_span.coords(), ElementsAre(1, 2, 3, 4, 5));
EXPECT_EQ(coord_span[0], 1);
EXPECT_EQ(coord_span[1], 2);
EXPECT_EQ(coord_span[2], 3);
EXPECT_EQ(coord_span[3], 4);
EXPECT_EQ(coord_span[4], 5);
}

TEST(MeshCoordinateTest, Comparison) {
MeshCoordinate coord1(1, 2);

EXPECT_EQ(coord1, MeshCoordinate(1, 2));
EXPECT_NE(coord1, MeshCoordinate(2, 1));
EXPECT_NE(coord1, MeshCoordinate(1, 2, 1));
}

TEST(MeshCoordinateRangeTest, FromShape) {
SimpleMeshShape shape(2, 3);
MeshCoordinateRange range(shape);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
}

TEST(MeshCoordinateRangeTest, Subrange) {
MeshCoordinate start(1, 1, 1);
MeshCoordinate end(2, 1, 4);
MeshCoordinateRange range(start, end);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(1, 1, 1),
MeshCoordinate(1, 1, 2),
MeshCoordinate(1, 1, 3),
MeshCoordinate(1, 1, 4),
MeshCoordinate(2, 1, 1),
MeshCoordinate(2, 1, 2),
MeshCoordinate(2, 1, 3),
MeshCoordinate(2, 1, 4)));
}

TEST(MeshCoordinateRangeTest, SubrangeOneElement) {
MeshCoordinate start(1, 1, 1);
MeshCoordinate end(1, 1, 1);
MeshCoordinateRange range(start, end);

std::vector<MeshCoordinate> coords;
for (const auto& coord : range) {
coords.push_back(coord);
}

EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1)));
}

TEST(MeshCoordinateRangeTest, MismatchedDimensions) {
MeshCoordinate start(1, 0);
MeshCoordinate end(2, 3, 1);
EXPECT_ANY_THROW(MeshCoordinateRange(start, end));
}

TEST(MeshCoordinateRangeTest, InvalidRange) {
MeshCoordinate start(1, 2, 0);
MeshCoordinate end(1, 1, 1);
EXPECT_ANY_THROW(MeshCoordinateRange(start, end));
}

TEST(ToLinearIndexTest, Basic) {
SimpleMeshShape shape(2, 3);

EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0)), 0);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1)), 1);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 2)), 2);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0)), 3);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1)), 4);
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 2)), 5);
}

TEST(ToLinearIndexTest, MismatchedDimensions) {
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0)));
}

TEST(ToLinearIndexTest, OutOfBounds) {
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0)));
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(0, 3)));
}

TEST(MeshContainerTest, InitialValues) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 3);

std::vector<int> initial_values;
for (const auto& [coord, value] : container) {
initial_values.push_back(value);
}
EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3));
}

TEST(MeshContainerTest, ElementAccessRowMajor) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

container.at(MeshCoordinate(0, 0)) = 0;
container.at(MeshCoordinate(0, 1)) = 1;
container.at(MeshCoordinate(0, 2)) = 2;
container.at(MeshCoordinate(1, 0)) = 3;
container.at(MeshCoordinate(1, 1)) = 4;
container.at(MeshCoordinate(1, 2)) = 5;

std::vector<MeshCoordinate> coords;
std::vector<int> values;
for (const auto& [coord, value] : container) {
coords.push_back(coord);
values.push_back(value);
}
EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, MutateThroughProxy) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

// Proxy class provides access to the container value through the mutable reference.
int updated_value = 0;
for (auto& [coord, value] : container) {
value = updated_value++;
}

// `auto` makes a copy of the value, verify this loop is a no-op.
for (auto [coord, value] : container) {
value = updated_value++;
}

std::vector<int> values;
for (const auto& [coord, value] : container) {
values.push_back(value);
}
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, OutOfBounds) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

EXPECT_ANY_THROW(container.at(MeshCoordinate(2, 0)));
EXPECT_ANY_THROW(container.at(MeshCoordinate(0, 0, 0)));
}

} // namespace
} // namespace tt::tt_metal::distributed
7 changes: 5 additions & 2 deletions tt_metal/api/tt-metalium/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "buffer.hpp"
#include "buffer_constants.hpp"
#include "mesh_coord.hpp"
#include "mesh_device.hpp"
#include "mesh_device_view.hpp"
#include "shape2d.hpp"
Expand Down Expand Up @@ -96,6 +97,7 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; }

std::shared_ptr<Buffer> get_device_buffer(const Coordinate& device_coord) const;
std::shared_ptr<Buffer> get_device_buffer(const MeshCoordinate& device_coord) const;
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;
std::pair<bool, bool> replicated_dims() const;
Expand All @@ -108,6 +110,7 @@ class MeshBuffer {
DeviceAddr device_local_size,
MeshDevice* mesh_device,
std::shared_ptr<Buffer> backing_buffer) :
buffers_(SimpleMeshShape(mesh_device->shape()), nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
Expand All @@ -122,6 +125,7 @@ class MeshBuffer {
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device) :
buffers_(SimpleMeshShape(mesh_device->shape()), /*fill_value=*/nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
Expand All @@ -136,8 +140,7 @@ class MeshBuffer {
DeviceAddr address_ = 0;
DeviceAddr device_local_size_ = 0;

// TODO: Consider optimizing with SmallVector.
std::vector<std::vector<std::shared_ptr<Buffer>>> buffers_;
MeshContainer<std::shared_ptr<Buffer>> buffers_;

// `MeshBufferState` specifies the state of the MeshBuffer. It can either be:
// 1. Owned - a single device buffer is responsible for providing the address for the entire mesh buffer.
Expand Down
Loading
Loading