diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp index ece06de1c41d..e517db3b32b6 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp @@ -21,26 +21,28 @@ using ::testing::FloatEq; using ::testing::Pointwise; using MeshTensorTest = T3kMultiDeviceFixture; -/* TEST_F(MeshTensorTest, Lifecycle) { const TensorSpec tensor_spec = TensorSpec(ttnn::Shape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); Tensor input_tensor = allocate_tensor_on_mesh(tensor_spec, mesh_device_.get()); - EXPECT_EQ(input_tensor.workers.size(), mesh_device_->num_devices()); EXPECT_TRUE(input_tensor.is_allocated()); const auto& storage = input_tensor.get_storage(); - auto* multi_device_storage = std::get_if(&storage); + auto* device_storage = std::get_if(&storage); - ASSERT_NE(multi_device_storage, nullptr); - EXPECT_NE(multi_device_storage->mesh_buffer, nullptr); + ASSERT_NE(device_storage, nullptr); + EXPECT_NE(device_storage->mesh_buffer, nullptr); // Buffer address is the same across all device buffers. - const auto buffer_address = multi_device_storage->mesh_buffer->address(); - for (auto* device : mesh_device_->get_devices()) { - auto buffer = multi_device_storage->get_buffer_for_device(device); + const auto& view = mesh_device_->get_view(); + const auto buffer_address = device_storage->mesh_buffer->address(); + + for (auto* device : view.get_devices()) { + auto coordinate = view.find_device(device->id()); + auto buffer = device_storage->mesh_buffer->get_device_buffer(coordinate); + ASSERT_NE(buffer, nullptr); EXPECT_TRUE(buffer->is_allocated()); EXPECT_EQ(buffer->address(), buffer_address); @@ -49,7 +51,6 @@ TEST_F(MeshTensorTest, Lifecycle) { input_tensor.deallocate(); EXPECT_FALSE(input_tensor.is_allocated()); } -*/ using MeshTensorDeviceTest = T3kMultiDeviceFixture; @@ -62,7 +63,7 @@ TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) { EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor)); } -/* + TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) { const ttnn::Shape shape{1, 1, 32, 32}; const TensorSpec tensor_spec = @@ -82,12 +83,9 @@ TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) { EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); EXPECT_EQ(device_tensor.get_tensor_spec().logical_shape(), shape); - auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); - ASSERT_NE(multi_device_storage, nullptr); - for (const auto& [_, shard_spec] : multi_device_storage->specs) { - EXPECT_EQ(shard_spec.logical_shape(), shape); - } - EXPECT_TRUE(std::holds_alternative(multi_device_storage->strategy)); + auto* device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(device_storage, nullptr); + EXPECT_NE(device_storage->mesh_buffer, nullptr); // Read the tensor back, and compare it with input data. Tensor output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor); @@ -99,9 +97,9 @@ TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) { EXPECT_THAT(tensor.to_vector(), Pointwise(FloatEq(), host_data)); } } -*/ -/* -TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) { + +// TODO(jchu): Re-enable this test when we have handling for uneven shard shapes. +TEST_F(MeshTensorDeviceTest, DISABLED_WriteMultiDeviceHostTensor) { const int num_devices = mesh_device_->num_devices(); ASSERT_EQ(num_devices, 8); // Test uneven shard shapes. @@ -129,11 +127,8 @@ TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) { tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{}); EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); - auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); - ASSERT_NE(multi_device_storage, nullptr); - const auto* device_tensor_strategy = std::get_if(&multi_device_storage->strategy); - ASSERT_NE(device_tensor_strategy, nullptr); - EXPECT_EQ(device_tensor_strategy->shard_dimension, 1); + auto* device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(device_storage, nullptr); // Read the tensor back, and compare it with input data. Tensor output_host_tensor = aggregate_tensor( @@ -142,8 +137,6 @@ TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) { EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape); EXPECT_THAT(output_host_tensor.to_vector(), Pointwise(FloatEq(), host_data)); - } // namespace - */ - +} } // namespace } // namespace ttnn::distributed::test diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 242a193b5d9f..db5cb7c52b2a 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -199,19 +199,11 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const IDevice* devic bool is_host_mesh_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } -bool is_multi_device_tensor(const Tensor& tensor) { - TT_THROW("TODO(jchu): Not implemented"); - return /*tensor.storage_type() == StorageType::MULTI_DEVICE or */ - tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; -} +bool is_multi_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } bool is_mesh_buffer_tensor(const Tensor& tensor) { - TT_THROW("TODO(jchu): Not implemented"); - /* - auto* multi_device_storage = std::get_if(&tensor.get_storage()); - return multi_device_storage != nullptr && multi_device_storage->mesh_buffer != nullptr; - */ - return false; + auto* device_storage = std::get_if(&tensor.get_storage()); + return device_storage != nullptr && device_storage->mesh_buffer != nullptr; } std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 3d82d24714fc..6ec223fc8bb8 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -196,8 +196,8 @@ Tensor distribute_tensor( } Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { - return is_multi_device_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor)) - : composer.compose({tensor}); + return is_host_mesh_tensor(tensor) ? composer.compose(get_tensors_from_multi_device_storage(tensor)) + : composer.compose({tensor}); } } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index b9aa75c2fa5a..0cd9be116d39 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -273,7 +273,9 @@ void Tensor::deallocate_impl(bool force, bool deallocation_through_destructor) { [force, attr](auto&& s) { using type = std::decay_t; if constexpr (std::is_same_v) { - if (force or s.buffer.use_count() == 1) { + if (s.mesh_buffer != nullptr and (force or s.mesh_buffer.use_count() == 1)) { + s.mesh_buffer->deallocate(); + } else if (force or s.buffer.use_count() == 1) { DeallocateBuffer(*(s.buffer)); } // Safe to reset this buf object since this is the last reference (in diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index dc0aeba59f8c..b8b2cad3ee5b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -817,7 +817,6 @@ Tensor to_device_mesh_tensor( } TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_device_mesh_tensor must be called from the main thread"); - TT_FATAL(tensor.storage_type() != StorageType::MULTI_DEVICE, "Tensor is already on device!"); TT_FATAL(mesh_device != nullptr, "Need target device in order to move tensor to device!"); TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device");