From e5564d471a7ac52c46e1f654796b76816e88ae44 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Thu, 23 Jan 2025 08:08:58 -0600 Subject: [PATCH] Adding subdevice support in all gather. doesn't fix hang on synchronize devices. --- .../device/all_gather_async_op.cpp | 20 ++++++++++++++++--- .../device/all_gather_async_op.hpp | 9 +++++++-- .../device/all_gather_async_program.cpp | 14 +++++++++---- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 2cd2a82cdd5..13f1979d54d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -20,6 +20,7 @@ AllGatherAsync create_all_gather_async_struct( const std::vector& devices, const ttnn::ccl::Topology topology, const std::vector& semaphores, + std::optional sub_device_id, bool enable_persistent_fabric_mode) { uint32_t num_devices = devices.size(); @@ -50,6 +51,7 @@ AllGatherAsync create_all_gather_async_struct( memory_config.value_or(input_tensor.memory_config()), topology, semaphore.value(), + sub_device_id, enable_persistent_fabric_mode}; } @@ -119,6 +121,7 @@ operation::ProgramWithCallbacks AllGatherAsync::create_program( this->ring_index, this->topology, this->semaphore, + this->sub_device_id, this->enable_persistent_fabric_mode); } @@ -153,7 +156,7 @@ Tensor all_gather_async( const uint32_t num_links, const std::optional& memory_config, const ttnn::ccl::Topology topology, - std::optional subdevice_id, + std::optional sub_device_id, bool enable_persistent_fabric_mode) { TT_FATAL( std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, @@ -179,7 +182,15 @@ Tensor all_gather_async( std::vector semaphores = multi_device_global_semaphore.global_semaphores; operation::launch_op( - [dim, num_links, num_devices, memory_config, devices, ccl_topology, semaphores, enable_persistent_fabric_mode]( + [dim, + num_links, + num_devices, + memory_config, + devices, + ccl_topology, + semaphores, + sub_device_id, + enable_persistent_fabric_mode]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -194,6 +205,7 @@ Tensor all_gather_async( devices, ccl_topology, semaphores, + sub_device_id, enable_persistent_fabric_mode), {input_tensor}); }, @@ -211,7 +223,7 @@ Tensor all_gather_async( const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, const std::optional num_preferred_links, - std::optional subdevice_id, + std::optional sub_device_id, bool enable_persistent_fabric_mode) { TT_FATAL( topology == ttnn::ccl::Topology::Linear, @@ -245,6 +257,7 @@ Tensor all_gather_async( num_devices, topology, semaphores, + sub_device_id, enable_persistent_fabric_mode]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, @@ -266,6 +279,7 @@ Tensor all_gather_async( devices, topology, semaphores, + sub_device_id, enable_persistent_fabric_mode), {input_tensor}); }, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index 2e49d65f189..6050b9a8cce 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -35,6 +35,7 @@ struct AllGatherAsync { const MemoryConfig output_mem_config; const ccl::Topology topology; const GlobalSemaphore semaphore; + std::optional sub_device_id; bool enable_persistent_fabric_mode; AllGatherAsync( @@ -47,6 +48,7 @@ struct AllGatherAsync { MemoryConfig output_mem_config, ccl::Topology topology, GlobalSemaphore semaphore, + std::optional& sub_device_id, bool enable_persistent_fabric_mode) : forward_device(forward_device), backward_device(backward_device), @@ -57,6 +59,7 @@ struct AllGatherAsync { output_mem_config(output_mem_config), topology(topology), semaphore(semaphore), + sub_device_id(sub_device_id), enable_persistent_fabric_mode(enable_persistent_fabric_mode) {} // Add attributes method for reflection @@ -92,6 +95,7 @@ AllGatherAsync create_all_gather_async_struct( const std::vector& devices, const ccl::Topology topology, const std::vector& semaphores, + std::optional sub_device_id, bool enable_persistent_fabric_mode); } // namespace all_gather_async_detail } // namespace ccl @@ -108,6 +112,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const uint32_t ring_index, ccl::Topology topology, const GlobalSemaphore semaphore, + const std::optional& sub_device_id, bool enable_persistent_fabric_mode); namespace operations { @@ -121,7 +126,7 @@ Tensor all_gather_async( const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring, - std::optional subdevice_id = std::nullopt, + std::optional sub_device_id = std::nullopt, bool enable_persistent_fabric_mode = false); // TODO make reference Tensor all_gather_async( @@ -133,7 +138,7 @@ Tensor all_gather_async( const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config = std::nullopt, const std::optional num_preferred_links = std::nullopt, - std::optional subdevice_id = std::nullopt, + std::optional sub_device_id = std::nullopt, bool enable_persistent_fabric_mode = false); } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index 8b9bd0a5fb8..8fe983e403f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -73,13 +73,18 @@ static void print_tensor_slice(const ttnn::ccl::v2::TensorSlice& slice_v2) { } std::tuple> choose_worker_cores( - size_t num_links, size_t num_workers_per_link, bool persistent_fabric_mode, IDevice* device) { + size_t num_links, + size_t num_workers_per_link, + bool persistent_fabric_mode, + IDevice* device, + const std::optional& sub_device_id) { std::tuple> result; CoreRangeSet sender_worker_core_range; if (persistent_fabric_mode) { const size_t num_workers_preferred = num_workers_per_link * num_links; - const auto available_cores = - device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0)); + const auto available_cores = device->worker_cores( + HalProgrammableCoreType::TENSIX, + sub_device_id.has_value() ? *sub_device_id : device->get_sub_device_ids().at(0)); if (available_cores.num_cores() < num_workers_preferred) { log_warning( tt::LogOp, @@ -131,6 +136,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const uint32_t ring_index, ccl::Topology topology, const GlobalSemaphore semaphore, + const std::optional& sub_device_id, bool enable_persistent_fabric_mode) { tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; @@ -174,7 +180,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( // Get worker cores, assuming 1 worker per link uint32_t num_workers_per_link = 1; const auto [sender_worker_core_range, sender_worker_cores] = - choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); + choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id); // L1 Scratch CB Creation const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes();