Skip to content

Commit

Permalink
Re-enable couple mesh tensors tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Feb 21, 2025
1 parent 560ad32 commit f8193f9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 42 deletions.
47 changes: 20 additions & 27 deletions tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,28 @@ using ::testing::FloatEq;
using ::testing::Pointwise;

using MeshTensorTest = T3kMultiDeviceFixture;
/*
TEST_F(MeshTensorTest, Lifecycle) {
const TensorSpec tensor_spec =
TensorSpec(ttnn::Shape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));

Tensor input_tensor = allocate_tensor_on_mesh(tensor_spec, mesh_device_.get());

EXPECT_EQ(input_tensor.workers.size(), mesh_device_->num_devices());
EXPECT_TRUE(input_tensor.is_allocated());

const auto& storage = input_tensor.get_storage();
auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&storage);
auto* device_storage = std::get_if<tt::tt_metal::DeviceStorage>(&storage);

ASSERT_NE(multi_device_storage, nullptr);
EXPECT_NE(multi_device_storage->mesh_buffer, nullptr);
ASSERT_NE(device_storage, nullptr);
EXPECT_NE(device_storage->mesh_buffer, nullptr);

// Buffer address is the same across all device buffers.
const auto buffer_address = multi_device_storage->mesh_buffer->address();
for (auto* device : mesh_device_->get_devices()) {
auto buffer = multi_device_storage->get_buffer_for_device(device);
const auto& view = mesh_device_->get_view();
const auto buffer_address = device_storage->mesh_buffer->address();

for (auto* device : view.get_devices()) {
auto coordinate = view.find_device(device->id());
auto buffer = device_storage->mesh_buffer->get_device_buffer(coordinate);

ASSERT_NE(buffer, nullptr);
EXPECT_TRUE(buffer->is_allocated());
EXPECT_EQ(buffer->address(), buffer_address);
Expand All @@ -49,7 +51,6 @@ TEST_F(MeshTensorTest, Lifecycle) {
input_tensor.deallocate();
EXPECT_FALSE(input_tensor.is_allocated());
}
*/

using MeshTensorDeviceTest = T3kMultiDeviceFixture;

Expand All @@ -62,7 +63,7 @@ TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) {

EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor));
}
/*

TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
const ttnn::Shape shape{1, 1, 32, 32};
const TensorSpec tensor_spec =
Expand All @@ -82,12 +83,9 @@ TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor));
EXPECT_EQ(device_tensor.get_tensor_spec().logical_shape(), shape);

auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(multi_device_storage, nullptr);
for (const auto& [_, shard_spec] : multi_device_storage->specs) {
EXPECT_EQ(shard_spec.logical_shape(), shape);
}
EXPECT_TRUE(std::holds_alternative<tt::tt_metal::ReplicateTensor>(multi_device_storage->strategy));
auto* device_storage = std::get_if<tt::tt_metal::DeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(device_storage, nullptr);
EXPECT_NE(device_storage->mesh_buffer, nullptr);

// Read the tensor back, and compare it with input data.
Tensor output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor);
Expand All @@ -99,9 +97,9 @@ TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
EXPECT_THAT(tensor.to_vector<float>(), Pointwise(FloatEq(), host_data));
}
}
*/
/*
TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) {

// TODO(jchu): Re-enable this test when we have handling for uneven shard shapes.
TEST_F(MeshTensorDeviceTest, DISABLED_WriteMultiDeviceHostTensor) {
const int num_devices = mesh_device_->num_devices();
ASSERT_EQ(num_devices, 8);
// Test uneven shard shapes.
Expand Down Expand Up @@ -129,11 +127,8 @@ TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) {
tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{});
EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor));

auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(multi_device_storage, nullptr);
const auto* device_tensor_strategy = std::get_if<tt::tt_metal::ShardTensor>(&multi_device_storage->strategy);
ASSERT_NE(device_tensor_strategy, nullptr);
EXPECT_EQ(device_tensor_strategy->shard_dimension, 1);
auto* device_storage = std::get_if<tt::tt_metal::DeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(device_storage, nullptr);

// Read the tensor back, and compare it with input data.
Tensor output_host_tensor = aggregate_tensor(
Expand All @@ -142,8 +137,6 @@ TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) {
EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape);

EXPECT_THAT(output_host_tensor.to_vector<float>(), Pointwise(FloatEq(), host_data));
} // namespace
*/

}
} // namespace
} // namespace ttnn::distributed::test
14 changes: 3 additions & 11 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,11 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const IDevice* devic

bool is_host_mesh_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; }

bool is_multi_device_tensor(const Tensor& tensor) {
TT_THROW("TODO(jchu): Not implemented");
return /*tensor.storage_type() == StorageType::MULTI_DEVICE or */
tensor.storage_type() == StorageType::MULTI_DEVICE_HOST;
}
bool is_multi_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; }

bool is_mesh_buffer_tensor(const Tensor& tensor) {
TT_THROW("TODO(jchu): Not implemented");
/*
auto* multi_device_storage = std::get_if<MultiDeviceStorage>(&tensor.get_storage());
return multi_device_storage != nullptr && multi_device_storage->mesh_buffer != nullptr;
*/
return false;
auto* device_storage = std::get_if<DeviceStorage>(&tensor.get_storage());
return device_storage != nullptr && device_storage->mesh_buffer != nullptr;
}

std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) {
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ Tensor distribute_tensor(
}

Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) {
return is_multi_device_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor))
: composer.compose({tensor});
return is_host_mesh_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor))
: composer.compose({tensor});
}

} // namespace ttnn::distributed
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ void Tensor::deallocate_impl(bool force, bool deallocation_through_destructor) {
[force, attr](auto&& s) {
using type = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<type, DeviceStorage>) {
if (force or s.buffer.use_count() == 1) {
if (s.mesh_buffer != nullptr and (force or s.mesh_buffer.use_count() == 1)) {
s.mesh_buffer->deallocate();
} else if (force or s.buffer.use_count() == 1) {
DeallocateBuffer(*(s.buffer));
}
// Safe to reset this buf object since this is the last reference (in
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ Tensor to_device_mesh_tensor(
}

TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_device_mesh_tensor must be called from the main thread");
TT_FATAL(tensor.storage_type() != StorageType::MULTI_DEVICE, "Tensor is already on device!");
TT_FATAL(mesh_device != nullptr, "Need target device in order to move tensor to device!");
TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device");

Expand Down

0 comments on commit f8193f9

Please sign in to comment.