Skip to content

Commit

Permalink
Value proxy class for better container iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Feb 9, 2025
1 parent 7b6a11c commit be2fd92
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 14 deletions.
33 changes: 33 additions & 0 deletions tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,39 @@ TEST(MeshContainerTest, ElementAccessRowMajor) {
container.at(MeshCoordinate(1, 1)) = 4;
container.at(MeshCoordinate(1, 2)) = 5;

std::vector<MeshCoordinate> coords;
std::vector<int> values;
for (const auto& [coord, value] : container) {
coords.push_back(coord);
values.push_back(value);
}
EXPECT_THAT(
coords,
ElementsAre(
MeshCoordinate(0, 0),
MeshCoordinate(0, 1),
MeshCoordinate(0, 2),
MeshCoordinate(1, 0),
MeshCoordinate(1, 1),
MeshCoordinate(1, 2)));
EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, MutateThroughProxy) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);

// Proxy class provides access to the container value through the mutable reference.
int updated_value = 0;
for (auto& [coord, value] : container) {
value = updated_value++;
}

// `auto` makes a copy of the value, verify this loop is a no-op.
for (auto [coord, value] : container) {
value = updated_value++;
}

std::vector<int> values;
for (const auto& [coord, value] : container) {
values.push_back(value);
Expand Down
93 changes: 82 additions & 11 deletions tt_metal/api/tt-metalium/mesh_coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <cstddef>
#include <type_traits>
#include <vector>

#include "shape_base.hpp"
Expand All @@ -18,11 +19,6 @@ class SimpleMeshShape : public ShapeBase {
public:
using ShapeBase::ShapeBase;
using ShapeBase::operator[];
using ShapeBase::cbegin;
using ShapeBase::cend;
using ShapeBase::empty;
using ShapeBase::size;
using ShapeBase::view;

// Shorthands for constructing 1D, 2D and 3D shapes.
SimpleMeshShape(uint32_t num_elements);
Expand Down Expand Up @@ -50,6 +46,9 @@ class SimpleMeshShape : public ShapeBase {
friend std::ostream& operator<<(std::ostream& os, const SimpleMeshShape& shape);

private:
using ShapeBase::empty;
using ShapeBase::size;

void compute_strides();
tt::stl::SmallVector<size_t> strides_;
};
Expand Down Expand Up @@ -104,7 +103,7 @@ class MeshCoordinateRange {
class Iterator {
public:
Iterator& operator++();
MeshCoordinate operator*() const;
const MeshCoordinate& operator*() const;
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

Expand All @@ -131,6 +130,53 @@ class MeshCoordinateRange {
MeshCoordinate end_;
};

namespace detail {

// Proxy class that allows convenient structured binding to a pair of a coordinate and the value it points to.
// This supports iterator semantics similar to `std::map` / `std::unordered_map`.
template <typename T>
class MeshCoordinateValueProxy {
public:
MeshCoordinateValueProxy(const MeshCoordinate* coord, T* value_ptr) : coord_(coord), value_ptr_(value_ptr) {}

const MeshCoordinate& coord() const { return *coord_; }
T& value() { return *value_ptr_; }
const T& value() const { return *value_ptr_; }

template <std::size_t I>
decltype(auto) get() & {
if constexpr (I == 0) {
return coord();
} else if constexpr (I == 1) {
return value();
} else {
static_assert(I < 2);
}
}

template <std::size_t I>
decltype(auto) get() const& {
if constexpr (I == 0) {
return coord();
} else if constexpr (I == 1) {
return value();
} else {
static_assert(I < 2);
}
}

template <std::size_t I>
auto get() const&& {
return get<I>();
}

private:
const MeshCoordinate* coord_ = nullptr;
T* value_ptr_ = nullptr;
};

} // namespace detail

// Allows storing data in a mesh-shaped container, with convenient accessors and iterators.
template <typename T>
class MeshContainer {
Expand All @@ -147,10 +193,10 @@ class MeshContainer {
// Allows to iterate over the container elements, returning a pair of (coordinate, value reference).
class Iterator {
public:
using reference = std::pair<MeshCoordinate, std::reference_wrapper<T>>;
using ValueProxy = detail::MeshCoordinateValueProxy<T>;

Iterator& operator++();
reference operator*() const;
ValueProxy& operator*();
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

Expand All @@ -161,6 +207,9 @@ class MeshContainer {
MeshContainer* container_ = nullptr;
MeshCoordinateRange::Iterator coord_iter_;
size_t linear_index_ = 0;

// Provides mutable access to the container value along with the coordinate from the range iterator.
ValueProxy value_proxy_;
};

Iterator begin();
Expand Down Expand Up @@ -194,18 +243,22 @@ const T& MeshContainer<T>::at(const MeshCoordinate& coord) const {
template <typename T>
MeshContainer<T>::Iterator::Iterator(
MeshContainer* container, const MeshCoordinateRange::Iterator& coord_iter, size_t linear_index) :
container_(container), coord_iter_(coord_iter), linear_index_(linear_index) {}
container_(container),
coord_iter_(coord_iter),
linear_index_(linear_index),
value_proxy_(&(*coord_iter_), &container_->values_[linear_index_]) {}

template <typename T>
typename MeshContainer<T>::Iterator& MeshContainer<T>::Iterator::operator++() {
++linear_index_;
++coord_iter_;
value_proxy_ = ValueProxy(&(*coord_iter_), &container_->values_[linear_index_]);
return *this;
}

template <typename T>
typename MeshContainer<T>::Iterator::reference MeshContainer<T>::Iterator::operator*() const {
return {*coord_iter_, std::ref(container_->values_[linear_index_])};
typename MeshContainer<T>::Iterator::ValueProxy& MeshContainer<T>::Iterator::operator*() {
return value_proxy_;
}

template <typename T>
Expand All @@ -229,3 +282,21 @@ typename MeshContainer<T>::Iterator MeshContainer<T>::end() {
}

} // namespace tt::tt_metal::distributed

namespace std {

template <typename T>
struct tuple_size<tt::tt_metal::distributed::detail::MeshCoordinateValueProxy<T>> : std::integral_constant<size_t, 2> {
};

template <typename T>
struct tuple_element<0, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy<T>> {
using type = const tt::tt_metal::distributed::MeshCoordinate;
};

template <typename T>
struct tuple_element<1, tt::tt_metal::distributed::detail::MeshCoordinateValueProxy<T>> {
using type = T;
};

} // namespace std
2 changes: 1 addition & 1 deletion tt_metal/common/mesh_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ MeshCoordinateRange::Iterator& MeshCoordinateRange::Iterator::operator++() {
current_coord_ = MeshCoordinate(new_coords);
return *this;
}
MeshCoordinate MeshCoordinateRange::Iterator::operator*() const { return current_coord_; }
const MeshCoordinate& MeshCoordinateRange::Iterator::operator*() const { return current_coord_; }
bool MeshCoordinateRange::Iterator::operator==(const Iterator& other) const {
return range_ == other.range_ && linear_index_ == other.linear_index_;
}
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ void MeshBuffer::initialize_device_buffers() {
return buffer;
};

for (auto [coord, device_buffer] : buffers_) {
device_buffer.get() = init_device_buffer_at_address(coord);
for (auto& [coord, device_buffer] : buffers_) {
device_buffer = init_device_buffer_at_address(coord);
}
}

Expand Down

0 comments on commit be2fd92

Please sign in to comment.