From 4893de89c83e1656ebf3850664e0bfc5ff95090a Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Thu, 6 Feb 2025 23:52:09 +0000 Subject: [PATCH] #0: Fix failing Llama TG tests by preserving old behavior for ShardTensorToMesh Previously, when we had a MxN MeshDevice, a mesh_mapper of ShardTensorToMesh would behave differently based on whether `mesh_type` passed into the MeshDevice was MeshType::RowMajor, MeshType::Ring. With the removal of `MeshType` from MeshDevice specification, this changed the default behavior for users constructing a MeshDevice with default mesh_type=MeshType::RowMajor. This change now preserves the old behavior so that shards are distributed in row-major instead of a line. --- conftest.py | 1 + ttnn/cpp/ttnn/distributed/api.cpp | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index 510905dd8f7..4be5deca442 100644 --- a/conftest.py +++ b/conftest.py @@ -258,6 +258,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic **updated_device_params, offset=ttnn.MeshOffset(0, 1), ) + mesh_device.reshape(ttnn.MeshShape(1, 4)) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") yield mesh_device diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 831c1f4cbd5..bd0fd35a206 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -153,7 +153,6 @@ std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_ [&](const ShardTensor2D& s) { return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); }, - [&](const ShardTensor& s) { return get_workers_for_tensor(mesh_device.get_view().get_line_devices()); }, [&](const auto&) { return get_workers_for_tensor(mesh_device.get_devices()); }}, host_storage.strategy); } else if (std::holds_alternative(tensor.get_storage())) {