From e49454bba016ccba5f7c19627e87994a837e24c0 Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Wed, 2 Oct 2024 18:04:25 +0000 Subject: [PATCH] #0: done --- conftest.py | 9 ++- .../unit_tests/gtests/ttnn_test_fixtures.hpp | 4 +- tests/ttnn/unit_tests/test_multi_device.py | 23 ++++++++ tt_metal/impl/device/mesh_device.cpp | 56 ++++++++++++------- tt_metal/impl/device/mesh_device.hpp | 42 +++++++++++--- tt_metal/impl/device/mesh_device_view.cpp | 19 ++++--- tt_metal/impl/device/mesh_device_view.hpp | 14 ++--- ttnn/cpp/pybind11/multi_device.hpp | 31 +++++++--- ttnn/cpp/ttnn/multi_device.cpp | 5 +- ttnn/cpp/ttnn/multi_device.hpp | 15 ++++- ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/multi_device.py | 4 ++ 12 files changed, 161 insertions(+), 62 deletions(-) diff --git a/conftest.py b/conftest.py index df6c842636d..d3fc41b414a 100644 --- a/conftest.py +++ b/conftest.py @@ -235,7 +235,11 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic request.node.pci_ids = device_ids[:num_pcie_devices_requested] mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, 4), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1) + ttnn.MeshShape(2, 2), + dispatch_core_type=get_dispatch_core_type(), + **device_params, + offset=(0, 1), + mesh_type=ttnn.MeshType.Ring, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -280,9 +284,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device request.node.pci_ids = ttnn.get_pcie_device_ids() mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, 8), + ttnn.MeshShape(2, 4), dispatch_core_type=get_dispatch_core_type(), **device_params, + mesh_type=ttnn.MeshType.Ring, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index 9e779b7f0cc..2b0a8fc04a1 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -69,11 +69,11 @@ class T3kMultiDeviceFixture : public ::testing::Test { } constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1; mesh_device_ = MeshDevice::create( - MeshShape{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, DEFAULT_NUM_COMMAND_QUEUES, - DispatchCoreType::WORKER); + DispatchCoreType::WORKER, + MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring)); } void TearDown() override { diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index 32eb74e3c99..4753dcedf8d 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -591,3 +591,26 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): def test_visualize_mesh_device(t3k_mesh_device): ttnn.visualize_mesh_device(t3k_mesh_device) + + +def test_matmul_multiple_submeshes(t3k_mesh_device): + """Test all_gather with multiple submeshes""" + + def model(submesh): + ttnn.visualize_mesh_device(submesh) + + full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16) + for i in range(submesh.get_num_devices()): + full_tensor[..., i * 32 : (i + 1) * 32] = i + + ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3)) + ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh) + ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) + + for device_tensor in ttnn.get_device_tensors(ttnn_tensor): + device_tensor_torch = ttnn.to_torch(device_tensor) + assert torch.all(device_tensor_torch == full_tensor) + + submesh_devices = t3k_mesh_device.create_submeshes((2, 2), ttnn.MeshType.Ring) + for submesh in submesh_devices: + model(submesh) diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e6e8d3d1097..39232fbb465 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -155,21 +155,20 @@ std::vector SystemMesh::map_mesh_device( std::size_t l1_small_size, std::size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) { + const MeshDeviceConfig& config) { auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; - auto [row_offset, col_offset] = offset; + auto [row_offset, col_offset] = config.offset; log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset); TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - auto physical_device_ids = user_provided_physical_device_ids.empty() ? - this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : - user_provided_physical_device_ids; + auto physical_device_ids = config.physical_device_ids.empty() ? + this->get_mapped_physical_device_ids(config) : + config.physical_device_ids; this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices( physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); @@ -181,7 +180,7 @@ std::vector SystemMesh::map_mesh_device( this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } - this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this + this->register_mesh_device(mesh_device, mapped_devices); // here return mapped_devices; } @@ -213,25 +212,27 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr parent_mesh) - : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType order, std::shared_ptr parent_mesh) + : mesh_device_shape(mesh_device_shape), order(order), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( - const MeshShape& mesh_device_shape, std::size_t l1_small_size, std::size_t trace_region_size, std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) + const MeshDeviceConfig& config) { - auto mesh_device = std::make_shared(mesh_device_shape); - mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset, user_provided_physical_device_ids); + auto mesh_device = std::make_shared(config.mesh_shape, config.mesh_type); + mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); return mesh_device; } -std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) { +std::shared_ptr MeshDevice::create_submesh( + const MeshShape &submesh_shape, + const MeshOffset &offset, + MeshType order) +{ if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); } @@ -248,7 +249,7 @@ std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_ this->mesh_device_shape.first, this->mesh_device_shape.second); } - auto submesh = std::make_shared(submesh_shape, shared_from_this()); + auto submesh = std::make_shared(submesh_shape, order, shared_from_this()); auto start_coordinate = Coordinate{offset.first, offset.second}; auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; submesh->primary_view = std::make_shared(*this, start_coordinate, end_coordinate); @@ -261,13 +262,26 @@ std::shared_ptr MeshDevice::create_submesh(const MeshShape &submesh_ return submesh; } +std::vector> MeshDevice::create_submeshes( + const MeshShape &submesh_shape, + MeshType order) +{ + std::vector> submeshes; + for (int row = 0; row < this->num_rows(); row += submesh_shape.first) { + for (int col = 0; col < this->num_cols(); col += submesh_shape.second) { + auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, order); + submeshes.push_back(submesh); + } + } + return submeshes; +} + void MeshDevice::initialize( std::size_t l1_small_size, std::size_t trace_region_size, std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& physical_device_ids) + const MeshDeviceConfig& config) { auto [num_rows, num_cols] = this->shape(); auto num_requested_devices = num_rows * num_cols; @@ -279,7 +293,7 @@ void MeshDevice::initialize( auto& instance = SystemMesh::instance(); this->devices = instance.map_mesh_device( - shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); + shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config); this->primary_view = std::make_shared(*this); } @@ -304,7 +318,7 @@ Device* MeshDevice::get_device(chip_id_t physical_device_id) const { return SystemMesh::instance().get_device(physical_device_id); } -std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(IterationOrder::LINE); } +std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(this->order); } Device* MeshDevice::get_device(std::size_t row_idx, std::size_t col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); @@ -417,7 +431,7 @@ std::vector get_t3k_physical_device_ids_ring() { TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); auto physical_device_ids = instance.get_mapped_physical_device_ids( - MeshDeviceConfig{MeshShape{1, 8}, MeshOffset{0, 0}}); + MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0})); return physical_device_ids; } diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index d806c7ee9f9..7b7ac608b40 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -22,6 +22,26 @@ class MeshDeviceView; struct MeshDeviceConfig { MeshShape mesh_shape; MeshOffset offset; + std::vector physical_device_ids; + MeshType mesh_type; + + MeshDeviceConfig( + const MeshShape &mesh_shape, + MeshType mesh_type = MeshType::RowMajor) : + mesh_shape(mesh_shape), + offset(MeshOffset{0, 0}), + physical_device_ids(std::vector()), + mesh_type(mesh_type) {} + + MeshDeviceConfig( + const MeshShape &mesh_shape, + const MeshOffset &offset = MeshOffset{0, 0}, + const std::vector &physical_device_ids = {}, + MeshType mesh_type = MeshType::RowMajor) : + mesh_shape(mesh_shape), + offset(offset), + physical_device_ids(physical_device_ids), + mesh_type(mesh_type) {} }; // SystemMesh creates a virtualization over the physical devices in the system. @@ -80,8 +100,7 @@ class SystemMesh { std::size_t l1_small_size, std::size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); + const MeshDeviceConfig &config); // Unmap MeshDevice, releasing the associated physical devices. void unmap_mesh_device(const MeshDevice* mesh_device); @@ -93,6 +112,7 @@ class MeshDevice : public std::enable_shared_from_this { private: MeshDeviceID mesh_id; MeshShape mesh_device_shape; + MeshType order; std::shared_ptr primary_view; std::vector devices; std::shared_ptr parent_mesh; @@ -103,11 +123,10 @@ class MeshDevice : public std::enable_shared_from_this { std::size_t trace_region_size, std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset, - const std::vector &physical_device_ids); + const MeshDeviceConfig &config); public: - MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr parent_mesh = nullptr); + MeshDevice(const MeshShape &mesh_device_shape, MeshType order, std::shared_ptr parent_mesh = nullptr); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -146,17 +165,22 @@ class MeshDevice : public std::enable_shared_from_this { std::vector> get_submesh_views(); std::shared_ptr get_view(const Device* device); - std::shared_ptr create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0}); + std::shared_ptr create_submesh( + const MeshShape &submesh_shape, + const MeshOffset &offset = MeshOffset{0, 0}, + MeshType order = MeshType::RowMajor); + + std::vector> create_submeshes( + const MeshShape &submesh_shape, + MeshType order = MeshType::RowMajor); static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( - const MeshShape &mesh_device_shape, std::size_t l1_small_size, std::size_t trace_region_size, std::size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); + const MeshDeviceConfig &config); }; std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index 6d1f7527b30..911329d4678 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -243,11 +243,14 @@ std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ri // Traverse the bottom row from right to left, if there is more than one row if (ring_rows > 1 and ring_cols > 1) { - for (std::size_t col = end_col - 1; col >= start_col; --col) { - boundary_coords.emplace_back(Coordinate{end_row, col}); + // Traverse the bottom row from right to left + for (int col = static_cast(end_col - 1); col >= static_cast(start_col); --col) { + boundary_coords.emplace_back(Coordinate{end_row, static_cast(col)}); } - for (std::size_t row = end_row - 1; row >= start_row; --row) { - boundary_coords.emplace_back(Coordinate{row, start_col}); + + // Traverse the leftmost column from bottom-1 to top+1 + for (int row = static_cast(end_row - 1); row > static_cast(start_row); --row) { + boundary_coords.emplace_back(Coordinate{static_cast(row), start_col}); } } @@ -271,13 +274,13 @@ std::vector MeshDeviceView::get_ring_devices() { return get_devices_from_coordinates(*this, boundary_coords); } -MeshDeviceView::DeviceView MeshDeviceView::get_devices(IterationOrder order) { +MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType order) { switch (order) { - case IterationOrder::ROW_MAJOR: + case MeshType::RowMajor: return this->devices_; - case IterationOrder::RING: + case MeshType::Ring: return this->get_ring_devices(); - case IterationOrder::LINE: + case MeshType::Line: return this->get_line_devices(); default: TT_THROW("Unsupported iteration order: {}", order); diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index bb49e63d50e..2e2378ed16a 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -54,10 +54,10 @@ struct Coordinate { * (CCL-ops), such as line all-gather, which require column or row views of the device mesh. */ -enum class IterationOrder { - ROW_MAJOR, - RING, - LINE +enum class MeshType { + RowMajor, + Ring, + Line }; class MeshDeviceView { @@ -79,7 +79,7 @@ class MeshDeviceView { // devices are returned in row-major order with start/end coordinates inclusive [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end); [[nodiscard]] DeviceView get_devices(const MeshShape& shape); - [[nodiscard]] DeviceView get_devices(IterationOrder order = IterationOrder::ROW_MAJOR); + [[nodiscard]] DeviceView get_devices(MeshType order = MeshType::RowMajor); [[nodiscard]] DeviceView get_devices_on_row(std::size_t row) const; [[nodiscard]] DeviceView get_devices_on_column(std::size_t col) const; @@ -110,10 +110,10 @@ class MeshDeviceView { // The current support only provides left-to-right and right-to-left snaking of the line. [[nodiscard]] static std::vector get_line_coordinates(std::size_t length, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols); [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols); - -private: [[nodiscard]] std::vector get_ring_devices(); [[nodiscard]] std::vector get_line_devices(); + +private: std::vector devices_; std::unordered_map device_coordinates_; Coordinate top_left_; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index fb2ec846a2f..8642727884c 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -16,9 +16,16 @@ namespace ttnn { namespace multi_device { -void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); } +void py_module_types(py::module& module) { + py::class_>(module, "MeshDevice"); +} void py_module(py::module& module) { + py::enum_(module, "MeshType") + .value("RowMajor", MeshType::RowMajor) + .value("Ring", MeshType::Ring) + .value("Line", MeshType::Line) + .export_values(); auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( @@ -28,15 +35,15 @@ void py_module(py::module& module) { std::size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset, - const std::vector& physical_device_ids) { + const std::vector& physical_device_ids, + MeshType mesh_type) { + auto config = MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type); return MeshDevice::create( - mesh_device_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, - offset, - physical_device_ids); + config); }), py::kw_only(), py::arg("mesh_shape"), @@ -45,7 +52,8 @@ void py_module(py::module& module) { py::arg("num_command_queues"), py::arg("dispatch_core_type"), py::arg("offset"), - py::arg("physical_device_ids")) + py::arg("physical_device_ids"), + py::arg("mesh_type")) .def("get_num_devices", &MeshDevice::num_devices) .def("get_mesh_id", &MeshDevice::get_mesh_id) .def("get_device_ids", &MeshDevice::get_device_ids) @@ -63,7 +71,12 @@ void py_module(py::module& module) { Returns: List[Device]: The devices in the device mesh. )doc") - .def("create_submesh", &MeshDevice::create_submesh, py::arg("submesh_shape"), py::arg("offset") = std::pair{0, 0}, py::return_value_policy::reference_internal, py::keep_alive<0, 1>()) + .def("create_submesh", &MeshDevice::create_submesh, + py::arg("submesh_shape"), py::arg("offset"), py::arg("mesh_type"), + py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevice is alive + .def("create_submeshes", &MeshDevice::create_submeshes, + py::arg("submesh_shape"), py::arg("mesh_type"), + py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevices are alive .def( "compute_with_storage_grid_size", &MeshDevice::compute_with_storage_grid_size, @@ -108,7 +121,9 @@ void py_module(py::module& module) { py::arg("trace_region_size"), py::arg("num_command_queues"), py::arg("dispatch_core_type"), - py::arg("physical_device_ids")); + py::arg("offset"), + py::arg("physical_device_ids"), + py::arg("mesh_type")); module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/multi_device.cpp index 7fa5f9e0d65..b8a9e91a900 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/multi_device.cpp @@ -12,8 +12,9 @@ namespace ttnn::multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset) { - return MeshDevice::create(mesh_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset); +std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair& offset, const std::vector& physical_device_ids) { + auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type); + return MeshDevice::create(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); } void close_mesh_device(const std::shared_ptr& mesh_device) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index d7db05721bc..60dd6451aca 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -6,16 +6,25 @@ #include -#include "ttnn/types.hpp" -#include "ttnn/tensor/tensor.hpp" #include "tt_metal/impl/device/mesh_device.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" using Device = ttnn::Device; namespace ttnn { namespace multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset = {0, 0}); +std::shared_ptr open_mesh_device( + const MeshShape& mesh_shape, + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + MeshType mesh_type = MeshType::RowMajor, + const std::pair& offset = std::pair(0, 0), + const std::vector& physical_device_ids = std::vector()); + void close_mesh_device(const std::shared_ptr& mesh_device); std::vector get_device_tensors(const ttnn::Tensor& tensor); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index ddd1e95725d..8dcad39de63 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -197,6 +197,7 @@ def manage_config(name, value): visualize_mesh_device, ConcatMesh2dToTensor, distribute, + MeshType, ) from ttnn.core import ( diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 87f3297dbee..72574ab7503 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -18,6 +18,7 @@ def get_mesh_device_core_grid(mesh_device): MeshDevice = ttnn._ttnn.multi_device.MeshDevice MeshDevice.core_grid = property(get_mesh_device_core_grid) DispatchCoreType = ttnn._ttnn.device.DispatchCoreType +MeshType = ttnn._ttnn.multi_device.MeshType def _get_rich_table( @@ -140,6 +141,7 @@ def open_mesh_device( dispatch_core_type: int = DispatchCoreType.WORKER, offset: Tuple[int, int] = (0, 0), physical_device_ids: List[int] = [], + mesh_type: "MeshType" = MeshType.RowMajor, ): """ Open a mesh device with the specified configuration. @@ -151,6 +153,7 @@ def open_mesh_device( num_command_queues (int, optional): Number of command queues. Defaults to 1. dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. offset (Tuple[int, int], optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). + mesh_type (MeshType, optional): Order in which devices are iterated. Defaults to MeshType.ROW_MAJOR. Returns: ttnn._ttnn.multi_device.MeshDevice: The opened mesh device. @@ -164,6 +167,7 @@ def open_mesh_device( dispatch_core_type=dispatch_core_type, offset=offset, physical_device_ids=physical_device_ids, + mesh_type=mesh_type, )