Skip to content

Commit

Permalink
#15061: Refactor utilities related to Mesh infra (#15757)
Browse files Browse the repository at this point in the history
### Ticket
#15061 

### What's changed
* Refactor `DistributedTensorConfig` in it's own header
* Use typed `struct` to represent `MeshShape` and `MeshOffset`

### Checklist
- [X] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12210236362)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Dec 9, 2024
1 parent 9818f7f commit e3526de
Show file tree
Hide file tree
Showing 25 changed files with 245 additions and 205 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
mesh_shape=ttnn.MeshShape(2, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
offset=(0, 1),
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ TEST(MeshDeviceSuite, Test1x1SystemMeshInitialize) {
auto& sys = tt::tt_metal::distributed::SystemMesh::instance();

auto config =
tt::tt_metal::distributed::MeshDeviceConfig({1, 1}, std::pair<size_t, size_t>(0, 0), {}, MeshType::RowMajor);
tt::tt_metal::distributed::MeshDeviceConfig(MeshShape(1, 1), MeshOffset(0, 0), {}, MeshType::RowMajor);

EXPECT_NO_THROW({
auto mesh = tt::tt_metal::distributed::MeshDevice::create(
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
auto view = ttnn::MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 =
view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
view.get_devices_on_column(mesh_shape.num_cols - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 =
view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
Expand Down
20 changes: 10 additions & 10 deletions tt-train/sources/ttml/core/distributed_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class XTensorToMesh {
tt::tt_metal::distributed::MeshShape m_mesh_shape;

size_t get_num_devices() const {
return m_mesh_shape.first * m_mesh_shape.second;
return m_mesh_shape.num_rows * m_mesh_shape.num_cols;
}
};

Expand Down Expand Up @@ -130,8 +130,8 @@ class ShardTensor2dMesh : public XTensorToMesh<ShardTensor2dMesh<T>, T> {
throw std::invalid_argument("ShardTensor2dMesh requires at least one dimension to shard");
}

int rows = Base::m_mesh_shape.first;
int cols = Base::m_mesh_shape.second;
int rows = Base::m_mesh_shape.num_rows;
int cols = Base::m_mesh_shape.num_cols;
auto row_dim = m_dims.first;
auto col_dim = m_dims.second;

Expand Down Expand Up @@ -178,8 +178,8 @@ class ShardTensor2dMesh : public XTensorToMesh<ShardTensor2dMesh<T>, T> {
std::unordered_map<std::string, std::string> config_impl() const {
return {
{"strategy", "shard_2d"},
{"mesh_shape_y", std::to_string(Base::m_mesh_shape.first)},
{"mesh_shape_x", std::to_string(Base::m_mesh_shape.second)}};
{"mesh_shape_y", std::to_string(Base::m_mesh_shape.num_rows)},
{"mesh_shape_x", std::to_string(Base::m_mesh_shape.num_cols)}};
}

private:
Expand All @@ -193,16 +193,16 @@ class ConcatMesh2dToTensor : public MeshToXTensor<ConcatMesh2dToTensor<T>, T> {
ConcatMesh2dToTensor(
tt::tt_metal::distributed::MeshShape mesh_shape, const tt::tt_metal::distributed::MeshShape& dims) :
Base(std::move(mesh_shape)), m_dims(dims) {
if (m_dims.first == m_dims.second) {
if (m_dims.num_rows == m_dims.num_cols) {
throw std::invalid_argument("Dimensions in 'dims' must be different");
}
}

std::vector<xt::xarray<T>> compose_impl(const std::vector<xt::xarray<T>>& tensors) const {
int rows = Base::m_mesh_shape.first;
int cols = Base::m_mesh_shape.second;
size_t row_dim = m_dims.first;
size_t col_dim = m_dims.second;
int rows = Base::m_mesh_shape.num_rows;
int cols = Base::m_mesh_shape.num_cols;
size_t row_dim = m_dims.num_rows;
size_t col_dim = m_dims.num_cols;

std::vector<xt::xarray<T>> row_concatenated;
row_concatenated.reserve(static_cast<size_t>(rows));
Expand Down
2 changes: 1 addition & 1 deletion tt-train/tests/core/distributed_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) {

TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) {
tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2};
int num_devices = mesh_shape.first * mesh_shape.second; // 4
int num_devices = mesh_shape.num_rows * mesh_shape.num_cols; // 4

auto tensor = xt::arange<TypeParam>(4); // [0,1,2,3]

Expand Down
48 changes: 24 additions & 24 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ MeshShape SystemMesh::Impl::get_system_mesh_shape(size_t system_num_devices) {
TT_FATAL(
system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices);
auto shape = system_mesh_to_shape.at(system_num_devices);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.first, shape.second);
log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols);
return shape;
}

Expand Down Expand Up @@ -293,32 +293,32 @@ std::shared_ptr<MeshDevice> MeshDevice::create(

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) {
if (submesh_shape.first <= 0 || submesh_shape.second <= 0) {
if (submesh_shape.num_rows <= 0 || submesh_shape.num_cols <= 0) {
TT_THROW(
"Invalid submesh shape: ({}, {}). Both dimensions must be positive.",
submesh_shape.first,
submesh_shape.second);
submesh_shape.num_rows,
submesh_shape.num_cols);
}

if (offset.first < 0 || offset.second < 0) {
TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second);
if (offset.row < 0 || offset.col < 0) {
TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.row, offset.col);
}

if (offset.first + submesh_shape.first > this->mesh_device_shape.first ||
offset.second + submesh_shape.second > this->mesh_device_shape.second) {
if (offset.row + submesh_shape.num_rows > this->mesh_device_shape.num_rows ||
offset.col + submesh_shape.num_cols > this->mesh_device_shape.num_cols) {
TT_THROW(
"Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).",
submesh_shape.first,
submesh_shape.second,
offset.first,
offset.second,
this->mesh_device_shape.first,
this->mesh_device_shape.second);
submesh_shape.num_rows,
submesh_shape.num_cols,
offset.row,
offset.col,
this->mesh_device_shape.num_rows,
this->mesh_device_shape.num_cols);
}

auto submesh = std::make_shared<MeshDevice>(submesh_shape, type, shared_from_this());
auto start_coordinate = Coordinate{offset.first, offset.second};
auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1};
auto start_coordinate = Coordinate{offset.row, offset.col};
auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1};
submesh->primary_view = std::make_shared<MeshDeviceView>(*this, start_coordinate, end_coordinate);
submesh->devices = submesh->primary_view->get_devices();
SystemMesh::instance().register_mesh_device(submesh, submesh->devices);
Expand All @@ -327,19 +327,19 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
LogMetal,
"Instantiating submesh {}: {}x{} with offset: {} {}",
submesh->get_mesh_id(),
submesh_shape.first,
submesh_shape.second,
offset.first,
offset.second);
submesh_shape.num_rows,
submesh_shape.num_cols,
offset.row,
offset.col);
log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices);

return submesh;
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) {
std::vector<std::shared_ptr<MeshDevice>> submeshes;
for (int row = 0; row < this->num_rows(); row += submesh_shape.first) {
for (int col = 0; col < this->num_cols(); col += submesh_shape.second) {
for (int row = 0; row < this->num_rows(); row += submesh_shape.num_rows) {
for (int col = 0; col < this->num_cols(); col += submesh_shape.num_cols) {
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type);
submeshes.push_back(submesh);
}
Expand Down Expand Up @@ -413,9 +413,9 @@ CoreCoord MeshDevice::dram_grid_size() const { return this->reference_device()->

tt::ARCH MeshDevice::arch() const { return this->reference_device()->arch(); }

size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; }
size_t MeshDevice::num_rows() const { return this->mesh_device_shape.num_rows; }

size_t MeshDevice::num_cols() const { return this->mesh_device_shape.second; }
size_t MeshDevice::num_cols() const { return this->mesh_device_shape.num_cols; }

MeshShape MeshDevice::shape() const { return this->mesh_device_shape; }

Expand Down
5 changes: 4 additions & 1 deletion tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ namespace tt::tt_metal::distributed {

using DeviceIds = std::vector<int>;
using MeshDeviceID = size_t;
using MeshOffset = std::pair<size_t, size_t>;
struct MeshOffset {
size_t row = 0;
size_t col = 0;
};
class MeshDeviceView;

struct MeshSubDeviceManagerId;
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/distributed/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start,
}

MeshDeviceView::DeviceView MeshDeviceView::get_devices(const MeshShape& shape) {
return get_devices({0, 0}, {shape.first - 1, shape.second - 1});
return get_devices({0, 0}, {shape.num_rows - 1, shape.num_cols - 1});
}

std::vector<MeshDeviceView::device_pointer> MeshDeviceView::get_devices_on_row(size_t row) const {
Expand Down Expand Up @@ -128,7 +128,7 @@ bool MeshDeviceView::empty() const noexcept { return devices_.empty(); }

size_t MeshDeviceView::size() const noexcept { return devices_.size(); }

std::pair<size_t, size_t> MeshDeviceView::shape() const noexcept { return {num_rows(), num_cols()}; }
MeshShape MeshDeviceView::shape() const noexcept { return {num_rows(), num_cols()}; }

bool MeshDeviceView::contains(const Coordinate& coord) const noexcept {
return coord.row >= top_left_.row && coord.row <= bottom_right_.row && coord.col >= top_left_.col &&
Expand Down
5 changes: 4 additions & 1 deletion tt_metal/distributed/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ namespace tt::tt_metal::distributed {

// Forward declaration of MeshDevice
class MeshDevice;
using MeshShape = std::pair<size_t, size_t>;
struct MeshShape {
size_t num_rows = 0;
size_t num_cols = 0;
};

struct Coordinate {
size_t row;
Expand Down
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/global_semaphore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/api.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_tensor_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp
Expand Down
50 changes: 25 additions & 25 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

#include <memory>

#include "tt_metal/tt_stl/overloaded.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/distributed/distributed_tensor_config.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

using namespace tt::tt_metal;
Expand All @@ -21,7 +23,7 @@ std::shared_ptr<MeshDevice> open_mesh_device(
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
MeshType mesh_type,
const std::pair<size_t, size_t>& offset,
const MeshOffset& offset,
const std::vector<int>& physical_device_ids) {
auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type);
return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_config);
Expand Down Expand Up @@ -58,18 +60,20 @@ std::vector<ttnn::Tensor> get_device_tensors(const ttnn::Tensor& tensor) {
TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!");
}

Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards) {
Tensor aggregate_as_tensor(
const std::vector<Tensor>& tensor_shards, const tt::tt_metal::DistributedTensorConfig& config) {
TT_ASSERT(tensor_shards.size() > 0, "At least one tensor shard must be provided");
const auto& reference_shard = tensor_shards.at(0);
for (const auto& shard : tensor_shards) {
if (shard.storage_type() != tensor_shards.at(0).storage_type()) {
if (shard.storage_type() != reference_shard.storage_type()) {
TT_THROW("All tensor shards must have the same storage type");
}
}

// Based whether the first tensor shard has OwnedBuffer or Device buffer,
// we want to use MultiDeviceHostStorage or MultiDeviceStorage
StorageType storage_type = tensor_shards.at(0).storage_type();
Tile tile = tensor_shards.at(0).get_tensor_spec().tile();
StorageType storage_type = reference_shard.storage_type();
Tile tile = reference_shard.get_tensor_spec().tile();
if (storage_type == StorageType::OWNED) {
std::vector<ttnn::Shape> shapes;
std::vector<OwnedBuffer> host_owned_buffers;
Expand All @@ -81,20 +85,20 @@ Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards) {
TT_THROW(
"Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling "
"configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}.",
tensor_shards.at(0).device()->id(),
reference_shard.device()->id(),
tile.get_height(),
tile.get_width(),
shard.device()->id(),
shard_tile.get_height(),
shard_tile.get_width());
}
}
auto storage = MultiDeviceHostStorage{AllGatherTensor(), std::move(host_owned_buffers), shapes};
auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), shapes};
return Tensor(
std::move(storage),
tensor_shards.at(0).get_legacy_shape(),
tensor_shards.at(0).get_dtype(),
tensor_shards.at(0).get_layout(),
reference_shard.get_legacy_shape(),
reference_shard.get_dtype(),
reference_shard.get_layout(),
tile);
} else {
std::vector<int> ordered_device_ids;
Expand All @@ -111,20 +115,20 @@ Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards) {
TT_THROW(
"Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling "
"configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}.",
tensor_shards.at(0).device()->id(),
reference_shard.device()->id(),
tile.get_height(),
tile.get_width(),
shard.device()->id(),
shard_tile.get_height(),
shard_tile.get_width());
}
}
auto storage = MultiDeviceStorage{AllGatherTensor(), ordered_device_ids, std::move(device_buffers), shapes};
auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), shapes};
return Tensor(
std::move(storage),
tensor_shards.at(0).get_legacy_shape(),
tensor_shards.at(0).get_dtype(),
tensor_shards.at(0).get_layout(),
reference_shard.get_legacy_shape(),
reference_shard.get_dtype(),
reference_shard.get_layout(),
tile);
}
}
Expand All @@ -140,7 +144,7 @@ std::vector<int> get_t3k_physical_device_ids_ring() {
return physical_device_ids;
}

std::vector<Device*> distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& mesh_device) {
std::vector<Device*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_device) {
// For multi-device tensors, returns the number of workers capped by the number of buffers
// Otherwise, returns all available workes from mesh_device.
auto get_workers_for_tensor = [&tensor, &mesh_device]() {
Expand All @@ -151,19 +155,15 @@ std::vector<Device*> distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice&
}
return workers;
};

if (mesh_device.get_view() != nullptr and std::holds_alternative<MultiDeviceHostStorage>(tensor.get_storage())) {
const auto& host_storage = std::get<tt::tt_metal::MultiDeviceHostStorage>(tensor.get_storage());

return std::visit(
[&](const auto& strategy) {
using StrategyType = std::decay_t<decltype(strategy)>;
if constexpr (std::is_same_v<StrategyType, ShardTensor2D>) {
return mesh_device.get_view()->get_devices(strategy.shard_mesh);
} else {
return get_workers_for_tensor();
}
},
tt::stl::overloaded{
[&](const ShardTensor2D& s) {
return mesh_device.get_view()->get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x});
},
[&](const auto&) { return get_workers_for_tensor(); }},
host_storage.strategy);
} else if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
return tensor.workers;
Expand Down
10 changes: 7 additions & 3 deletions ttnn/cpp/ttnn/distributed/api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/distributed/distributed_tensor_config.hpp"
#include "ttnn/distributed/types.hpp"

namespace ttnn::distributed::api {
Expand All @@ -18,19 +19,22 @@ std::shared_ptr<MeshDevice> open_mesh_device(
size_t num_command_queues,
const tt::tt_metal::DispatchCoreConfig& dispatch_core_config,
MeshType mesh_type = MeshType::RowMajor,
const std::pair<size_t, size_t>& offset = std::pair<size_t, size_t>(0, 0),
const MeshOffset& offset = MeshOffset(0, 0),
const std::vector<int>& physical_device_ids = {});

void close_mesh_device(const std::shared_ptr<MeshDevice>& mesh_device);

// Given a multi-device tensor, returns a list of individual per-device tensors.
std::vector<ttnn::Tensor> get_device_tensors(const ttnn::Tensor& tensor);

Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards);
// Given a list of per-device shards, returns multi-device tensor.
Tensor aggregate_as_tensor(
const std::vector<Tensor>& tensor_shards, const tt::tt_metal::DistributedTensorConfig& config);

std::vector<int> get_t3k_physical_device_ids_ring();

// Maps a tensor to the set of devices in the device-mesh that the shards will be distributed across.
std::vector<Device*> distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& mesh_device);
std::vector<Device*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_device);

// Get the distributed tensor config from a tensor.
tt::tt_metal::DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor);
Expand Down
Loading

0 comments on commit e3526de

Please sign in to comment.