Skip to content

Commit

Permalink
#0: Fix ordering of devices after reshape applied on MeshDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Feb 7, 2025
1 parent b552fb8 commit e3920bd
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 63 deletions.
87 changes: 30 additions & 57 deletions tests/ttnn/distributed/test_distributed_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TEST_P(MeshReshapeTest, TestReshapeBetweenConfigurations) {
EXPECT_EQ(mesh->num_rows(), old_shape.num_rows);
EXPECT_EQ(mesh->num_cols(), old_shape.num_cols);

auto original_order = get_physical_device_ids(*mesh);
auto original_order = mesh->get_device_ids();

// Attempt reshape
mesh->reshape({new_shape.num_rows, new_shape.num_cols});
Expand All @@ -93,7 +93,7 @@ TEST_P(MeshReshapeTest, TestReshapeBetweenConfigurations) {
EXPECT_EQ(mesh->num_cols(), new_shape.num_cols);

// Verify device ordering is preserved
EXPECT_EQ(get_physical_device_ids(*mesh), original_order);
EXPECT_EQ(mesh->get_device_ids(), original_order);
}

// Generate all possible combinations of shapes from kMeshShapes
Expand Down Expand Up @@ -121,35 +121,34 @@ TEST_F(T3000ReshapeTest, InvalidReshapeDimensions) {
EXPECT_EQ(mesh->num_cols(), 8);
}

TEST_F(T3000ReshapeTest, From1x8To2x4) {
TEST_F(T3000ReshapeTest, From1x8To2x4ThenBackTo1x8) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 8);
auto original_order = get_physical_device_ids(*mesh);

mesh->reshape({2, 4});
EXPECT_EQ(mesh->num_rows(), 2);
EXPECT_EQ(mesh->num_cols(), 4);
auto new_order = get_physical_device_ids(*mesh);
EXPECT_EQ(original_order, new_order);
}

TEST_F(T3000ReshapeTest, OnRingTopology) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 8);
auto original_order = get_physical_device_ids(*mesh);
auto original_order = mesh->get_device_ids();

mesh->reshape({2, 4});

EXPECT_EQ(mesh->num_rows(), 2);
EXPECT_EQ(mesh->num_cols(), 4);
auto new_order = get_physical_device_ids(*mesh);
EXPECT_EQ(original_order, new_order);
std::vector<chip_id_t> expected_physical_device_id_order = {
original_order[0],
original_order[1],
original_order[2],
original_order[3],
original_order[7],
original_order[6],
original_order[5],
original_order[4],
};

auto new_order = mesh->get_device_ids();
EXPECT_EQ(new_order, expected_physical_device_id_order);

mesh->reshape({1, 8});
EXPECT_EQ(mesh->get_device_ids(), original_order);
}

TEST_F(T3000ReshapeTest, InvalidTotalDeviceCount) {
Expand All @@ -165,26 +164,6 @@ TEST_F(T3000ReshapeTest, InvalidTotalDeviceCount) {
EXPECT_EQ(mesh->num_cols(), 8);
}

TEST_F(T3000ReshapeTest, MultipleReshapes) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto original_order = get_physical_device_ids(*mesh);

// Test multiple reshapes
mesh->reshape({2, 4}); // 1x8 -> 2x4
auto order1 = get_physical_device_ids(*mesh);
EXPECT_EQ(order1, original_order);

mesh->reshape({4, 2}); // 2x4 -> 4x2
auto order2 = get_physical_device_ids(*mesh);
EXPECT_EQ(order2, original_order);

mesh->reshape({1, 8}); // 4x2 -> 1x8 (back to original)
auto final_order = get_physical_device_ids(*mesh);
EXPECT_EQ(final_order, original_order);
}

TEST_F(T3000ReshapeTest, RingPreservation) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
Expand Down Expand Up @@ -239,7 +218,7 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
mesh->reshape({2, 2});
EXPECT_EQ(mesh->num_rows(), 2);
EXPECT_EQ(mesh->num_cols(), 2);
auto new_layout = get_physical_device_ids(*mesh);
auto new_layout = mesh->get_device_ids();
for (auto physical_device_id : physical_device_ids) {
EXPECT_TRUE(std::find(new_layout.begin(), new_layout.end(), physical_device_id) != new_layout.end());
}
Expand All @@ -249,27 +228,21 @@ TEST_F(T3000ReshapeTest, From2x2To1x4) {
auto mesh = ttnn::distributed::open_mesh_device(
{2, 2}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

std::vector<chip_id_t> original_layout;
for (size_t i = 0; i < mesh->num_rows(); ++i) {
for (size_t j = 0; j < mesh->num_cols(); ++j) {
auto id = mesh->get_device(i, j)->id();
original_layout.push_back(id);
}
}
auto mesh_2x2_device_ids = mesh->get_device_ids();

mesh->reshape({1, 4});
EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 4);

std::vector<chip_id_t> new_layout;
for (size_t i = 0; i < mesh->num_rows(); ++i) {
for (size_t j = 0; j < mesh->num_cols(); ++j) {
auto id = mesh->get_device(i, j)->id();
new_layout.push_back(id);
}
}
auto mesh_1x4_device_ids = mesh->get_device_ids();
std::vector<chip_id_t> expected_1x4_device_ids = {
mesh_2x2_device_ids[0],
mesh_2x2_device_ids[1],
mesh_2x2_device_ids[3],
mesh_2x2_device_ids[2],
};

EXPECT_EQ(new_layout, original_layout);
EXPECT_EQ(mesh_1x4_device_ids, expected_1x4_device_ids);
}

} // namespace ttnn::distributed::test
22 changes: 22 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,25 @@ def model(submesh):
submesh_devices = mesh_device.create_submeshes(ttnn.MeshShape(2, 2))
for submesh in submesh_devices:
model(submesh)


@pytest.mark.parametrize("mesh_device", [pytest.param((1, 8), id="1x8_line")], indirect=True)
def test_line_all_gather_after_reshape(mesh_device):
if mesh_device.get_num_devices() < 8:
pytest.skip()
mesh_device.reshape(ttnn.MeshShape(2, 4))
torch_input_tensor = torch.rand((1, 1, 64, 128), dtype=torch.bfloat16)

mesh_tensor = ttnn.from_torch(
torch_input_tensor,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=list(mesh_device.shape), dims=(2, 3)),
)
output_tensor = ttnn.all_gather(
mesh_tensor,
dim=2,
cluster_axis=0,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
)
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
// This is a reference device used to query properties that are the same for all devices in the mesh.
IDevice* reference_device() const;

// Returns the devices in row-major order for the new mesh shape
std::vector<IDevice*> get_row_major_devices(const MeshShape& new_shape) const;

public:
MeshDevice(
std::shared_ptr<ScopedDevices> mesh_handle,
Expand Down
37 changes: 31 additions & 6 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,31 @@ size_t MeshDevice::num_cols() const { return mesh_shape_.num_cols; }

MeshShape MeshDevice::shape() const { return mesh_shape_; }

void MeshDevice::reshape(const MeshShape& new_shape) {
TT_FATAL(
new_shape.num_rows * new_shape.num_cols == this->num_devices(),
"New shape must have the same number of devices as current shape");

std::vector<IDevice*> MeshDevice::get_row_major_devices(const MeshShape& new_shape) const {
// MeshDeviceView requires devices to be provided as a 1D array in row-major order for the target mesh shape.
// The physical connectivity between devices must be preserved when reshaping.
//
// Example:
// Given 4 devices physically connected in a 2x2 grid like this:
// [0]--[1]
// | |
// [3]--[2]
//
// For a 1x4 mesh shape:
// - Devices must form a line: 0->1->2->3
// - Row-major order will be: [0,1,2,3]
//
// For a 2x2 mesh shape:
// - Preserves original 2x2 physical connectivity
// - Row-major order will be: [0,1,3,2]
std::unordered_map<chip_id_t, size_t> physical_device_id_to_linearized_index;
for (size_t i = 0; i < this->num_devices(); i++) {
physical_device_id_to_linearized_index[this->get_devices()[i]->id()] = i;
}

// From an MxN mesh, we can always reduce rank to a 1xM*N Line mesh.
// However, going from a Line mesh to an MxN mesh is not always possible.
std::vector<IDevice*> new_device_order;
if (new_shape.num_rows != 1 and new_shape.num_cols != 1) {
auto new_physical_device_ids =
SystemMesh::instance().request_available_devices(
Expand All @@ -285,10 +298,22 @@ void MeshDevice::reshape(const MeshShape& new_shape) {
this->num_cols());
}
}
for (size_t i = 0; i < new_physical_device_ids.size(); i++) {
new_device_order.push_back(this->get_device(new_physical_device_ids[i]));
}
} else {
new_device_order = view_->get_line_devices();
}
return new_device_order;
}

void MeshDevice::reshape(const MeshShape& new_shape) {
TT_FATAL(
new_shape.num_rows * new_shape.num_cols == this->num_devices(),
"New shape must have the same number of devices as current shape");

mesh_shape_ = new_shape;
view_ = std::make_unique<MeshDeviceView>(scoped_devices_->get_devices(), mesh_shape_);
view_ = std::make_unique<MeshDeviceView>(this->get_row_major_devices(new_shape), new_shape);
}

bool MeshDevice::close() {
Expand Down

0 comments on commit e3920bd

Please sign in to comment.