From be2fd92f3568cc91e57bcc0a113cd976d27725fc Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Sun, 9 Feb 2025 06:26:06 +0000 Subject: [PATCH] Value proxy class for better container iteration --- .../tt_metal/distributed/test_mesh_coord.cpp | 33 +++++++ tt_metal/api/tt-metalium/mesh_coord.hpp | 93 ++++++++++++++++--- tt_metal/common/mesh_coord.cpp | 2 +- tt_metal/distributed/mesh_buffer.cpp | 4 +- 4 files changed, 118 insertions(+), 14 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_coord.cpp b/tests/tt_metal/distributed/test_mesh_coord.cpp index 80f3f299198..a0baff76b74 100644 --- a/tests/tt_metal/distributed/test_mesh_coord.cpp +++ b/tests/tt_metal/distributed/test_mesh_coord.cpp @@ -210,6 +210,39 @@ TEST(MeshContainerTest, ElementAccessRowMajor) { container.at(MeshCoordinate(1, 1)) = 4; container.at(MeshCoordinate(1, 2)) = 5; + std::vector coords; + std::vector values; + for (const auto& [coord, value] : container) { + coords.push_back(coord); + values.push_back(value); + } + EXPECT_THAT( + coords, + ElementsAre( + MeshCoordinate(0, 0), + MeshCoordinate(0, 1), + MeshCoordinate(0, 2), + MeshCoordinate(1, 0), + MeshCoordinate(1, 1), + MeshCoordinate(1, 2))); + EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +TEST(MeshContainerTest, MutateThroughProxy) { + SimpleMeshShape shape(2, 3); + MeshContainer container(shape, 0); + + // Proxy class provides access to the container value through the mutable reference. + int updated_value = 0; + for (auto& [coord, value] : container) { + value = updated_value++; + } + + // `auto` makes a copy of the value, verify this loop is a no-op. + for (auto [coord, value] : container) { + value = updated_value++; + } + std::vector values; for (const auto& [coord, value] : container) { values.push_back(value); diff --git a/tt_metal/api/tt-metalium/mesh_coord.hpp b/tt_metal/api/tt-metalium/mesh_coord.hpp index b3d74dca4b1..20718751dab 100644 --- a/tt_metal/api/tt-metalium/mesh_coord.hpp +++ b/tt_metal/api/tt-metalium/mesh_coord.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include "shape_base.hpp" @@ -18,11 +19,6 @@ class SimpleMeshShape : public ShapeBase { public: using ShapeBase::ShapeBase; using ShapeBase::operator[]; - using ShapeBase::cbegin; - using ShapeBase::cend; - using ShapeBase::empty; - using ShapeBase::size; - using ShapeBase::view; // Shorthands for constructing 1D, 2D and 3D shapes. SimpleMeshShape(uint32_t num_elements); @@ -50,6 +46,9 @@ class SimpleMeshShape : public ShapeBase { friend std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape); private: + using ShapeBase::empty; + using ShapeBase::size; + void compute_strides(); tt::stl::SmallVector strides_; }; @@ -104,7 +103,7 @@ class MeshCoordinateRange { class Iterator { public: Iterator& operator++(); - MeshCoordinate operator*() const; + const MeshCoordinate& operator*() const; bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; @@ -131,6 +130,53 @@ class MeshCoordinateRange { MeshCoordinate end_; }; +namespace detail { + +// Proxy class that allows convenient structured binding to a pair of a coordinate and the value it points to. +// This supports iterator semantics similar to `std::map` / `std::unordered_map`. +template +class MeshCoordinateValueProxy { +public: + MeshCoordinateValueProxy(const MeshCoordinate* coord, T* value_ptr) : coord_(coord), value_ptr_(value_ptr) {} + + const MeshCoordinate& coord() const { return *coord_; } + T& value() { return *value_ptr_; } + const T& value() const { return *value_ptr_; } + + template + decltype(auto) get() & { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + template + decltype(auto) get() const& { + if constexpr (I == 0) { + return coord(); + } else if constexpr (I == 1) { + return value(); + } else { + static_assert(I < 2); + } + } + + template + auto get() const&& { + return get(); + } + +private: + const MeshCoordinate* coord_ = nullptr; + T* value_ptr_ = nullptr; +}; + +} // namespace detail + // Allows storing data in a mesh-shaped container, with convenient accessors and iterators. template class MeshContainer { @@ -147,10 +193,10 @@ class MeshContainer { // Allows to iterate over the container elements, returning a pair of (coordinate, value reference). class Iterator { public: - using reference = std::pair>; + using ValueProxy = detail::MeshCoordinateValueProxy; Iterator& operator++(); - reference operator*() const; + ValueProxy& operator*(); bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const; @@ -161,6 +207,9 @@ class MeshContainer { MeshContainer* container_ = nullptr; MeshCoordinateRange::Iterator coord_iter_; size_t linear_index_ = 0; + + // Provides mutable access to the container value along with the coordinate from the range iterator. + ValueProxy value_proxy_; }; Iterator begin(); @@ -194,18 +243,22 @@ const T& MeshContainer::at(const MeshCoordinate& coord) const { template MeshContainer::Iterator::Iterator( MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) : - container_(container), coord_iter_(coord_iter), linear_index_(linear_index) {} + container_(container), + coord_iter_(coord_iter), + linear_index_(linear_index), + value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {} template typename MeshContainer::Iterator& MeshContainer::Iterator::operator++() { ++linear_index_; ++coord_iter_; + value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]); return *this; } template -typename MeshContainer::Iterator::reference MeshContainer::Iterator::operator*() const { - return {*coord_iter_, std::ref(container_->values_[linear_index_])}; +typename MeshContainer::Iterator::ValueProxy& MeshContainer::Iterator::operator*() { + return value_proxy_; } template @@ -229,3 +282,21 @@ typename MeshContainer::Iterator MeshContainer::end() { } } // namespace tt::tt_metal::distributed + +namespace std { + +template +struct tuple_size> : std::integral_constant { +}; + +template +struct tuple_element<0, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = const tt::tt_metal::distributed::MeshCoordinate; +}; + +template +struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy> { + using type = T; +}; + +} // namespace std diff --git a/tt_metal/common/mesh_coord.cpp b/tt_metal/common/mesh_coord.cpp index c79af09a63d..f0e7890bfda 100644 --- a/tt_metal/common/mesh_coord.cpp +++ b/tt_metal/common/mesh_coord.cpp @@ -129,7 +129,7 @@ MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() { current_coord_ = MeshCoordinate(new_coords); return *this; } -MeshCoordinate MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } +const MeshCoordinate& MeshCoordinateRange::Iterator::operator*() const { return current_coord_; } bool MeshCoordinateRange::Iterator::operator==(const Iterator& other) const { return range_ == other.range_ && linear_index_ == other.linear_index_; } diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index 660a0fe8529..13d1fc5e6cc 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -125,8 +125,8 @@ void MeshBuffer::initialize_device_buffers() { return buffer; }; - for (auto [coord, device_buffer] : buffers_) { - device_buffer.get() = init_device_buffer_at_address(coord); + for (auto& [coord, device_buffer] : buffers_) { + device_buffer = init_device_buffer_at_address(coord); } }