From d53909a6c964f3d65b697477ef8b42c84905fe58 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 7 Feb 2025 07:05:23 +0000 Subject: [PATCH 01/10] Remove std::span dependency from shapes --- tt_metal/api/tt-metalium/shape_base.hpp | 6 +++--- tt_metal/common/shape_base.cpp | 2 +- ttnn/cpp/pybind11/pytensor.cpp | 6 ++---- ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp | 5 +---- .../repeat/device/host/repeat_program_factory.cpp | 8 ++++---- 5 files changed, 11 insertions(+), 16 deletions(-) diff --git a/tt_metal/api/tt-metalium/shape_base.hpp b/tt_metal/api/tt-metalium/shape_base.hpp index 350e8833d82..cb207b79794 100644 --- a/tt_metal/api/tt-metalium/shape_base.hpp +++ b/tt_metal/api/tt-metalium/shape_base.hpp @@ -5,9 +5,9 @@ #pragma once #include -#include #include "small_vector.hpp" +#include "span.hpp" namespace tt::tt_metal { @@ -24,7 +24,7 @@ class ShapeBase { explicit ShapeBase(const std::array& arr) : value_(arr.begin(), arr.end()) { init(); } - explicit ShapeBase(std::span span) : value_(span.begin(), span.end()) { init(); } + explicit ShapeBase(tt::stl::Span span) : value_(span.begin(), span.end()) { init(); } template bool operator==(const std::array& other) const { @@ -42,7 +42,7 @@ class ShapeBase { Container::const_iterator cbegin() const; Container::const_iterator cend() const; - std::span view() const; + tt::stl::Span view() const; bool empty() const; diff --git a/tt_metal/common/shape_base.cpp b/tt_metal/common/shape_base.cpp index 57e69bb49e6..fdb94487ed8 100644 --- a/tt_metal/common/shape_base.cpp +++ b/tt_metal/common/shape_base.cpp @@ -46,7 +46,7 @@ bool ShapeBase::empty() const { return original_size_ == 0; } size_t ShapeBase::size() const { return original_size_; } -std::span ShapeBase::view() const { return std::span(cbegin(), cend()); } +tt::stl::Span ShapeBase::view() const { return tt::stl::Span(value_); } bool ShapeBase::operator==(const ShapeBase& other) const = default; diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index f6e55603d8a..c26294362e1 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -520,8 +520,7 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor, const bool buffer); auto logical_shape = tt_tensor.get_logical_shape(); - auto view = logical_shape.view(); - std::vector torch_shape(view.begin(), view.end()); + std::vector torch_shape(logical_shape.cbegin(), logical_shape.cend()); auto tensor = [&]() { if (tt_tensor.volume() == 0) { auto pytorch_empty = torch.attr("empty"); @@ -580,8 +579,7 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) { buffer); auto logical_shape = tt_tensor.get_logical_shape(); - auto view = logical_shape.view(); - std::vector np_shape(view.begin(), view.end()); + std::vector np_shape(logical_shape.cbegin(), logical_shape.cend()); auto tensor = frombuffer(buffer, py::arg("dtype") = np_dtype); tensor = tensor.attr("reshape")(np_shape); tensor = np.attr("ascontiguousarray")(tensor); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index b5232f2c464..9e4382f3d73 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -14,10 +14,7 @@ namespace ttnn::operations::data_movement { namespace { -template -bool eq_spans(const ArrayType& a, const ArrayType& b) { - return std::equal(a.begin(), a.end(), b.begin(), b.end()); -} +bool eq_spans(const auto a, const auto b) { return std::equal(a.begin(), a.end(), b.begin(), b.end()); } ttnn::Shape update_original_shape(const ttnn::Shape& padded_shape, const ttnn::Shape& input_shape) { ttnn::SmallVector updated_shape; diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp index e8266b2ee50..cc28c0610d0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp @@ -39,8 +39,8 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim( uint32_t num_cores_y = compute_with_storage_grid_size.y; uint32_t num_cores_total = num_cores_x * num_cores_y; CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); - ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + ttnn::Shape input_log_shape = input.get_logical_shape(); + ttnn::Shape output_log_shape = output.get_logical_shape(); tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); @@ -139,8 +139,8 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater( uint32_t num_cores_total = num_cores_x * num_cores_y; CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); - ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + ttnn::Shape input_log_shape = input.get_logical_shape(); + ttnn::Shape output_log_shape = output.get_logical_shape(); tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); From a9b0079b2e47b6ddf41b248b9e80ce7b4135f028 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 7 Feb 2025 22:38:57 +0000 Subject: [PATCH 02/10] ND mesh coordinate system --- tests/tt_metal/distributed/CMakeLists.txt | 1 + .../tt_metal/distributed/test_mesh_coord.cpp | 199 +++++++++++++++ tt_metal/api/tt-metalium/mesh_buffer.hpp | 7 +- tt_metal/api/tt-metalium/mesh_coord.hpp | 230 ++++++++++++++++++ tt_metal/api/tt-metalium/mesh_device.hpp | 2 + tt_metal/common/CMakeLists.txt | 1 + tt_metal/common/mesh_coord.cpp | 161 ++++++++++++ tt_metal/distributed/mesh_buffer.cpp | 28 +-- tt_metal/distributed/mesh_device.cpp | 6 +- 9 files changed, 615 insertions(+), 20 deletions(-) create mode 100644 tests/tt_metal/distributed/test_mesh_coord.cpp create mode 100644 tt_metal/api/tt-metalium/mesh_coord.hpp create mode 100644 tt_metal/common/mesh_coord.cpp diff --git a/tests/tt_metal/distributed/CMakeLists.txt b/tests/tt_metal/distributed/CMakeLists.txt index 27bb9ee7b53..08fededb592 100644 --- a/tests/tt_metal/distributed/CMakeLists.txt +++ b/tests/tt_metal/distributed/CMakeLists.txt @@ -1,6 +1,7 @@ set(UNIT_TESTS_DISTRIBUTED_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_coord.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_sub_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp new file mode 100644 index 00000000000..d6284e64467 --- /dev/null +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "mesh_coord.hpp" + +namespace tt::tt_metal::distributed { +namespace { + +using ::testing::ElementsAre; + +TEST(SimpleMeshShapeTest, Construction) { + SimpleMeshShape shape_2d(3, 4); + EXPECT_EQ(shape_2d.dims(), 2); + EXPECT_EQ(shape_2d[0], 3); + EXPECT_EQ(shape_2d[1], 4); + EXPECT_EQ(shape_2d.mesh_size(), 12); + + SimpleMeshShape shape_3d(2, 3, 4); + EXPECT_EQ(shape_3d.dims(), 3); + EXPECT_EQ(shape_3d[0], 2); + EXPECT_EQ(shape_3d[1], 3); + EXPECT_EQ(shape_3d[2], 4); + EXPECT_EQ(shape_3d.mesh_size(), 24); + + SimpleMeshShape shape_5d({2, 3, 4, 5, 6}); + EXPECT_EQ(shape_5d.dims(), 5); + EXPECT_EQ(shape_5d[0], 2); + EXPECT_EQ(shape_5d[1], 3); + EXPECT_EQ(shape_5d[2], 4); + EXPECT_EQ(shape_5d[3], 5); + EXPECT_EQ(shape_5d[4], 6); + EXPECT_EQ(shape_5d.mesh_size(), 720); +} + +TEST(SimpleMeshShapeTest, Strides) { + SimpleMeshShape shape(2, 3, 4); + EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4 + EXPECT_EQ(shape.get_stride(1), 4); // 4 + EXPECT_EQ(shape.get_stride(2), 1); // 1 +} + +TEST(SimpleMeshShapeTest, Comparison) { + SimpleMeshShape shape1(2, 3); + SimpleMeshShape shape2(2, 3); + SimpleMeshShape shape3(3, 2); + + EXPECT_EQ(shape1, shape2); + EXPECT_NE(shape1, shape3); +} + +TEST(MeshCoordinateTest, Construction) { + MeshCoordinate coord_2d(1, 2); + EXPECT_EQ(coord_2d.dims(), 2); + EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2)); + EXPECT_EQ(coord_2d[0], 1); + EXPECT_EQ(coord_2d[1], 2); + + MeshCoordinate coord_3d(1, 2, 3); + EXPECT_EQ(coord_3d.dims(), 3); + EXPECT_THAT(coord_3d.coords(), ElementsAre(1, 2, 3)); + EXPECT_EQ(coord_3d[0], 1); + EXPECT_EQ(coord_3d[1], 2); + EXPECT_EQ(coord_3d[2], 3); + + std::vector values = {1, 2, 3, 4, 5}; + MeshCoordinate coord_span(values); + EXPECT_EQ(coord_span.dims(), 5); + EXPECT_THAT(coord_span.coords(), ElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(coord_span[0], 1); + EXPECT_EQ(coord_span[1], 2); + EXPECT_EQ(coord_span[2], 3); + EXPECT_EQ(coord_span[3], 4); + EXPECT_EQ(coord_span[4], 5); +} + +TEST(MeshCoordinateTest, Comparison) { + MeshCoordinate coord1(1, 2); + MeshCoordinate coord2(1, 2); + MeshCoordinate coord3(2, 1); + + EXPECT_EQ(coord1, coord2); + EXPECT_NE(coord1, coord3); +} + +TEST(MeshCoordinateRangeTest, FromShape) { + SimpleMeshShape shape(2, 3); + MeshCoordinateRange range(shape); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); +} + +TEST(MeshCoordinateRangeTest, Subrange) { + MeshCoordinate start(1, 0); + MeshCoordinate end(2, 2); + MeshCoordinateRange range(start, end); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2), + MeshCoordinate(2, 0), + MeshCoordinate(2, 1), + MeshCoordinate(2, 2))); +} + +TEST(MeshCoordinateRangeTest, MismatchedDimensions) { + MeshCoordinate start(1, 0); + MeshCoordinate end(2, 3, 1); + EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); +} + +TEST(MeshCoordinateRangeTest, InvalidRange) { + MeshCoordinate start(1, 2, 0); + MeshCoordinate end(1, 1, 1); + EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); +} + +TEST(ToLinearIndexTest, Basic) { + SimpleMeshShape shape(2, 3); + + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0)), 0); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1)), 1); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 2)), 2); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0)), 3); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1)), 4); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 2)), 5); +} + +TEST(ToLinearIndexTest, MismatchedDimensions) { + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0))); +} + +TEST(ToLinearIndexTest, OutOfBounds) { + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0))); +} + +TEST(MeshContainerTest, InitialValues) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 3); + + std::vector initial_values; + for (const auto& [coord, value] : container) { + initial_values.push_back(value); + } + EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3)); +} + +TEST(MeshContainerTest, ElementAccessRowMajor) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + container.at(MeshCoordinate(0, 0)) = 0; + container.at(MeshCoordinate(0, 1)) = 1; + container.at(MeshCoordinate(0, 2)) = 2; + container.at(MeshCoordinate(1, 0)) = 3; + container.at(MeshCoordinate(1, 1)) = 4; + container.at(MeshCoordinate(1, 2)) = 5; + + std::vector values; + for (const auto& [coord, value] : container) { + values.push_back(value); + } + EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +TEST(MeshContainerTest, OutOfBounds) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + EXPECT_ANY_THROW(container.at(MeshCoordinate(2, 0))); + EXPECT_ANY_THROW(container.at(MeshCoordinate(0, 0, 0))); +} + +} // namespace +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/api/tt-metalium/mesh_buffer.hpp b/tt_metal/api/tt-metalium/mesh_buffer.hpp index 0e029685b47..8656fc02e67 100644 --- a/tt_metal/api/tt-metalium/mesh_buffer.hpp +++ b/tt_metal/api/tt-metalium/mesh_buffer.hpp @@ -6,6 +6,7 @@ #include "buffer.hpp" #include "buffer_constants.hpp" +#include "mesh_coord.hpp" #include "mesh_device.hpp" #include "mesh_device_view.hpp" #include "shape2d.hpp" @@ -96,6 +97,7 @@ class MeshBuffer { const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; } std::shared_ptr get_device_buffer(const Coordinate& device_coord) const; + std::shared_ptr get_device_buffer(const MeshCoordinate& device_coord) const; uint32_t datum_size_bytes() const; Shape2D physical_shard_shape() const; std::pair replicated_dims() const; @@ -108,6 +110,7 @@ class MeshBuffer { DeviceAddr device_local_size, MeshDevice* mesh_device, std::shared_ptr backing_buffer) : + buffers_(SimpleMeshShape(mesh_device->shape()), nullptr), config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device), @@ -122,6 +125,7 @@ class MeshBuffer { DeviceAddr address, DeviceAddr device_local_size, MeshDevice* mesh_device) : + buffers_(SimpleMeshShape(mesh_device->shape()), /*fill_value=*/nullptr), config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device), @@ -136,8 +140,7 @@ class MeshBuffer { DeviceAddr address_ = 0; DeviceAddr device_local_size_ = 0; - // TODO: Consider optimizing with SmallVector. - std::vector>> buffers_; + MeshContainer> buffers_; // `MeshBufferState` specifies the state of the MeshBuffer. It can either be: // 1. Owned - a single device buffer is responsible for providing the address for the entire mesh buffer. diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp new file mode 100644 index 00000000000..16556d41e3d --- /dev/null +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -0,0 +1,230 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "shape_base.hpp" + +namespace tt::tt_metal::distributed { + +struct MeshShape; + +// TODO: #17477 - Rename to `MeshShape` when the legacy type is gone. +class SimpleMeshShape : public ShapeBase { +public: + using ShapeBase::ShapeBase; + using ShapeBase::operator[]; + using ShapeBase::cbegin; + using ShapeBase::cend; + using ShapeBase::empty; + using ShapeBase::size; + using ShapeBase::view; + + // Shorthands for constructing 2D and 3D shapes. + SimpleMeshShape(uint32_t num_rows, uint32_t num_cols); + SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z); + + // Temporary constructor for transitioning to `SimpleMeshShape`. + SimpleMeshShape(const MeshShape& legacy_shape); + + // Returns the dimensionality of the mesh. + size_t dims() const; + + // Returns the stride for the given dimension. + size_t get_stride(size_t dim) const; + + // Returns the total number of elements in the mesh. + size_t mesh_size() const; + + // Needed for reflect / fmt + static constexpr auto attribute_names = std::forward_as_tuple("value"); + auto attribute_values() const { return std::forward_as_tuple(value_); } + + friend bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs); + friend bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs); + friend std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape); + +private: + void compute_strides(); + tt::stl::SmallVector strides_; +}; + +class MeshCoordinate { +public: + // Shorthands for constructing 2D and 3D coordinates. + MeshCoordinate(uint32_t row, uint32_t col); + MeshCoordinate(uint32_t x, uint32_t y, uint32_t z); + + // Constructs a generic N-dimensional coordinate. + explicit MeshCoordinate(tt::stl::Span coords); + + // Returns the dimensionality of the coordinate. + size_t dims() const; + + // Returns the coordinate values as a span. + tt::stl::Span coords() const; + + // Returns the coordinate value at the given index. + uint32_t operator[](size_t dim) const; + + // Needed for reflect / fmt + static constexpr auto attribute_names = std::forward_as_tuple("value"); + auto attribute_values() const { return std::forward_as_tuple(value_); } + + friend bool operator==(const MeshCoordinate& lhs, const MeshCoordinate& rhs); + friend bool operator!=(const MeshCoordinate& lhs, const MeshCoordinate& rhs); + friend std::ostream& operator<<(std::ostream& os, const MeshCoordinate& shape); + +private: + tt::stl::SmallVector value_; +}; + +// Converts a MeshCoordinate to a linear index. +// Throws if `coord` is out of bounds of `shape`. +size_t to_linear_index(const SimpleMeshShape& shape, const MeshCoordinate& coord); + +// Represents a range of MeshCoordinates. Requires that mesh coordinates have the same dimensionality. +class MeshCoordinateRange { +public: + // Constructs an inclusive range that iterates between `start` and `end`. + MeshCoordinateRange(const MeshCoordinate& start, const MeshCoordinate& end); + + // Constructs a range that iterates over all coordinates in the mesh. + MeshCoordinateRange(const SimpleMeshShape& shape); + + const MeshCoordinate& start_coord() const; + const MeshCoordinate& end_coord() const; + + class Iterator { + public: + Iterator& operator++(); + MeshCoordinate operator*() const; + bool operator==(const Iterator& other) const; + bool operator!=(const Iterator& other) const; + + private: + Iterator(const MeshCoordinateRange* range, const MeshCoordinate& current_coord, size_t linear_index); + friend class MeshCoordinateRange; + + const MeshCoordinateRange* range_ = nullptr; + + // For simplicity, rely on `linear_index_` for the iterator boundary check, and allow + // MeshCoordinate to wrap around the range end. + MeshCoordinate current_coord_; + size_t linear_index_ = 0; + }; + + Iterator begin() const; + Iterator end() const; + + friend bool operator==(const MeshCoordinateRange& lhs, const MeshCoordinateRange& rhs); + friend bool operator!=(const MeshCoordinateRange& lhs, const MeshCoordinateRange& rhs); + +private: + MeshCoordinate start_; + MeshCoordinate end_; +}; + +// Allows storing data in a mesh-shaped container, with convenient accessors and iterators. +template +class MeshContainer { +public: + MeshContainer(const SimpleMeshShape& shape, const T& fill_value); + + // Returns a shape of the container. + const SimpleMeshShape& shape() const; + + // Accessor methods. + T& at(const MeshCoordinate& coord); + const T& at(const MeshCoordinate& coord) const; + + // Allows to iterate over the container elements, returning a pair of (coordinate, value reference). + class Iterator { + public: + using value_type = std::pair; + using reference = std::pair>; + + Iterator& operator++(); + reference operator*() const; + bool operator==(const Iterator& other) const; + bool operator!=(const Iterator& other) const; + + private: + Iterator(MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index); + friend class MeshContainer; + + MeshContainer* container_ = nullptr; + MeshCoordinateRange::Iterator coord_iter_; + size_t linear_index_ = 0; + }; + + Iterator begin(); + Iterator end(); + +private: + SimpleMeshShape shape_; + MeshCoordinateRange coord_range_; + std::vector values_; +}; + +template +MeshContainer::MeshContainer(const SimpleMeshShape& shape, const T& fill_value) : + shape_(shape), coord_range_(shape), values_(shape.mesh_size(), fill_value) {} + +template +const SimpleMeshShape& MeshContainer::shape() const { + return shape_; +} + +template +T& MeshContainer::at(const MeshCoordinate& coord) { + return values_.at(to_linear_index(shape_, coord)); +} + +template +const T& MeshContainer::at(const MeshCoordinate& coord) const { + return values_.at(to_linear_index(shape_, coord)); +} + +template +MeshContainer::Iterator::Iterator( + MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : + container_(container), coord_iter_(coord_iter), linear_index_(linear_index) {} + +template +typename MeshContainer::Iterator& MeshContainer::Iterator::operator++() { + ++linear_index_; + ++coord_iter_; + return *this; +} + +template +typename MeshContainer::Iterator::reference MeshContainer::Iterator::operator*() const { + return {*coord_iter_, std::ref(container_->values_[linear_index_])}; +} + +template +bool MeshContainer::Iterator::operator==(const Iterator& other) const { + return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; +} + +template +bool MeshContainer::Iterator::operator!=(const Iterator& other) const { + return !(*this == other); +} + +template +typename MeshContainer::Iterator MeshContainer::begin() { + return Iterator(this, coord_range_.begin(), /* linear_index = */ 0); +} + +template +typename MeshContainer::Iterator MeshContainer::end() { + return Iterator(this, coord_range_.end(), shape_.mesh_size()); +} + +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 91638a57cb6..979e603a6cd 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -12,6 +12,7 @@ #include "device.hpp" #include "mesh_config.hpp" +#include "mesh_coord.hpp" #include "mesh_device_view.hpp" #include "sub_device_types.hpp" #include "span.hpp" @@ -204,6 +205,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this + +#include +#include +#include +#include +#include +#include + +namespace tt::tt_metal::distributed { +namespace { + +// Returns a zero coordinate of dimensionality `dims`. +MeshCoordinate zero_coordinate(size_t dims) { return MeshCoordinate(tt::stl::SmallVector(dims, 0)); } + +// Returns the last valid coordinate for the provided `shape`. +MeshCoordinate shape_back(const SimpleMeshShape& shape) { + tt::stl::SmallVector coords; + for (int i = 0; i < shape.dims(); i++) { + coords.push_back(shape[i] - 1); + } + return MeshCoordinate(coords); +} + +} // namespace + +SimpleMeshShape::SimpleMeshShape(uint32_t num_rows, uint32_t num_cols) : ShapeBase({num_rows, num_cols}) { + compute_strides(); +} +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) { compute_strides(); } + +SimpleMeshShape::SimpleMeshShape(const MeshShape& legacy_shape) : + SimpleMeshShape(legacy_shape.num_rows, legacy_shape.num_cols) {} + +void SimpleMeshShape::compute_strides() { + size_t stride = 1; + strides_.resize(dims()); + for (int dim = dims() - 1; dim >= 0; --dim) { + strides_[dim] = stride; + stride *= (*this)[dim]; + } +} + +size_t SimpleMeshShape::get_stride(size_t dim) const { return strides_[dim]; } + +size_t SimpleMeshShape::dims() const { return size(); } +size_t SimpleMeshShape::mesh_size() const { + return std::accumulate(value_.begin(), value_.end(), 1, std::multiplies()); +} + +bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) { return lhs.value_ == rhs.value_; } +bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) { return !(lhs == rhs); } + +std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape) { + os << "SimpleMeshShape(["; + for (size_t i = 0; i < shape.dims(); ++i) { + if (i > 0) { + os << ", "; + } + os << shape[i]; + } + os << "])"; + return os; +} + +MeshCoordinate::MeshCoordinate(uint32_t row, uint32_t col) : value_({row, col}) {} +MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y, uint32_t z) : value_({x, y, z}) {} + +MeshCoordinate::MeshCoordinate(tt::stl::Span coords) : value_(coords.begin(), coords.end()) {} + +size_t MeshCoordinate::dims() const { return value_.size(); } +tt::stl::Span MeshCoordinate::coords() const { return value_; } +uint32_t MeshCoordinate::operator[](size_t dim) const { return value_[dim]; } + +bool operator==(const MeshCoordinate& lhs, const MeshCoordinate& rhs) { + return lhs.dims() == rhs.dims() && std::equal(lhs.coords().begin(), lhs.coords().end(), rhs.coords().begin()); +} +bool operator!=(const MeshCoordinate& lhs, const MeshCoordinate& rhs) { return !(lhs == rhs); } + +std::ostream& operator<<(std::ostream& os, const MeshCoordinate& coord) { + os << "MeshCoordinate(" << coord.dims() << ", ["; + for (size_t dim : coord.coords()) { + os << dim << ", "; + } + os << "])"; + return os; +} + +MeshCoordinateRange::MeshCoordinateRange(const MeshCoordinate& start, const MeshCoordinate& end) : + start_(start), end_(end) { + TT_FATAL( + start.dims() == end.dims(), + "Start and end dimensions of a coordinate range do not match: {} != {}", + start.dims(), + end.dims()); + for (size_t i = 0; i < start.dims(); ++i) { + TT_FATAL(start[i] <= end[i], "Start coordinate is greater than end coordinate: {} > {}", start, end); + } +} + +MeshCoordinateRange::MeshCoordinateRange(const SimpleMeshShape& shape) : + MeshCoordinateRange(zero_coordinate(shape.dims()), shape_back(shape)) {} + +const MeshCoordinate& MeshCoordinateRange::start_coord() const { return start_; } +const MeshCoordinate& MeshCoordinateRange::end_coord() const { return end_; } + +MeshCoordinateRange::Iterator::Iterator( + const MeshCoordinateRange* range, const MeshCoordinate& current, size_t linear_index) : + range_(range), current_coord_(current), linear_index_(linear_index) {} + +MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() { + ++linear_index_; + + tt::stl::SmallVector new_coords(current_coord_.coords().begin(), current_coord_.coords().end()); + for (int i = new_coords.size() - 1; i >= 0; --i) { + auto& dimension_value = new_coords[i]; + if (++dimension_value > range_->end_coord()[i]) { + dimension_value = 0; + } else { + break; + } + } + current_coord_ = MeshCoordinate(new_coords); + return *this; +} +MeshCoordinate MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } +bool MeshCoordinateRange::Iterator::operator==(const Iterator& other) const { + return range_ == other.range_ && linear_index_ == other.linear_index_; +} +bool MeshCoordinateRange::Iterator::operator!=(const Iterator& other) const { return !(*this == other); } + +MeshCoordinateRange::Iterator MeshCoordinateRange::begin() const { return Iterator(this, start_, /*linear_index=*/0); } +MeshCoordinateRange::Iterator MeshCoordinateRange::end() const { + size_t range_size = 1; + for (size_t i = 0; i < start_.dims(); ++i) { + range_size *= end_[i] - start_[i] + 1; + } + // Set `start_` coordinate but `range_size` linear index as the wrap around condition. + return Iterator(this, start_, range_size); +} + +size_t to_linear_index(const SimpleMeshShape& shape, const MeshCoordinate& coord) { + TT_FATAL( + shape.dims() == coord.dims(), + "Shape and coordinate dimensions do not match: {} != {}", + shape.dims(), + coord.dims()); + + size_t linear_index = 0; + for (size_t dim = 0; dim < coord.dims(); ++dim) { + TT_FATAL(coord[dim] < shape[dim], "Coordinate {} is out of bounds for shape {}", coord, shape); + linear_index += coord[dim] * shape.get_stride(dim); + } + return linear_index; +} + +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index a0bf7b76e86..660a0fe8529 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -4,6 +4,8 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include +#include #include #include @@ -110,12 +112,9 @@ std::shared_ptr MeshBuffer::create( } void MeshBuffer::initialize_device_buffers() { - buffers_ = std::vector>>( - mesh_device_->num_rows(), std::vector>(mesh_device_->num_cols())); - - auto init_device_buffer_at_address = [this](const Coordinate& coord) { + auto init_device_buffer_at_address = [this](const MeshCoordinate& coord) { std::shared_ptr buffer = Buffer::create( - mesh_device_->get_device(coord.row, coord.col), + mesh_device_->get_device(coord), address_, device_local_size_, device_local_config_.page_size, @@ -126,10 +125,8 @@ void MeshBuffer::initialize_device_buffers() { return buffer; }; - for (int row = 0; row < mesh_device_->num_rows(); row++) { - for (int col = 0; col < mesh_device_->num_cols(); col++) { - buffers_[row][col] = init_device_buffer_at_address(Coordinate{row, col}); - } + for (auto [coord, device_buffer] : buffers_) { + device_buffer.get() = init_device_buffer_at_address(coord); } } @@ -138,14 +135,11 @@ bool MeshBuffer::is_allocated() const { return not std::holds_alternative MeshBuffer::get_device_buffer(const Coordinate& device_coord) const { - TT_FATAL( - device_coord.row < mesh_device_->num_rows() and device_coord.col < mesh_device_->num_cols(), - "Logical coordinates must be within the bounds of the mesh: {}, {}, mesh shape: {}, {}", - device_coord.row, - device_coord.col, - mesh_device_->num_rows(), - mesh_device_->num_cols()); - return buffers_[device_coord.row][device_coord.col]; + return get_device_buffer(MeshCoordinate(device_coord.row, device_coord.col)); +} + +std::shared_ptr MeshBuffer::get_device_buffer(const MeshCoordinate& device_coord) const { + return buffers_.at(device_coord); } DeviceAddr MeshBuffer::size() const { diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 04edd94373b..603ce95212e 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -211,7 +211,11 @@ std::vector MeshDevice::get_devices() const { return view_->get_device // TODO: Remove this function once we have a proper view interface IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { - return this->get_device_index(row_idx * num_cols() + col_idx); + return get_device(MeshCoordinate{row_idx, col_idx}); +} + +IDevice* MeshDevice::get_device(const MeshCoordinate& coord) const { + return this->get_device_index(to_linear_index(SimpleMeshShape(mesh_shape_), coord)); } MeshCommandQueue& MeshDevice::mesh_command_queue(std::size_t cq_id) const { From 1eb08d4ccc58e0b47e674215512930b2e2400c27 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 7 Feb 2025 22:59:08 +0000 Subject: [PATCH 03/10] Revert ttnn changes --- ttnn/cpp/pybind11/pytensor.cpp | 6 ++++-- .../repeat/device/host/repeat_program_factory.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index c26294362e1..f6e55603d8a 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -520,7 +520,8 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor, const bool buffer); auto logical_shape = tt_tensor.get_logical_shape(); - std::vector torch_shape(logical_shape.cbegin(), logical_shape.cend()); + auto view = logical_shape.view(); + std::vector torch_shape(view.begin(), view.end()); auto tensor = [&]() { if (tt_tensor.volume() == 0) { auto pytorch_empty = torch.attr("empty"); @@ -579,7 +580,8 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) { buffer); auto logical_shape = tt_tensor.get_logical_shape(); - std::vector np_shape(logical_shape.cbegin(), logical_shape.cend()); + auto view = logical_shape.view(); + std::vector np_shape(view.begin(), view.end()); auto tensor = frombuffer(buffer, py::arg("dtype") = np_dtype); tensor = tensor.attr("reshape")(np_shape); tensor = np.attr("ascontiguousarray")(tensor); diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp index cc28c0610d0..e8266b2ee50 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp @@ -39,8 +39,8 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim( uint32_t num_cores_y = compute_with_storage_grid_size.y; uint32_t num_cores_total = num_cores_x * num_cores_y; CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - ttnn::Shape input_log_shape = input.get_logical_shape(); - ttnn::Shape output_log_shape = output.get_logical_shape(); + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); @@ -139,8 +139,8 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater( uint32_t num_cores_total = num_cores_x * num_cores_y; CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - ttnn::Shape input_log_shape = input.get_logical_shape(); - ttnn::Shape output_log_shape = output.get_logical_shape(); + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); tt::log_debug("row major reshape"); tt::log_debug("input shape: {}", input_log_shape); tt::log_debug("output shape: {}", output_log_shape); From 800e1f3cd47ccd1d2edcf742f56ee08f191cac80 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 7 Feb 2025 23:37:54 +0000 Subject: [PATCH 04/10] Remove unnecessary alias --- tt_metal/api/tt-metalium/mesh_coord.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index 16556d41e3d..6a0f6496c3b 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -145,7 +145,6 @@ class MeshContainer { // Allows to iterate over the container elements, returning a pair of (coordinate, value reference). class Iterator { public: - using value_type = std::pair; using reference = std::pair>; Iterator& operator++(); From 86dc0cc0a72a2b33799eee2ce2616a391de490b1 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sat, 8 Feb 2025 19:19:04 +0000 Subject: [PATCH 05/10] Better tests, 1D case --- .../tt_metal/distributed/test_mesh_coord.cpp | 68 +++++++++++++------ tt_metal/api/tt-metalium/mesh_coord.hpp | 6 +- tt_metal/common/mesh_coord.cpp | 10 +-- 3 files changed, 59 insertions(+), 25 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index d6284e64467..80f3f299198 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -13,6 +13,11 @@ namespace { using ::testing::ElementsAre; TEST(SimpleMeshShapeTest, Construction) { + SimpleMeshShape shape_1d(3); + EXPECT_EQ(shape_1d.dims(), 1); + EXPECT_EQ(shape_1d[0], 3); + EXPECT_EQ(shape_1d.mesh_size(), 3); + SimpleMeshShape shape_2d(3, 4); EXPECT_EQ(shape_2d.dims(), 2); EXPECT_EQ(shape_2d[0], 3); @@ -36,6 +41,12 @@ TEST(SimpleMeshShapeTest, Construction) { EXPECT_EQ(shape_5d.mesh_size(), 720); } +TEST(SimpleMeshShapeTest, ZeroShape) { + SimpleMeshShape shape({}); + EXPECT_EQ(shape.dims(), 0); + EXPECT_EQ(shape.mesh_size(), 0); +} + TEST(SimpleMeshShapeTest, Strides) { SimpleMeshShape shape(2, 3, 4); EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4 @@ -44,15 +55,19 @@ TEST(SimpleMeshShapeTest, Strides) { } TEST(SimpleMeshShapeTest, Comparison) { - SimpleMeshShape shape1(2, 3); - SimpleMeshShape shape2(2, 3); - SimpleMeshShape shape3(3, 2); + SimpleMeshShape shape(2, 3); - EXPECT_EQ(shape1, shape2); - EXPECT_NE(shape1, shape3); + EXPECT_EQ(shape, SimpleMeshShape(2, 3)); + EXPECT_NE(shape, SimpleMeshShape(3, 2)); + EXPECT_NE(shape, SimpleMeshShape(1, 2, 3)); } TEST(MeshCoordinateTest, Construction) { + MeshCoordinate coord_1d(1); + EXPECT_EQ(coord_1d.dims(), 1); + EXPECT_THAT(coord_1d.coords(), ElementsAre(1)); + EXPECT_EQ(coord_1d[0], 1); + MeshCoordinate coord_2d(1, 2); EXPECT_EQ(coord_2d.dims(), 2); EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2)); @@ -79,11 +94,10 @@ TEST(MeshCoordinateTest, Construction) { TEST(MeshCoordinateTest, Comparison) { MeshCoordinate coord1(1, 2); - MeshCoordinate coord2(1, 2); - MeshCoordinate coord3(2, 1); - EXPECT_EQ(coord1, coord2); - EXPECT_NE(coord1, coord3); + EXPECT_EQ(coord1, MeshCoordinate(1, 2)); + EXPECT_NE(coord1, MeshCoordinate(2, 1)); + EXPECT_NE(coord1, MeshCoordinate(1, 2, 1)); } TEST(MeshCoordinateRangeTest, FromShape) { @@ -107,8 +121,8 @@ TEST(MeshCoordinateRangeTest, FromShape) { } TEST(MeshCoordinateRangeTest, Subrange) { - MeshCoordinate start(1, 0); - MeshCoordinate end(2, 2); + MeshCoordinate start(1, 1, 1); + MeshCoordinate end(2, 1, 4); MeshCoordinateRange range(start, end); std::vector coords; @@ -119,12 +133,27 @@ TEST(MeshCoordinateRangeTest, Subrange) { EXPECT_THAT( coords, ElementsAre( - MeshCoordinate(1, 0), - MeshCoordinate(1, 1), - MeshCoordinate(1, 2), - MeshCoordinate(2, 0), - MeshCoordinate(2, 1), - MeshCoordinate(2, 2))); + MeshCoordinate(1, 1, 1), + MeshCoordinate(1, 1, 2), + MeshCoordinate(1, 1, 3), + MeshCoordinate(1, 1, 4), + MeshCoordinate(2, 1, 1), + MeshCoordinate(2, 1, 2), + MeshCoordinate(2, 1, 3), + MeshCoordinate(2, 1, 4))); +} + +TEST(MeshCoordinateRangeTest, SubrangeOneElement) { + MeshCoordinate start(1, 1, 1); + MeshCoordinate end(1, 1, 1); + MeshCoordinateRange range(start, end); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1))); } TEST(MeshCoordinateRangeTest, MismatchedDimensions) { @@ -151,11 +180,12 @@ TEST(ToLinearIndexTest, Basic) { } TEST(ToLinearIndexTest, MismatchedDimensions) { - EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0))); + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0))); } TEST(ToLinearIndexTest, OutOfBounds) { - EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0))); + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0))); + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(0, 3))); } TEST(MeshContainerTest, InitialValues) { diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index 6a0f6496c3b..b3d74dca4b1 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -24,7 +24,8 @@ class SimpleMeshShape : public ShapeBase { using ShapeBase::size; using ShapeBase::view; - // Shorthands for constructing 2D and 3D shapes. + // Shorthands for constructing 1D, 2D and 3D shapes. + SimpleMeshShape(uint32_t num_elements); SimpleMeshShape(uint32_t num_rows, uint32_t num_cols); SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z); @@ -55,7 +56,8 @@ class SimpleMeshShape : public ShapeBase { class MeshCoordinate { public: - // Shorthands for constructing 2D and 3D coordinates. + // Shorthands for constructing 1D, 2D and 3D coordinates. + MeshCoordinate(uint32_t coord); MeshCoordinate(uint32_t row, uint32_t col); MeshCoordinate(uint32_t x, uint32_t y, uint32_t z); diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp index ee20b5a8690..c79af09a63d 100644 --- a/tt_metal/common/mesh_coord.cpp +++ b/tt_metal/common/mesh_coord.cpp @@ -28,6 +28,7 @@ MeshCoordinate shape_back(const SimpleMeshShape& shape) { } // namespace +SimpleMeshShape::SimpleMeshShape(uint32_t num_elements) : ShapeBase({num_elements}) { compute_strides(); } SimpleMeshShape::SimpleMeshShape(uint32_t num_rows, uint32_t num_cols) : ShapeBase({num_rows, num_cols}) { compute_strides(); } @@ -49,11 +50,11 @@ size_t SimpleMeshShape::get_stride(size_t dim) const { return strides_[dim]; } size_t SimpleMeshShape::dims() const { return size(); } size_t SimpleMeshShape::mesh_size() const { - return std::accumulate(value_.begin(), value_.end(), 1, std::multiplies()); + return empty() ? 0 : std::accumulate(value_.begin(), value_.end(), 1, std::multiplies()); } -bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) { return lhs.value_ == rhs.value_; } -bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) { return !(lhs == rhs); } +bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) = default; +bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) = default; std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape) { os << "SimpleMeshShape(["; @@ -67,6 +68,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape) { return os; } +MeshCoordinate::MeshCoordinate(uint32_t coord) : value_({coord}) {} MeshCoordinate::MeshCoordinate(uint32_t row, uint32_t col) : value_({row, col}) {} MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y, uint32_t z) : value_({x, y, z}) {} @@ -119,7 +121,7 @@ MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() { for (int i = new_coords.size() - 1; i >= 0; --i) { auto& dimension_value = new_coords[i]; if (++dimension_value > range_->end_coord()[i]) { - dimension_value = 0; + dimension_value = range_->start_coord()[i]; } else { break; } From 24f249ed62398e01293c5696ed29070861b0d10c Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sun, 9 Feb 2025 06:26:06 +0000 Subject: [PATCH 06/10] Value proxy class for better container iteration --- .../tt_metal/distributed/test_mesh_coord.cpp | 33 +++++++ tt_metal/api/tt-metalium/mesh_coord.hpp | 93 ++++++++++++++++--- tt_metal/common/mesh_coord.cpp | 2 +- tt_metal/distributed/mesh_buffer.cpp | 4 +- 4 files changed, 118 insertions(+), 14 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index 80f3f299198..a0baff76b74 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -210,6 +210,39 @@ TEST(MeshContainerTest, ElementAccessRowMajor) { container.at(MeshCoordinate(1, 1)) = 4; container.at(MeshCoordinate(1, 2)) = 5; + std::vector coords; + std::vector values; + for (const auto& [coord, value] : container) { + coords.push_back(coord); + values.push_back(value); + } + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); + EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +TEST(MeshContainerTest, MutateThroughProxy) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + // Proxy class provides access to the container value through the mutable reference. + int updated_value = 0; + for (auto& [coord, value] : container) { + value = updated_value++; + } + + // `auto` makes a copy of the value, verify this loop is a no-op. + for (auto [coord, value] : container) { + value = updated_value++; + } + std::vector values; for (const auto& [coord, value] : container) { values.push_back(value); diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index b3d74dca4b1..20718751dab 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include "shape_base.hpp" @@ -18,11 +19,6 @@ class SimpleMeshShape : public ShapeBase { public: using ShapeBase::ShapeBase; using ShapeBase::operator[]; - using ShapeBase::cbegin; - using ShapeBase::cend; - using ShapeBase::empty; - using ShapeBase::size; - using ShapeBase::view; // Shorthands for constructing 1D, 2D and 3D shapes. SimpleMeshShape(uint32_t num_elements); @@ -50,6 +46,9 @@ class SimpleMeshShape : public ShapeBase { friend std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape); private: + using ShapeBase::empty; + using ShapeBase::size; + void compute_strides(); tt::stl::SmallVector strides_; }; @@ -104,7 +103,7 @@ class MeshCoordinateRange { class Iterator { public: Iterator& operator++(); - MeshCoordinate operator*() const; + const MeshCoordinate& operator*() const; bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; @@ -131,6 +130,53 @@ class MeshCoordinateRange { MeshCoordinate end_; }; +namespace detail { + +// Proxy class that allows convenient structured binding to a pair of a coordinate and the value it points to. +// This supports iterator semantics similar to `std::map` / `std::unordered_map`. +template +class MeshCoordinateValueProxy { +public: + MeshCoordinateValueProxy(const MeshCoordinate* coord, T* value_ptr) : coord_(coord), value_ptr_(value_ptr) {} + + const MeshCoordinate& coord() const { return *coord_; } + T& value() { return *value_ptr_; } + const T& value() const { return *value_ptr_; } + + template + decltype(auto) get() & { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + template + decltype(auto) get() const& { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + template + auto get() const&& { + return get(); + } + +private: + const MeshCoordinate* coord_ = nullptr; + T* value_ptr_ = nullptr; +}; + +} // namespace detail + // Allows storing data in a mesh-shaped container, with convenient accessors and iterators. template class MeshContainer { @@ -147,10 +193,10 @@ class MeshContainer { // Allows to iterate over the container elements, returning a pair of (coordinate, value reference). class Iterator { public: - using reference = std::pair>; + using ValueProxy = detail::MeshCoordinateValueProxy; Iterator& operator++(); - reference operator*() const; + ValueProxy& operator*(); bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; @@ -161,6 +207,9 @@ class MeshContainer { MeshContainer* container_ = nullptr; MeshCoordinateRange::Iterator coord_iter_; size_t linear_index_ = 0; + + // Provides mutable access to the container value along with the coordinate from the range iterator. + ValueProxy value_proxy_; }; Iterator begin(); @@ -194,18 +243,22 @@ const T& MeshContainer::at(const MeshCoordinate& coord) const { template MeshContainer::Iterator::Iterator( MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : - container_(container), coord_iter_(coord_iter), linear_index_(linear_index) {} + container_(container), + coord_iter_(coord_iter), + linear_index_(linear_index), + value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {} template typename MeshContainer::Iterator& MeshContainer::Iterator::operator++() { ++linear_index_; ++coord_iter_; + value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]); return *this; } template -typename MeshContainer::Iterator::reference MeshContainer::Iterator::operator*() const { - return {*coord_iter_, std::ref(container_->values_[linear_index_])}; +typename MeshContainer::Iterator::ValueProxy& MeshContainer::Iterator::operator*() { + return value_proxy_; } template @@ -229,3 +282,21 @@ typename MeshContainer::Iterator MeshContainer::end() { } } // namespace tt::tt_metal::distributed + +namespace std { + +template +struct tuple_size> : std::integral_constant { +}; + +template +struct tuple_element<0, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = const tt::tt_metal::distributed::MeshCoordinate; +}; + +template +struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = T; +}; + +} // namespace std diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp index c79af09a63d..f0e7890bfda 100644 --- a/tt_metal/common/mesh_coord.cpp +++ b/tt_metal/common/mesh_coord.cpp @@ -129,7 +129,7 @@ MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() { current_coord_ = MeshCoordinate(new_coords); return *this; } -MeshCoordinate MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } +const MeshCoordinate& MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } bool MeshCoordinateRange::Iterator::operator==(const Iterator& other) const { return range_ == other.range_ && linear_index_ == other.linear_index_; } diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index 660a0fe8529..13d1fc5e6cc 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -125,8 +125,8 @@ void MeshBuffer::initialize_device_buffers() { return buffer; }; - for (auto [coord, device_buffer] : buffers_) { - device_buffer.get() = init_device_buffer_at_address(coord); + for (auto& [coord, device_buffer] : buffers_) { + device_buffer = init_device_buffer_at_address(coord); } } From 5e5b71a329bcda186e5d4bc42e12d0af4065bf5d Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 14 Feb 2025 17:57:43 +0000 Subject: [PATCH 07/10] Use x, y instead of rows, cols. --- tt_metal/api/tt-metalium/mesh_coord.hpp | 8 ++++---- tt_metal/common/mesh_coord.cpp | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index 20718751dab..2ce5d4767c4 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -21,8 +21,8 @@ class SimpleMeshShape : public ShapeBase { using ShapeBase::operator[]; // Shorthands for constructing 1D, 2D and 3D shapes. - SimpleMeshShape(uint32_t num_elements); - SimpleMeshShape(uint32_t num_rows, uint32_t num_cols); + SimpleMeshShape(uint32_t x); + SimpleMeshShape(uint32_t x, uint32_t y); SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z); // Temporary constructor for transitioning to `SimpleMeshShape`. @@ -56,8 +56,8 @@ class SimpleMeshShape : public ShapeBase { class MeshCoordinate { public: // Shorthands for constructing 1D, 2D and 3D coordinates. - MeshCoordinate(uint32_t coord); - MeshCoordinate(uint32_t row, uint32_t col); + MeshCoordinate(uint32_t x); + MeshCoordinate(uint32_t x, uint32_t y); MeshCoordinate(uint32_t x, uint32_t y, uint32_t z); // Constructs a generic N-dimensional coordinate. diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp index f0e7890bfda..dcfddad7728 100644 --- a/tt_metal/common/mesh_coord.cpp +++ b/tt_metal/common/mesh_coord.cpp @@ -28,11 +28,9 @@ MeshCoordinate shape_back(const SimpleMeshShape& shape) { } // namespace -SimpleMeshShape::SimpleMeshShape(uint32_t num_elements) : ShapeBase({num_elements}) { compute_strides(); } -SimpleMeshShape::SimpleMeshShape(uint32_t num_rows, uint32_t num_cols) : ShapeBase({num_rows, num_cols}) { - compute_strides(); -} -SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) { compute_strides(); } +SimpleMeshShape::SimpleMeshShape(uint32_t x) : ShapeBase({x}) {} +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y) : ShapeBase({x, y}) {} +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) {} SimpleMeshShape::SimpleMeshShape(const MeshShape& legacy_shape) : SimpleMeshShape(legacy_shape.num_rows, legacy_shape.num_cols) {} @@ -69,7 +67,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape) { } MeshCoordinate::MeshCoordinate(uint32_t coord) : value_({coord}) {} -MeshCoordinate::MeshCoordinate(uint32_t row, uint32_t col) : value_({row, col}) {} +MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y) : value_({x, y}) {} MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y, uint32_t z) : value_({x, y, z}) {} MeshCoordinate::MeshCoordinate(tt::stl::Span coords) : value_(coords.begin(), coords.end()) {} From 5c8b650c8b92472305a924bb05ab08c4e67238c4 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sat, 15 Feb 2025 05:56:23 +0000 Subject: [PATCH 08/10] Add const iterator --- .../tt_metal/distributed/test_mesh_coord.cpp | 30 ++++++-- tt_metal/api/tt-metalium/mesh_coord.hpp | 70 ++++++++++++++++++- tt_metal/common/mesh_coord.cpp | 6 +- .../distributed_buffer_rw.cpp | 2 +- .../distributed_eltwise_add.cpp | 2 +- 5 files changed, 100 insertions(+), 10 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index a0baff76b74..40cd659ae05 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -193,7 +193,7 @@ TEST(MeshContainerTest, InitialValues) { MeshContainer container(shape, 3); std::vector initial_values; - for (const auto& [coord, value] : container) { + for (const auto& [_, value] : container) { initial_values.push_back(value); } EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3)); @@ -228,23 +228,45 @@ TEST(MeshContainerTest, ElementAccessRowMajor) { EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); } +TEST(MeshContainerTest, ConstContainer) { + SimpleMeshShape shape(2, 3); + const MeshContainer container(shape, 0); + + std::vector coords; + std::vector values; + for (const auto& [coord, value] : container) { + coords.push_back(coord); + values.push_back(value); + } + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); + EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0)); +} + TEST(MeshContainerTest, MutateThroughProxy) { SimpleMeshShape shape(2, 3); MeshContainer container(shape, 0); // Proxy class provides access to the container value through the mutable reference. int updated_value = 0; - for (auto& [coord, value] : container) { + for (auto& [_, value] : container) { value = updated_value++; } // `auto` makes a copy of the value, verify this loop is a no-op. - for (auto [coord, value] : container) { + for (auto [_, value] : container) { value = updated_value++; } std::vector values; - for (const auto& [coord, value] : container) { + for (const auto& [_, value] : container) { values.push_back(value); } EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index 2ce5d4767c4..e346ce2ca83 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -97,6 +97,7 @@ class MeshCoordinateRange { // Constructs a range that iterates over all coordinates in the mesh. MeshCoordinateRange(const SimpleMeshShape& shape); + // Returns start and (inclusive) end coordinates of the range. const MeshCoordinate& start_coord() const; const MeshCoordinate& end_coord() const; @@ -165,6 +166,7 @@ class MeshCoordinateValueProxy { } } + // Force a copy via `auto`. template auto get() const&& { return get(); @@ -177,7 +179,8 @@ class MeshCoordinateValueProxy { } // namespace detail -// Allows storing data in a mesh-shaped container, with convenient accessors and iterators. +// Allows storing data in a mesh-shaped flat container, with convenient accessors and iterators. +// The iteration order and the storage memory layout is row-major. template class MeshContainer { public: @@ -212,8 +215,32 @@ class MeshContainer { ValueProxy value_proxy_; }; + class ConstIterator { + public: + using ValueProxy = detail::MeshCoordinateValueProxy; + + ConstIterator& operator++(); + const ValueProxy& operator*() const; + bool operator==(const ConstIterator& other) const; + bool operator!=(const ConstIterator& other) const; + + private: + ConstIterator( + const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index); + friend class MeshContainer; + + const MeshContainer* container_ = nullptr; + MeshCoordinateRange::Iterator coord_iter_; + size_t linear_index_ = 0; + + // Provides mutable access to the container value along with the coordinate from the range iterator. + ValueProxy value_proxy_; + }; + Iterator begin(); Iterator end(); + ConstIterator begin() const; + ConstIterator end() const; private: SimpleMeshShape shape_; @@ -261,6 +288,27 @@ typename MeshContainer::Iterator::ValueProxy& MeshContainer::Iterator::ope return value_proxy_; } +template +MeshContainer::ConstIterator::ConstIterator( + const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : + container_(container), + coord_iter_(coord_iter), + linear_index_(linear_index), + value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {} + +template +typename MeshContainer::ConstIterator& MeshContainer::ConstIterator::operator++() { + ++linear_index_; + ++coord_iter_; + value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]); + return *this; +} + +template +const typename MeshContainer::ConstIterator::ValueProxy& MeshContainer::ConstIterator::operator*() const { + return value_proxy_; +} + template bool MeshContainer::Iterator::operator==(const Iterator& other) const { return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; @@ -271,6 +319,16 @@ bool MeshContainer::Iterator::operator!=(const Iterator& other) const { return !(*this == other); } +template +bool MeshContainer::ConstIterator::operator==(const ConstIterator& other) const { + return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; +} + +template +bool MeshContainer::ConstIterator::operator!=(const ConstIterator& other) const { + return !(*this == other); +} + template typename MeshContainer::Iterator MeshContainer::begin() { return Iterator(this, coord_range_.begin(), /* linear_index = */ 0); @@ -281,6 +339,16 @@ typename MeshContainer::Iterator MeshContainer::end() { return Iterator(this, coord_range_.end(), shape_.mesh_size()); } +template +typename MeshContainer::ConstIterator MeshContainer::begin() const { + return ConstIterator(this, coord_range_.begin(), /* linear_index = */ 0); +} + +template +typename MeshContainer::ConstIterator MeshContainer::end() const { + return ConstIterator(this, coord_range_.end(), shape_.mesh_size()); +} + } // namespace tt::tt_metal::distributed namespace std { diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp index dcfddad7728..9a98a0ce801 100644 --- a/tt_metal/common/mesh_coord.cpp +++ b/tt_metal/common/mesh_coord.cpp @@ -28,9 +28,9 @@ MeshCoordinate shape_back(const SimpleMeshShape& shape) { } // namespace -SimpleMeshShape::SimpleMeshShape(uint32_t x) : ShapeBase({x}) {} -SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y) : ShapeBase({x, y}) {} -SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) {} +SimpleMeshShape::SimpleMeshShape(uint32_t x) : ShapeBase({x}) { compute_strides(); } +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y) : ShapeBase({x, y}) { compute_strides(); } +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) { compute_strides(); } SimpleMeshShape::SimpleMeshShape(const MeshShape& legacy_shape) : SimpleMeshShape(legacy_shape.num_rows, legacy_shape.num_cols) {} diff --git a/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp b/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp index d54d6a1c6e7..a1b17cec8d5 100644 --- a/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp +++ b/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp @@ -26,7 +26,7 @@ int main(int argc, char** argv) { // We will create a distributed buffer with 8 shards of {32, 32} and distribute it across the devices in the mesh. auto shard_shape = Shape2D{32, 32}; auto distributed_buffer_shape = Shape2D{32 * mesh_device->num_rows(), 32 * mesh_device->num_cols()}; - uint32_t tile_size_bytes = detail::TileSize(tt::DataFormat::UInt32); + uint32_t tile_size_bytes = tt::tt_metal::detail::TileSize(tt::DataFormat::UInt32); uint32_t distributed_buffer_size_bytes = 64 * 128 * tile_size_bytes; auto local_buffer_config = DeviceLocalBufferConfig{ diff --git a/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp b/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp index 73bf18ee0be..9dbf0bbbd61 100644 --- a/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp +++ b/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp @@ -92,7 +92,7 @@ int main(int argc, char** argv) { auto distributed_buffer_shape = Shape2D{shard_shape.height() * mesh_device->num_rows(), shard_shape.width() * mesh_device->num_cols()}; auto num_tiles = 1; - auto tile_size_bytes = detail::TileSize(tt::DataFormat::Float16_b); + auto tile_size_bytes = tt::tt_metal::detail::TileSize(tt::DataFormat::Float16_b); auto distributed_buffer_size_bytes = mesh_device->num_rows() * mesh_device->num_cols() * tile_size_bytes; // Configure device-local buffer settings From 16be729c79145f895e2363b99ca3844e2d35677d Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sat, 15 Feb 2025 06:12:50 +0000 Subject: [PATCH 09/10] Better to_linear_index test --- .../tt_metal/distributed/test_mesh_coord.cpp | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index 40cd659ae05..09853a488a0 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -169,14 +169,20 @@ TEST(MeshCoordinateRangeTest, InvalidRange) { } TEST(ToLinearIndexTest, Basic) { - SimpleMeshShape shape(2, 3); - - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0)), 0); - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1)), 1); - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 2)), 2); - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0)), 3); - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1)), 4); - EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 2)), 5); + SimpleMeshShape shape(2, 2, 3); + + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 0)), 0); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 1)), 1); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 2)), 2); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 0)), 3); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 1)), 4); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 2)), 5); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 0)), 6); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 1)), 7); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 2)), 8); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 0)), 9); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 1)), 10); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 2)), 11); } TEST(ToLinearIndexTest, MismatchedDimensions) { From 5d46d98bdd34a67af5956c47336a84fd28e10aba Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sat, 15 Feb 2025 18:48:14 +0000 Subject: [PATCH 10/10] Fix shape base view --- tt_metal/common/shape_base.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tt_metal/common/shape_base.cpp b/tt_metal/common/shape_base.cpp index fdb94487ed8..33acd941d22 100644 --- a/tt_metal/common/shape_base.cpp +++ b/tt_metal/common/shape_base.cpp @@ -4,7 +4,9 @@ #include "assert.hpp" #include "shape_base.hpp" +#include #include +#include #include "fmt/color.h" namespace tt::tt_metal { @@ -46,7 +48,14 @@ bool ShapeBase::empty() const { return original_size_ == 0; } size_t ShapeBase::size() const { return original_size_; } -tt::stl::Span ShapeBase::view() const { return tt::stl::Span(value_); } +tt::stl::Span ShapeBase::view() const { + const auto begin = cbegin(); + const auto end = cend(); + // `Span` constructor requires a contiguous range of data. + static_assert( + std::is_base_of_v::iterator_category>); + return tt::stl::Span(&*begin, std::distance(begin, end)); +} bool ShapeBase::operator==(const ShapeBase& other) const = default;