-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
### Ticket #17477 ### Problem description Existing mesh infra assumes 2D. This assumption won't hold in the future. ### What's changed Introduce a new `SimpleMeshShape` that will gradually replace the existing `MeshShape`, after which it will be renamed to `MeshShape`. Introduce `MeshCoordinate`, `MeshCoordinateRange`, and `MeshContainer` - primitives designed to work with the new ND coordinate system. `MeshContainer` allows efficient flat representation of various metadata that matches the mesh shape. Iterators are available to make it easy to use. `MeshCoordinate` along with strides that are precomputed on `SimpleMeshShape` allows for an easy point access. The integration with `MeshBuffer` demonstrates the use case. Next steps: * Replace the existing `MeshShape`, `MeshOffset`, and the related aliases with the new `SimpleMeshShape`, and `MeshCoordinate`. * No plans to generalize with `CoreCoord`, for now. Cores are fundamentally in 2D, so a more specialized system can be used for efficiency. Also it is not desired to make `CoreCoord` to interop with `MeshCoordinate` - the 2 sets of coordinates mean entirely different concepts. * More functionality might be added, as we continue working on TT-distributed. ### Checklist - [X] [All post commit](https://github.com/tenstorrent/tt-metal/actions/runs/13347753550) - [X] New/Existing tests provide coverage for changes
- Loading branch information
1 parent
907fffd
commit 52c53d5
Showing
14 changed files
with
862 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <gtest/gtest.h> | ||
#include <gmock/gmock.h> | ||
|
||
#include "mesh_coord.hpp" | ||
|
||
namespace tt::tt_metal::distributed { | ||
namespace { | ||
|
||
using ::testing::ElementsAre; | ||
|
||
TEST(SimpleMeshShapeTest, Construction) { | ||
SimpleMeshShape shape_1d(3); | ||
EXPECT_EQ(shape_1d.dims(), 1); | ||
EXPECT_EQ(shape_1d[0], 3); | ||
EXPECT_EQ(shape_1d.mesh_size(), 3); | ||
|
||
SimpleMeshShape shape_2d(3, 4); | ||
EXPECT_EQ(shape_2d.dims(), 2); | ||
EXPECT_EQ(shape_2d[0], 3); | ||
EXPECT_EQ(shape_2d[1], 4); | ||
EXPECT_EQ(shape_2d.mesh_size(), 12); | ||
|
||
SimpleMeshShape shape_3d(2, 3, 4); | ||
EXPECT_EQ(shape_3d.dims(), 3); | ||
EXPECT_EQ(shape_3d[0], 2); | ||
EXPECT_EQ(shape_3d[1], 3); | ||
EXPECT_EQ(shape_3d[2], 4); | ||
EXPECT_EQ(shape_3d.mesh_size(), 24); | ||
|
||
SimpleMeshShape shape_5d({2, 3, 4, 5, 6}); | ||
EXPECT_EQ(shape_5d.dims(), 5); | ||
EXPECT_EQ(shape_5d[0], 2); | ||
EXPECT_EQ(shape_5d[1], 3); | ||
EXPECT_EQ(shape_5d[2], 4); | ||
EXPECT_EQ(shape_5d[3], 5); | ||
EXPECT_EQ(shape_5d[4], 6); | ||
EXPECT_EQ(shape_5d.mesh_size(), 720); | ||
} | ||
|
||
TEST(SimpleMeshShapeTest, ZeroShape) { | ||
SimpleMeshShape shape({}); | ||
EXPECT_EQ(shape.dims(), 0); | ||
EXPECT_EQ(shape.mesh_size(), 0); | ||
} | ||
|
||
TEST(SimpleMeshShapeTest, Strides) { | ||
SimpleMeshShape shape(2, 3, 4); | ||
EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4 | ||
EXPECT_EQ(shape.get_stride(1), 4); // 4 | ||
EXPECT_EQ(shape.get_stride(2), 1); // 1 | ||
} | ||
|
||
TEST(SimpleMeshShapeTest, Comparison) { | ||
SimpleMeshShape shape(2, 3); | ||
|
||
EXPECT_EQ(shape, SimpleMeshShape(2, 3)); | ||
EXPECT_NE(shape, SimpleMeshShape(3, 2)); | ||
EXPECT_NE(shape, SimpleMeshShape(1, 2, 3)); | ||
} | ||
|
||
TEST(MeshCoordinateTest, Construction) { | ||
MeshCoordinate coord_1d(1); | ||
EXPECT_EQ(coord_1d.dims(), 1); | ||
EXPECT_THAT(coord_1d.coords(), ElementsAre(1)); | ||
EXPECT_EQ(coord_1d[0], 1); | ||
|
||
MeshCoordinate coord_2d(1, 2); | ||
EXPECT_EQ(coord_2d.dims(), 2); | ||
EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2)); | ||
EXPECT_EQ(coord_2d[0], 1); | ||
EXPECT_EQ(coord_2d[1], 2); | ||
|
||
MeshCoordinate coord_3d(1, 2, 3); | ||
EXPECT_EQ(coord_3d.dims(), 3); | ||
EXPECT_THAT(coord_3d.coords(), ElementsAre(1, 2, 3)); | ||
EXPECT_EQ(coord_3d[0], 1); | ||
EXPECT_EQ(coord_3d[1], 2); | ||
EXPECT_EQ(coord_3d[2], 3); | ||
|
||
std::vector<uint32_t> values = {1, 2, 3, 4, 5}; | ||
MeshCoordinate coord_span(values); | ||
EXPECT_EQ(coord_span.dims(), 5); | ||
EXPECT_THAT(coord_span.coords(), ElementsAre(1, 2, 3, 4, 5)); | ||
EXPECT_EQ(coord_span[0], 1); | ||
EXPECT_EQ(coord_span[1], 2); | ||
EXPECT_EQ(coord_span[2], 3); | ||
EXPECT_EQ(coord_span[3], 4); | ||
EXPECT_EQ(coord_span[4], 5); | ||
} | ||
|
||
TEST(MeshCoordinateTest, Comparison) { | ||
MeshCoordinate coord1(1, 2); | ||
|
||
EXPECT_EQ(coord1, MeshCoordinate(1, 2)); | ||
EXPECT_NE(coord1, MeshCoordinate(2, 1)); | ||
EXPECT_NE(coord1, MeshCoordinate(1, 2, 1)); | ||
} | ||
|
||
TEST(MeshCoordinateRangeTest, FromShape) { | ||
SimpleMeshShape shape(2, 3); | ||
MeshCoordinateRange range(shape); | ||
|
||
std::vector<MeshCoordinate> coords; | ||
for (const auto& coord : range) { | ||
coords.push_back(coord); | ||
} | ||
|
||
EXPECT_THAT( | ||
coords, | ||
ElementsAre( | ||
MeshCoordinate(0, 0), | ||
MeshCoordinate(0, 1), | ||
MeshCoordinate(0, 2), | ||
MeshCoordinate(1, 0), | ||
MeshCoordinate(1, 1), | ||
MeshCoordinate(1, 2))); | ||
} | ||
|
||
TEST(MeshCoordinateRangeTest, Subrange) { | ||
MeshCoordinate start(1, 1, 1); | ||
MeshCoordinate end(2, 1, 4); | ||
MeshCoordinateRange range(start, end); | ||
|
||
std::vector<MeshCoordinate> coords; | ||
for (const auto& coord : range) { | ||
coords.push_back(coord); | ||
} | ||
|
||
EXPECT_THAT( | ||
coords, | ||
ElementsAre( | ||
MeshCoordinate(1, 1, 1), | ||
MeshCoordinate(1, 1, 2), | ||
MeshCoordinate(1, 1, 3), | ||
MeshCoordinate(1, 1, 4), | ||
MeshCoordinate(2, 1, 1), | ||
MeshCoordinate(2, 1, 2), | ||
MeshCoordinate(2, 1, 3), | ||
MeshCoordinate(2, 1, 4))); | ||
} | ||
|
||
TEST(MeshCoordinateRangeTest, SubrangeOneElement) { | ||
MeshCoordinate start(1, 1, 1); | ||
MeshCoordinate end(1, 1, 1); | ||
MeshCoordinateRange range(start, end); | ||
|
||
std::vector<MeshCoordinate> coords; | ||
for (const auto& coord : range) { | ||
coords.push_back(coord); | ||
} | ||
|
||
EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1))); | ||
} | ||
|
||
TEST(MeshCoordinateRangeTest, MismatchedDimensions) { | ||
MeshCoordinate start(1, 0); | ||
MeshCoordinate end(2, 3, 1); | ||
EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); | ||
} | ||
|
||
TEST(MeshCoordinateRangeTest, InvalidRange) { | ||
MeshCoordinate start(1, 2, 0); | ||
MeshCoordinate end(1, 1, 1); | ||
EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); | ||
} | ||
|
||
TEST(ToLinearIndexTest, Basic) { | ||
SimpleMeshShape shape(2, 2, 3); | ||
|
||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 0)), 0); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 1)), 1); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 2)), 2); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 0)), 3); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 1)), 4); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 2)), 5); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 0)), 6); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 1)), 7); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 2)), 8); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 0)), 9); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 1)), 10); | ||
EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 2)), 11); | ||
} | ||
|
||
TEST(ToLinearIndexTest, MismatchedDimensions) { | ||
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0))); | ||
} | ||
|
||
TEST(ToLinearIndexTest, OutOfBounds) { | ||
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0))); | ||
EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(0, 3))); | ||
} | ||
|
||
TEST(MeshContainerTest, InitialValues) { | ||
SimpleMeshShape shape(2, 3); | ||
MeshContainer<int> container(shape, 3); | ||
|
||
std::vector<int> initial_values; | ||
for (const auto& [_, value] : container) { | ||
initial_values.push_back(value); | ||
} | ||
EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3)); | ||
} | ||
|
||
TEST(MeshContainerTest, ElementAccessRowMajor) { | ||
SimpleMeshShape shape(2, 3); | ||
MeshContainer<int> container(shape, 0); | ||
|
||
container.at(MeshCoordinate(0, 0)) = 0; | ||
container.at(MeshCoordinate(0, 1)) = 1; | ||
container.at(MeshCoordinate(0, 2)) = 2; | ||
container.at(MeshCoordinate(1, 0)) = 3; | ||
container.at(MeshCoordinate(1, 1)) = 4; | ||
container.at(MeshCoordinate(1, 2)) = 5; | ||
|
||
std::vector<MeshCoordinate> coords; | ||
std::vector<int> values; | ||
for (const auto& [coord, value] : container) { | ||
coords.push_back(coord); | ||
values.push_back(value); | ||
} | ||
EXPECT_THAT( | ||
coords, | ||
ElementsAre( | ||
MeshCoordinate(0, 0), | ||
MeshCoordinate(0, 1), | ||
MeshCoordinate(0, 2), | ||
MeshCoordinate(1, 0), | ||
MeshCoordinate(1, 1), | ||
MeshCoordinate(1, 2))); | ||
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); | ||
} | ||
|
||
TEST(MeshContainerTest, ConstContainer) { | ||
SimpleMeshShape shape(2, 3); | ||
const MeshContainer<int> container(shape, 0); | ||
|
||
std::vector<MeshCoordinate> coords; | ||
std::vector<int> values; | ||
for (const auto& [coord, value] : container) { | ||
coords.push_back(coord); | ||
values.push_back(value); | ||
} | ||
EXPECT_THAT( | ||
coords, | ||
ElementsAre( | ||
MeshCoordinate(0, 0), | ||
MeshCoordinate(0, 1), | ||
MeshCoordinate(0, 2), | ||
MeshCoordinate(1, 0), | ||
MeshCoordinate(1, 1), | ||
MeshCoordinate(1, 2))); | ||
EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0)); | ||
} | ||
|
||
TEST(MeshContainerTest, MutateThroughProxy) { | ||
SimpleMeshShape shape(2, 3); | ||
MeshContainer<int> container(shape, 0); | ||
|
||
// Proxy class provides access to the container value through the mutable reference. | ||
int updated_value = 0; | ||
for (auto& [_, value] : container) { | ||
value = updated_value++; | ||
} | ||
|
||
// `auto` makes a copy of the value, verify this loop is a no-op. | ||
for (auto [_, value] : container) { | ||
value = updated_value++; | ||
} | ||
|
||
std::vector<int> values; | ||
for (const auto& [_, value] : container) { | ||
values.push_back(value); | ||
} | ||
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); | ||
} | ||
|
||
TEST(MeshContainerTest, OutOfBounds) { | ||
SimpleMeshShape shape(2, 3); | ||
MeshContainer<int> container(shape, 0); | ||
|
||
EXPECT_ANY_THROW(container.at(MeshCoordinate(2, 0))); | ||
EXPECT_ANY_THROW(container.at(MeshCoordinate(0, 0, 0))); | ||
} | ||
|
||
} // namespace | ||
} // namespace tt::tt_metal::distributed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.