From 52c53d562387ad929c76d08f6b26834d5f4fa60e Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sat, 15 Feb 2025 20:55:29 -0500 Subject: [PATCH] #17477: Introduce ND coordinate system for TT-distributed (#17745) ### Ticket #17477 ### Problem description Existing mesh infra assumes 2D. This assumption won't hold in the future. ### What's changed Introduce a new `SimpleMeshShape` that will gradually replace the existing `MeshShape`, after which it will be renamed to `MeshShape`. Introduce `MeshCoordinate`, `MeshCoordinateRange`, and `MeshContainer` - primitives designed to work with the new ND coordinate system. `MeshContainer` allows efficient flat representation of various metadata that matches the mesh shape. Iterators are available to make it easy to use. `MeshCoordinate` along with strides that are precomputed on `SimpleMeshShape` allows for an easy point access. The integration with `MeshBuffer` demonstrates the use case. Next steps: * Replace the existing `MeshShape`, `MeshOffset`, and the related aliases with the new `SimpleMeshShape`, and `MeshCoordinate`. * No plans to generalize with `CoreCoord`, for now. Cores are fundamentally in 2D, so a more specialized system can be used for efficiency. Also it is not desired to make `CoreCoord` to interop with `MeshCoordinate` - the 2 sets of coordinates mean entirely different concepts. * More functionality might be added, as we continue working on TT-distributed. ### Checklist - [X] [All post commit](https://github.com/tenstorrent/tt-metal/actions/runs/13347753550) - [X] New/Existing tests provide coverage for changes --- tests/tt_metal/distributed/CMakeLists.txt | 1 + .../tt_metal/distributed/test_mesh_coord.cpp | 290 ++++++++++++++ tt_metal/api/tt-metalium/mesh_buffer.hpp | 7 +- tt_metal/api/tt-metalium/mesh_coord.hpp | 370 ++++++++++++++++++ tt_metal/api/tt-metalium/mesh_device.hpp | 2 + tt_metal/api/tt-metalium/shape_base.hpp | 6 +- tt_metal/common/CMakeLists.txt | 1 + tt_metal/common/mesh_coord.cpp | 161 ++++++++ tt_metal/common/shape_base.cpp | 11 +- tt_metal/distributed/mesh_buffer.cpp | 28 +- tt_metal/distributed/mesh_device.cpp | 6 +- .../distributed_buffer_rw.cpp | 2 +- .../distributed_eltwise_add.cpp | 2 +- .../ttnn/operations/data_movement/pad/pad.cpp | 5 +- 14 files changed, 862 insertions(+), 30 deletions(-) create mode 100644 tests/tt_metal/distributed/test_mesh_coord.cpp create mode 100644 tt_metal/api/tt-metalium/mesh_coord.hpp create mode 100644 tt_metal/common/mesh_coord.cpp diff --git a/tests/tt_metal/distributed/CMakeLists.txt b/tests/tt_metal/distributed/CMakeLists.txt index 27bb9ee7b53..08fededb592 100644 --- a/tests/tt_metal/distributed/CMakeLists.txt +++ b/tests/tt_metal/distributed/CMakeLists.txt @@ -1,6 +1,7 @@ set(UNIT_TESTS_DISTRIBUTED_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_coord.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_workload.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_sub_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_mesh_allocator.cpp diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp new file mode 100644 index 00000000000..09853a488a0 --- /dev/null +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -0,0 +1,290 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "mesh_coord.hpp" + +namespace tt::tt_metal::distributed { +namespace { + +using ::testing::ElementsAre; + +TEST(SimpleMeshShapeTest, Construction) { + SimpleMeshShape shape_1d(3); + EXPECT_EQ(shape_1d.dims(), 1); + EXPECT_EQ(shape_1d[0], 3); + EXPECT_EQ(shape_1d.mesh_size(), 3); + + SimpleMeshShape shape_2d(3, 4); + EXPECT_EQ(shape_2d.dims(), 2); + EXPECT_EQ(shape_2d[0], 3); + EXPECT_EQ(shape_2d[1], 4); + EXPECT_EQ(shape_2d.mesh_size(), 12); + + SimpleMeshShape shape_3d(2, 3, 4); + EXPECT_EQ(shape_3d.dims(), 3); + EXPECT_EQ(shape_3d[0], 2); + EXPECT_EQ(shape_3d[1], 3); + EXPECT_EQ(shape_3d[2], 4); + EXPECT_EQ(shape_3d.mesh_size(), 24); + + SimpleMeshShape shape_5d({2, 3, 4, 5, 6}); + EXPECT_EQ(shape_5d.dims(), 5); + EXPECT_EQ(shape_5d[0], 2); + EXPECT_EQ(shape_5d[1], 3); + EXPECT_EQ(shape_5d[2], 4); + EXPECT_EQ(shape_5d[3], 5); + EXPECT_EQ(shape_5d[4], 6); + EXPECT_EQ(shape_5d.mesh_size(), 720); +} + +TEST(SimpleMeshShapeTest, ZeroShape) { + SimpleMeshShape shape({}); + EXPECT_EQ(shape.dims(), 0); + EXPECT_EQ(shape.mesh_size(), 0); +} + +TEST(SimpleMeshShapeTest, Strides) { + SimpleMeshShape shape(2, 3, 4); + EXPECT_EQ(shape.get_stride(0), 12); // 3 * 4 + EXPECT_EQ(shape.get_stride(1), 4); // 4 + EXPECT_EQ(shape.get_stride(2), 1); // 1 +} + +TEST(SimpleMeshShapeTest, Comparison) { + SimpleMeshShape shape(2, 3); + + EXPECT_EQ(shape, SimpleMeshShape(2, 3)); + EXPECT_NE(shape, SimpleMeshShape(3, 2)); + EXPECT_NE(shape, SimpleMeshShape(1, 2, 3)); +} + +TEST(MeshCoordinateTest, Construction) { + MeshCoordinate coord_1d(1); + EXPECT_EQ(coord_1d.dims(), 1); + EXPECT_THAT(coord_1d.coords(), ElementsAre(1)); + EXPECT_EQ(coord_1d[0], 1); + + MeshCoordinate coord_2d(1, 2); + EXPECT_EQ(coord_2d.dims(), 2); + EXPECT_THAT(coord_2d.coords(), ElementsAre(1, 2)); + EXPECT_EQ(coord_2d[0], 1); + EXPECT_EQ(coord_2d[1], 2); + + MeshCoordinate coord_3d(1, 2, 3); + EXPECT_EQ(coord_3d.dims(), 3); + EXPECT_THAT(coord_3d.coords(), ElementsAre(1, 2, 3)); + EXPECT_EQ(coord_3d[0], 1); + EXPECT_EQ(coord_3d[1], 2); + EXPECT_EQ(coord_3d[2], 3); + + std::vector values = {1, 2, 3, 4, 5}; + MeshCoordinate coord_span(values); + EXPECT_EQ(coord_span.dims(), 5); + EXPECT_THAT(coord_span.coords(), ElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(coord_span[0], 1); + EXPECT_EQ(coord_span[1], 2); + EXPECT_EQ(coord_span[2], 3); + EXPECT_EQ(coord_span[3], 4); + EXPECT_EQ(coord_span[4], 5); +} + +TEST(MeshCoordinateTest, Comparison) { + MeshCoordinate coord1(1, 2); + + EXPECT_EQ(coord1, MeshCoordinate(1, 2)); + EXPECT_NE(coord1, MeshCoordinate(2, 1)); + EXPECT_NE(coord1, MeshCoordinate(1, 2, 1)); +} + +TEST(MeshCoordinateRangeTest, FromShape) { + SimpleMeshShape shape(2, 3); + MeshCoordinateRange range(shape); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); +} + +TEST(MeshCoordinateRangeTest, Subrange) { + MeshCoordinate start(1, 1, 1); + MeshCoordinate end(2, 1, 4); + MeshCoordinateRange range(start, end); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(1, 1, 1), + MeshCoordinate(1, 1, 2), + MeshCoordinate(1, 1, 3), + MeshCoordinate(1, 1, 4), + MeshCoordinate(2, 1, 1), + MeshCoordinate(2, 1, 2), + MeshCoordinate(2, 1, 3), + MeshCoordinate(2, 1, 4))); +} + +TEST(MeshCoordinateRangeTest, SubrangeOneElement) { + MeshCoordinate start(1, 1, 1); + MeshCoordinate end(1, 1, 1); + MeshCoordinateRange range(start, end); + + std::vector coords; + for (const auto& coord : range) { + coords.push_back(coord); + } + + EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1))); +} + +TEST(MeshCoordinateRangeTest, MismatchedDimensions) { + MeshCoordinate start(1, 0); + MeshCoordinate end(2, 3, 1); + EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); +} + +TEST(MeshCoordinateRangeTest, InvalidRange) { + MeshCoordinate start(1, 2, 0); + MeshCoordinate end(1, 1, 1); + EXPECT_ANY_THROW(MeshCoordinateRange(start, end)); +} + +TEST(ToLinearIndexTest, Basic) { + SimpleMeshShape shape(2, 2, 3); + + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 0)), 0); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 1)), 1); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 0, 2)), 2); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 0)), 3); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 1)), 4); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(0, 1, 2)), 5); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 0)), 6); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 1)), 7); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 0, 2)), 8); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 0)), 9); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 1)), 10); + EXPECT_EQ(to_linear_index(shape, MeshCoordinate(1, 1, 2)), 11); +} + +TEST(ToLinearIndexTest, MismatchedDimensions) { + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(1, 2, 3), MeshCoordinate(0, 0))); +} + +TEST(ToLinearIndexTest, OutOfBounds) { + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(2, 0))); + EXPECT_ANY_THROW(to_linear_index(SimpleMeshShape(2, 3), MeshCoordinate(0, 3))); +} + +TEST(MeshContainerTest, InitialValues) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 3); + + std::vector initial_values; + for (const auto& [_, value] : container) { + initial_values.push_back(value); + } + EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3)); +} + +TEST(MeshContainerTest, ElementAccessRowMajor) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + container.at(MeshCoordinate(0, 0)) = 0; + container.at(MeshCoordinate(0, 1)) = 1; + container.at(MeshCoordinate(0, 2)) = 2; + container.at(MeshCoordinate(1, 0)) = 3; + container.at(MeshCoordinate(1, 1)) = 4; + container.at(MeshCoordinate(1, 2)) = 5; + + std::vector coords; + std::vector values; + for (const auto& [coord, value] : container) { + coords.push_back(coord); + values.push_back(value); + } + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); + EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +TEST(MeshContainerTest, ConstContainer) { + SimpleMeshShape shape(2, 3); + const MeshContainer container(shape, 0); + + std::vector coords; + std::vector values; + for (const auto& [coord, value] : container) { + coords.push_back(coord); + values.push_back(value); + } + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); + EXPECT_THAT(values, ElementsAre(0, 0, 0, 0, 0, 0)); +} + +TEST(MeshContainerTest, MutateThroughProxy) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + // Proxy class provides access to the container value through the mutable reference. + int updated_value = 0; + for (auto& [_, value] : container) { + value = updated_value++; + } + + // `auto` makes a copy of the value, verify this loop is a no-op. + for (auto [_, value] : container) { + value = updated_value++; + } + + std::vector values; + for (const auto& [_, value] : container) { + values.push_back(value); + } + EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +TEST(MeshContainerTest, OutOfBounds) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + EXPECT_ANY_THROW(container.at(MeshCoordinate(2, 0))); + EXPECT_ANY_THROW(container.at(MeshCoordinate(0, 0, 0))); +} + +} // namespace +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/api/tt-metalium/mesh_buffer.hpp b/tt_metal/api/tt-metalium/mesh_buffer.hpp index 0e029685b47..8656fc02e67 100644 --- a/tt_metal/api/tt-metalium/mesh_buffer.hpp +++ b/tt_metal/api/tt-metalium/mesh_buffer.hpp @@ -6,6 +6,7 @@ #include "buffer.hpp" #include "buffer_constants.hpp" +#include "mesh_coord.hpp" #include "mesh_device.hpp" #include "mesh_device_view.hpp" #include "shape2d.hpp" @@ -96,6 +97,7 @@ class MeshBuffer { const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; } std::shared_ptr get_device_buffer(const Coordinate& device_coord) const; + std::shared_ptr get_device_buffer(const MeshCoordinate& device_coord) const; uint32_t datum_size_bytes() const; Shape2D physical_shard_shape() const; std::pair replicated_dims() const; @@ -108,6 +110,7 @@ class MeshBuffer { DeviceAddr device_local_size, MeshDevice* mesh_device, std::shared_ptr backing_buffer) : + buffers_(SimpleMeshShape(mesh_device->shape()), nullptr), config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device), @@ -122,6 +125,7 @@ class MeshBuffer { DeviceAddr address, DeviceAddr device_local_size, MeshDevice* mesh_device) : + buffers_(SimpleMeshShape(mesh_device->shape()), /*fill_value=*/nullptr), config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device), @@ -136,8 +140,7 @@ class MeshBuffer { DeviceAddr address_ = 0; DeviceAddr device_local_size_ = 0; - // TODO: Consider optimizing with SmallVector. - std::vector>> buffers_; + MeshContainer> buffers_; // `MeshBufferState` specifies the state of the MeshBuffer. It can either be: // 1. Owned - a single device buffer is responsible for providing the address for the entire mesh buffer. diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp new file mode 100644 index 00000000000..e346ce2ca83 --- /dev/null +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -0,0 +1,370 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "shape_base.hpp" + +namespace tt::tt_metal::distributed { + +struct MeshShape; + +// TODO: #17477 - Rename to `MeshShape` when the legacy type is gone. +class SimpleMeshShape : public ShapeBase { +public: + using ShapeBase::ShapeBase; + using ShapeBase::operator[]; + + // Shorthands for constructing 1D, 2D and 3D shapes. + SimpleMeshShape(uint32_t x); + SimpleMeshShape(uint32_t x, uint32_t y); + SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z); + + // Temporary constructor for transitioning to `SimpleMeshShape`. + SimpleMeshShape(const MeshShape& legacy_shape); + + // Returns the dimensionality of the mesh. + size_t dims() const; + + // Returns the stride for the given dimension. + size_t get_stride(size_t dim) const; + + // Returns the total number of elements in the mesh. + size_t mesh_size() const; + + // Needed for reflect / fmt + static constexpr auto attribute_names = std::forward_as_tuple("value"); + auto attribute_values() const { return std::forward_as_tuple(value_); } + + friend bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs); + friend bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs); + friend std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape); + +private: + using ShapeBase::empty; + using ShapeBase::size; + + void compute_strides(); + tt::stl::SmallVector strides_; +}; + +class MeshCoordinate { +public: + // Shorthands for constructing 1D, 2D and 3D coordinates. + MeshCoordinate(uint32_t x); + MeshCoordinate(uint32_t x, uint32_t y); + MeshCoordinate(uint32_t x, uint32_t y, uint32_t z); + + // Constructs a generic N-dimensional coordinate. + explicit MeshCoordinate(tt::stl::Span coords); + + // Returns the dimensionality of the coordinate. + size_t dims() const; + + // Returns the coordinate values as a span. + tt::stl::Span coords() const; + + // Returns the coordinate value at the given index. + uint32_t operator[](size_t dim) const; + + // Needed for reflect / fmt + static constexpr auto attribute_names = std::forward_as_tuple("value"); + auto attribute_values() const { return std::forward_as_tuple(value_); } + + friend bool operator==(const MeshCoordinate& lhs, const MeshCoordinate& rhs); + friend bool operator!=(const MeshCoordinate& lhs, const MeshCoordinate& rhs); + friend std::ostream& operator<<(std::ostream& os, const MeshCoordinate& shape); + +private: + tt::stl::SmallVector value_; +}; + +// Converts a MeshCoordinate to a linear index. +// Throws if `coord` is out of bounds of `shape`. +size_t to_linear_index(const SimpleMeshShape& shape, const MeshCoordinate& coord); + +// Represents a range of MeshCoordinates. Requires that mesh coordinates have the same dimensionality. +class MeshCoordinateRange { +public: + // Constructs an inclusive range that iterates between `start` and `end`. + MeshCoordinateRange(const MeshCoordinate& start, const MeshCoordinate& end); + + // Constructs a range that iterates over all coordinates in the mesh. + MeshCoordinateRange(const SimpleMeshShape& shape); + + // Returns start and (inclusive) end coordinates of the range. + const MeshCoordinate& start_coord() const; + const MeshCoordinate& end_coord() const; + + class Iterator { + public: + Iterator& operator++(); + const MeshCoordinate& operator*() const; + bool operator==(const Iterator& other) const; + bool operator!=(const Iterator& other) const; + + private: + Iterator(const MeshCoordinateRange* range, const MeshCoordinate& current_coord, size_t linear_index); + friend class MeshCoordinateRange; + + const MeshCoordinateRange* range_ = nullptr; + + // For simplicity, rely on `linear_index_` for the iterator boundary check, and allow + // MeshCoordinate to wrap around the range end. + MeshCoordinate current_coord_; + size_t linear_index_ = 0; + }; + + Iterator begin() const; + Iterator end() const; + + friend bool operator==(const MeshCoordinateRange& lhs, const MeshCoordinateRange& rhs); + friend bool operator!=(const MeshCoordinateRange& lhs, const MeshCoordinateRange& rhs); + +private: + MeshCoordinate start_; + MeshCoordinate end_; +}; + +namespace detail { + +// Proxy class that allows convenient structured binding to a pair of a coordinate and the value it points to. +// This supports iterator semantics similar to `std::map` / `std::unordered_map`. +template +class MeshCoordinateValueProxy { +public: + MeshCoordinateValueProxy(const MeshCoordinate* coord, T* value_ptr) : coord_(coord), value_ptr_(value_ptr) {} + + const MeshCoordinate& coord() const { return *coord_; } + T& value() { return *value_ptr_; } + const T& value() const { return *value_ptr_; } + + template + decltype(auto) get() & { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + template + decltype(auto) get() const& { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + // Force a copy via `auto`. + template + auto get() const&& { + return get(); + } + +private: + const MeshCoordinate* coord_ = nullptr; + T* value_ptr_ = nullptr; +}; + +} // namespace detail + +// Allows storing data in a mesh-shaped flat container, with convenient accessors and iterators. +// The iteration order and the storage memory layout is row-major. +template +class MeshContainer { +public: + MeshContainer(const SimpleMeshShape& shape, const T& fill_value); + + // Returns a shape of the container. + const SimpleMeshShape& shape() const; + + // Accessor methods. + T& at(const MeshCoordinate& coord); + const T& at(const MeshCoordinate& coord) const; + + // Allows to iterate over the container elements, returning a pair of (coordinate, value reference). + class Iterator { + public: + using ValueProxy = detail::MeshCoordinateValueProxy; + + Iterator& operator++(); + ValueProxy& operator*(); + bool operator==(const Iterator& other) const; + bool operator!=(const Iterator& other) const; + + private: + Iterator(MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index); + friend class MeshContainer; + + MeshContainer* container_ = nullptr; + MeshCoordinateRange::Iterator coord_iter_; + size_t linear_index_ = 0; + + // Provides mutable access to the container value along with the coordinate from the range iterator. + ValueProxy value_proxy_; + }; + + class ConstIterator { + public: + using ValueProxy = detail::MeshCoordinateValueProxy; + + ConstIterator& operator++(); + const ValueProxy& operator*() const; + bool operator==(const ConstIterator& other) const; + bool operator!=(const ConstIterator& other) const; + + private: + ConstIterator( + const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index); + friend class MeshContainer; + + const MeshContainer* container_ = nullptr; + MeshCoordinateRange::Iterator coord_iter_; + size_t linear_index_ = 0; + + // Provides mutable access to the container value along with the coordinate from the range iterator. + ValueProxy value_proxy_; + }; + + Iterator begin(); + Iterator end(); + ConstIterator begin() const; + ConstIterator end() const; + +private: + SimpleMeshShape shape_; + MeshCoordinateRange coord_range_; + std::vector values_; +}; + +template +MeshContainer::MeshContainer(const SimpleMeshShape& shape, const T& fill_value) : + shape_(shape), coord_range_(shape), values_(shape.mesh_size(), fill_value) {} + +template +const SimpleMeshShape& MeshContainer::shape() const { + return shape_; +} + +template +T& MeshContainer::at(const MeshCoordinate& coord) { + return values_.at(to_linear_index(shape_, coord)); +} + +template +const T& MeshContainer::at(const MeshCoordinate& coord) const { + return values_.at(to_linear_index(shape_, coord)); +} + +template +MeshContainer::Iterator::Iterator( + MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : + container_(container), + coord_iter_(coord_iter), + linear_index_(linear_index), + value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {} + +template +typename MeshContainer::Iterator& MeshContainer::Iterator::operator++() { + ++linear_index_; + ++coord_iter_; + value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]); + return *this; +} + +template +typename MeshContainer::Iterator::ValueProxy& MeshContainer::Iterator::operator*() { + return value_proxy_; +} + +template +MeshContainer::ConstIterator::ConstIterator( + const MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : + container_(container), + coord_iter_(coord_iter), + linear_index_(linear_index), + value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {} + +template +typename MeshContainer::ConstIterator& MeshContainer::ConstIterator::operator++() { + ++linear_index_; + ++coord_iter_; + value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]); + return *this; +} + +template +const typename MeshContainer::ConstIterator::ValueProxy& MeshContainer::ConstIterator::operator*() const { + return value_proxy_; +} + +template +bool MeshContainer::Iterator::operator==(const Iterator& other) const { + return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; +} + +template +bool MeshContainer::Iterator::operator!=(const Iterator& other) const { + return !(*this == other); +} + +template +bool MeshContainer::ConstIterator::operator==(const ConstIterator& other) const { + return container_ == other.container_ && coord_iter_ == other.coord_iter_ && linear_index_ == other.linear_index_; +} + +template +bool MeshContainer::ConstIterator::operator!=(const ConstIterator& other) const { + return !(*this == other); +} + +template +typename MeshContainer::Iterator MeshContainer::begin() { + return Iterator(this, coord_range_.begin(), /* linear_index = */ 0); +} + +template +typename MeshContainer::Iterator MeshContainer::end() { + return Iterator(this, coord_range_.end(), shape_.mesh_size()); +} + +template +typename MeshContainer::ConstIterator MeshContainer::begin() const { + return ConstIterator(this, coord_range_.begin(), /* linear_index = */ 0); +} + +template +typename MeshContainer::ConstIterator MeshContainer::end() const { + return ConstIterator(this, coord_range_.end(), shape_.mesh_size()); +} + +} // namespace tt::tt_metal::distributed + +namespace std { + +template +struct tuple_size> : std::integral_constant { +}; + +template +struct tuple_element<0, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = const tt::tt_metal::distributed::MeshCoordinate; +}; + +template +struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = T; +}; + +} // namespace std diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 91638a57cb6..979e603a6cd 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -12,6 +12,7 @@ #include "device.hpp" #include "mesh_config.hpp" +#include "mesh_coord.hpp" #include "mesh_device_view.hpp" #include "sub_device_types.hpp" #include "span.hpp" @@ -204,6 +205,7 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this -#include #include "small_vector.hpp" +#include "span.hpp" namespace tt::tt_metal { @@ -24,7 +24,7 @@ class ShapeBase { explicit ShapeBase(const std::array& arr) : value_(arr.begin(), arr.end()) { init(); } - explicit ShapeBase(std::span span) : value_(span.begin(), span.end()) { init(); } + explicit ShapeBase(tt::stl::Span span) : value_(span.begin(), span.end()) { init(); } template bool operator==(const std::array& other) const { @@ -42,7 +42,7 @@ class ShapeBase { Container::const_iterator cbegin() const; Container::const_iterator cend() const; - std::span view() const; + tt::stl::Span view() const; bool empty() const; diff --git a/tt_metal/common/CMakeLists.txt b/tt_metal/common/CMakeLists.txt index 28f27de3edf..7d43d25d5b0 100644 --- a/tt_metal/common/CMakeLists.txt +++ b/tt_metal/common/CMakeLists.txt @@ -1,6 +1,7 @@ set(COMMON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/core_assignment.cpp ${CMAKE_CURRENT_SOURCE_DIR}/core_coord.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mesh_coord.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal_soc_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/shape2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/shape_base.cpp diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp new file mode 100644 index 00000000000..9a98a0ce801 --- /dev/null +++ b/tt_metal/common/mesh_coord.cpp @@ -0,0 +1,161 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include +#include +#include +#include +#include + +namespace tt::tt_metal::distributed { +namespace { + +// Returns a zero coordinate of dimensionality `dims`. +MeshCoordinate zero_coordinate(size_t dims) { return MeshCoordinate(tt::stl::SmallVector(dims, 0)); } + +// Returns the last valid coordinate for the provided `shape`. +MeshCoordinate shape_back(const SimpleMeshShape& shape) { + tt::stl::SmallVector coords; + for (int i = 0; i < shape.dims(); i++) { + coords.push_back(shape[i] - 1); + } + return MeshCoordinate(coords); +} + +} // namespace + +SimpleMeshShape::SimpleMeshShape(uint32_t x) : ShapeBase({x}) { compute_strides(); } +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y) : ShapeBase({x, y}) { compute_strides(); } +SimpleMeshShape::SimpleMeshShape(uint32_t x, uint32_t y, uint32_t z) : ShapeBase({x, y, z}) { compute_strides(); } + +SimpleMeshShape::SimpleMeshShape(const MeshShape& legacy_shape) : + SimpleMeshShape(legacy_shape.num_rows, legacy_shape.num_cols) {} + +void SimpleMeshShape::compute_strides() { + size_t stride = 1; + strides_.resize(dims()); + for (int dim = dims() - 1; dim >= 0; --dim) { + strides_[dim] = stride; + stride *= (*this)[dim]; + } +} + +size_t SimpleMeshShape::get_stride(size_t dim) const { return strides_[dim]; } + +size_t SimpleMeshShape::dims() const { return size(); } +size_t SimpleMeshShape::mesh_size() const { + return empty() ? 0 : std::accumulate(value_.begin(), value_.end(), 1, std::multiplies()); +} + +bool operator==(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) = default; +bool operator!=(const SimpleMeshShape& lhs, const SimpleMeshShape& rhs) = default; + +std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape) { + os << "SimpleMeshShape(["; + for (size_t i = 0; i < shape.dims(); ++i) { + if (i > 0) { + os << ", "; + } + os << shape[i]; + } + os << "])"; + return os; +} + +MeshCoordinate::MeshCoordinate(uint32_t coord) : value_({coord}) {} +MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y) : value_({x, y}) {} +MeshCoordinate::MeshCoordinate(uint32_t x, uint32_t y, uint32_t z) : value_({x, y, z}) {} + +MeshCoordinate::MeshCoordinate(tt::stl::Span coords) : value_(coords.begin(), coords.end()) {} + +size_t MeshCoordinate::dims() const { return value_.size(); } +tt::stl::Span MeshCoordinate::coords() const { return value_; } +uint32_t MeshCoordinate::operator[](size_t dim) const { return value_[dim]; } + +bool operator==(const MeshCoordinate& lhs, const MeshCoordinate& rhs) { + return lhs.dims() == rhs.dims() && std::equal(lhs.coords().begin(), lhs.coords().end(), rhs.coords().begin()); +} +bool operator!=(const MeshCoordinate& lhs, const MeshCoordinate& rhs) { return !(lhs == rhs); } + +std::ostream& operator<<(std::ostream& os, const MeshCoordinate& coord) { + os << "MeshCoordinate(" << coord.dims() << ", ["; + for (size_t dim : coord.coords()) { + os << dim << ", "; + } + os << "])"; + return os; +} + +MeshCoordinateRange::MeshCoordinateRange(const MeshCoordinate& start, const MeshCoordinate& end) : + start_(start), end_(end) { + TT_FATAL( + start.dims() == end.dims(), + "Start and end dimensions of a coordinate range do not match: {} != {}", + start.dims(), + end.dims()); + for (size_t i = 0; i < start.dims(); ++i) { + TT_FATAL(start[i] <= end[i], "Start coordinate is greater than end coordinate: {} > {}", start, end); + } +} + +MeshCoordinateRange::MeshCoordinateRange(const SimpleMeshShape& shape) : + MeshCoordinateRange(zero_coordinate(shape.dims()), shape_back(shape)) {} + +const MeshCoordinate& MeshCoordinateRange::start_coord() const { return start_; } +const MeshCoordinate& MeshCoordinateRange::end_coord() const { return end_; } + +MeshCoordinateRange::Iterator::Iterator( + const MeshCoordinateRange* range, const MeshCoordinate& current, size_t linear_index) : + range_(range), current_coord_(current), linear_index_(linear_index) {} + +MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() { + ++linear_index_; + + tt::stl::SmallVector new_coords(current_coord_.coords().begin(), current_coord_.coords().end()); + for (int i = new_coords.size() - 1; i >= 0; --i) { + auto& dimension_value = new_coords[i]; + if (++dimension_value > range_->end_coord()[i]) { + dimension_value = range_->start_coord()[i]; + } else { + break; + } + } + current_coord_ = MeshCoordinate(new_coords); + return *this; +} +const MeshCoordinate& MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } +bool MeshCoordinateRange::Iterator::operator==(const Iterator& other) const { + return range_ == other.range_ && linear_index_ == other.linear_index_; +} +bool MeshCoordinateRange::Iterator::operator!=(const Iterator& other) const { return !(*this == other); } + +MeshCoordinateRange::Iterator MeshCoordinateRange::begin() const { return Iterator(this, start_, /*linear_index=*/0); } +MeshCoordinateRange::Iterator MeshCoordinateRange::end() const { + size_t range_size = 1; + for (size_t i = 0; i < start_.dims(); ++i) { + range_size *= end_[i] - start_[i] + 1; + } + // Set `start_` coordinate but `range_size` linear index as the wrap around condition. + return Iterator(this, start_, range_size); +} + +size_t to_linear_index(const SimpleMeshShape& shape, const MeshCoordinate& coord) { + TT_FATAL( + shape.dims() == coord.dims(), + "Shape and coordinate dimensions do not match: {} != {}", + shape.dims(), + coord.dims()); + + size_t linear_index = 0; + for (size_t dim = 0; dim < coord.dims(); ++dim) { + TT_FATAL(coord[dim] < shape[dim], "Coordinate {} is out of bounds for shape {}", coord, shape); + linear_index += coord[dim] * shape.get_stride(dim); + } + return linear_index; +} + +} // namespace tt::tt_metal::distributed diff --git a/tt_metal/common/shape_base.cpp b/tt_metal/common/shape_base.cpp index 57e69bb49e6..33acd941d22 100644 --- a/tt_metal/common/shape_base.cpp +++ b/tt_metal/common/shape_base.cpp @@ -4,7 +4,9 @@ #include "assert.hpp" #include "shape_base.hpp" +#include #include +#include #include "fmt/color.h" namespace tt::tt_metal { @@ -46,7 +48,14 @@ bool ShapeBase::empty() const { return original_size_ == 0; } size_t ShapeBase::size() const { return original_size_; } -std::span ShapeBase::view() const { return std::span(cbegin(), cend()); } +tt::stl::Span ShapeBase::view() const { + const auto begin = cbegin(); + const auto end = cend(); + // `Span` constructor requires a contiguous range of data. + static_assert( + std::is_base_of_v::iterator_category>); + return tt::stl::Span(&*begin, std::distance(begin, end)); +} bool ShapeBase::operator==(const ShapeBase& other) const = default; diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index a0bf7b76e86..13d1fc5e6cc 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -4,6 +4,8 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include +#include #include #include @@ -110,12 +112,9 @@ std::shared_ptr MeshBuffer::create( } void MeshBuffer::initialize_device_buffers() { - buffers_ = std::vector>>( - mesh_device_->num_rows(), std::vector>(mesh_device_->num_cols())); - - auto init_device_buffer_at_address = [this](const Coordinate& coord) { + auto init_device_buffer_at_address = [this](const MeshCoordinate& coord) { std::shared_ptr buffer = Buffer::create( - mesh_device_->get_device(coord.row, coord.col), + mesh_device_->get_device(coord), address_, device_local_size_, device_local_config_.page_size, @@ -126,10 +125,8 @@ void MeshBuffer::initialize_device_buffers() { return buffer; }; - for (int row = 0; row < mesh_device_->num_rows(); row++) { - for (int col = 0; col < mesh_device_->num_cols(); col++) { - buffers_[row][col] = init_device_buffer_at_address(Coordinate{row, col}); - } + for (auto& [coord, device_buffer] : buffers_) { + device_buffer = init_device_buffer_at_address(coord); } } @@ -138,14 +135,11 @@ bool MeshBuffer::is_allocated() const { return not std::holds_alternative MeshBuffer::get_device_buffer(const Coordinate& device_coord) const { - TT_FATAL( - device_coord.row < mesh_device_->num_rows() and device_coord.col < mesh_device_->num_cols(), - "Logical coordinates must be within the bounds of the mesh: {}, {}, mesh shape: {}, {}", - device_coord.row, - device_coord.col, - mesh_device_->num_rows(), - mesh_device_->num_cols()); - return buffers_[device_coord.row][device_coord.col]; + return get_device_buffer(MeshCoordinate(device_coord.row, device_coord.col)); +} + +std::shared_ptr MeshBuffer::get_device_buffer(const MeshCoordinate& device_coord) const { + return buffers_.at(device_coord); } DeviceAddr MeshBuffer::size() const { diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 04edd94373b..603ce95212e 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -211,7 +211,11 @@ std::vector MeshDevice::get_devices() const { return view_->get_device // TODO: Remove this function once we have a proper view interface IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { - return this->get_device_index(row_idx * num_cols() + col_idx); + return get_device(MeshCoordinate{row_idx, col_idx}); +} + +IDevice* MeshDevice::get_device(const MeshCoordinate& coord) const { + return this->get_device_index(to_linear_index(SimpleMeshShape(mesh_shape_), coord)); } MeshCommandQueue& MeshDevice::mesh_command_queue(std::size_t cq_id) const { diff --git a/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp b/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp index d54d6a1c6e7..a1b17cec8d5 100644 --- a/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp +++ b/tt_metal/programming_examples/distributed/2_distributed_buffer_rw/distributed_buffer_rw.cpp @@ -26,7 +26,7 @@ int main(int argc, char** argv) { // We will create a distributed buffer with 8 shards of {32, 32} and distribute it across the devices in the mesh. auto shard_shape = Shape2D{32, 32}; auto distributed_buffer_shape = Shape2D{32 * mesh_device->num_rows(), 32 * mesh_device->num_cols()}; - uint32_t tile_size_bytes = detail::TileSize(tt::DataFormat::UInt32); + uint32_t tile_size_bytes = tt::tt_metal::detail::TileSize(tt::DataFormat::UInt32); uint32_t distributed_buffer_size_bytes = 64 * 128 * tile_size_bytes; auto local_buffer_config = DeviceLocalBufferConfig{ diff --git a/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp b/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp index 73bf18ee0be..9dbf0bbbd61 100644 --- a/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp +++ b/tt_metal/programming_examples/distributed/3_distributed_eltwise_add/distributed_eltwise_add.cpp @@ -92,7 +92,7 @@ int main(int argc, char** argv) { auto distributed_buffer_shape = Shape2D{shard_shape.height() * mesh_device->num_rows(), shard_shape.width() * mesh_device->num_cols()}; auto num_tiles = 1; - auto tile_size_bytes = detail::TileSize(tt::DataFormat::Float16_b); + auto tile_size_bytes = tt::tt_metal::detail::TileSize(tt::DataFormat::Float16_b); auto distributed_buffer_size_bytes = mesh_device->num_rows() * mesh_device->num_cols() * tile_size_bytes; // Configure device-local buffer settings diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index b5232f2c464..9e4382f3d73 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -14,10 +14,7 @@ namespace ttnn::operations::data_movement { namespace { -template -bool eq_spans(const ArrayType& a, const ArrayType& b) { - return std::equal(a.begin(), a.end(), b.begin(), b.end()); -} +bool eq_spans(const auto a, const auto b) { return std::equal(a.begin(), a.end(), b.begin(), b.end()); } ttnn::Shape update_original_shape(const ttnn::Shape& padded_shape, const ttnn::Shape& input_shape) { ttnn::SmallVector updated_shape;