Skip to content

Commit

Permalink
Adding subdevice support in all gather. doesn't fix hang on synchroni…
Browse files Browse the repository at this point in the history
…ze devices.
  • Loading branch information
avoraTT authored and caixunshiren committed Jan 30, 2025
1 parent 433f602 commit e5564d4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ AllGatherAsync create_all_gather_async_struct(
const std::vector<IDevice*>& devices,
const ttnn::ccl::Topology topology,
const std::vector<GlobalSemaphore>& semaphores,
std::optional<SubDeviceId> sub_device_id,
bool enable_persistent_fabric_mode) {
uint32_t num_devices = devices.size();

Expand Down Expand Up @@ -50,6 +51,7 @@ AllGatherAsync create_all_gather_async_struct(
memory_config.value_or(input_tensor.memory_config()),
topology,
semaphore.value(),
sub_device_id,
enable_persistent_fabric_mode};
}

Expand Down Expand Up @@ -119,6 +121,7 @@ operation::ProgramWithCallbacks AllGatherAsync::create_program(
this->ring_index,
this->topology,
this->semaphore,
this->sub_device_id,
this->enable_persistent_fabric_mode);
}

Expand Down Expand Up @@ -153,7 +156,7 @@ Tensor all_gather_async(
const uint32_t num_links,
const std::optional<MemoryConfig>& memory_config,
const ttnn::ccl::Topology topology,
std::optional<SubDeviceId> subdevice_id,
std::optional<SubDeviceId> sub_device_id,
bool enable_persistent_fabric_mode) {
TT_FATAL(
std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr,
Expand All @@ -179,7 +182,15 @@ Tensor all_gather_async(
std::vector<GlobalSemaphore> semaphores = multi_device_global_semaphore.global_semaphores;

operation::launch_op(
[dim, num_links, num_devices, memory_config, devices, ccl_topology, semaphores, enable_persistent_fabric_mode](
[dim,
num_links,
num_devices,
memory_config,
devices,
ccl_topology,
semaphores,
sub_device_id,
enable_persistent_fabric_mode](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand All @@ -194,6 +205,7 @@ Tensor all_gather_async(
devices,
ccl_topology,
semaphores,
sub_device_id,
enable_persistent_fabric_mode),
{input_tensor});
},
Expand All @@ -211,7 +223,7 @@ Tensor all_gather_async(
const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore,
const std::optional<MemoryConfig>& memory_config,
const std::optional<size_t> num_preferred_links,
std::optional<SubDeviceId> subdevice_id,
std::optional<SubDeviceId> sub_device_id,
bool enable_persistent_fabric_mode) {
TT_FATAL(
topology == ttnn::ccl::Topology::Linear,
Expand Down Expand Up @@ -245,6 +257,7 @@ Tensor all_gather_async(
num_devices,
topology,
semaphores,
sub_device_id,
enable_persistent_fabric_mode](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
Expand All @@ -266,6 +279,7 @@ Tensor all_gather_async(
devices,
topology,
semaphores,
sub_device_id,
enable_persistent_fabric_mode),
{input_tensor});
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct AllGatherAsync {
const MemoryConfig output_mem_config;
const ccl::Topology topology;
const GlobalSemaphore semaphore;
std::optional<SubDeviceId> sub_device_id;
bool enable_persistent_fabric_mode;

AllGatherAsync(
Expand All @@ -47,6 +48,7 @@ struct AllGatherAsync {
MemoryConfig output_mem_config,
ccl::Topology topology,
GlobalSemaphore semaphore,
std::optional<SubDeviceId>& sub_device_id,
bool enable_persistent_fabric_mode) :
forward_device(forward_device),
backward_device(backward_device),
Expand All @@ -57,6 +59,7 @@ struct AllGatherAsync {
output_mem_config(output_mem_config),
topology(topology),
semaphore(semaphore),
sub_device_id(sub_device_id),
enable_persistent_fabric_mode(enable_persistent_fabric_mode) {}

// Add attributes method for reflection
Expand Down Expand Up @@ -92,6 +95,7 @@ AllGatherAsync create_all_gather_async_struct(
const std::vector<IDevice*>& devices,
const ccl::Topology topology,
const std::vector<GlobalSemaphore>& semaphores,
std::optional<SubDeviceId> sub_device_id,
bool enable_persistent_fabric_mode);
} // namespace all_gather_async_detail
} // namespace ccl
Expand All @@ -108,6 +112,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers(
const uint32_t ring_index,
ccl::Topology topology,
const GlobalSemaphore semaphore,
const std::optional<SubDeviceId>& sub_device_id,
bool enable_persistent_fabric_mode);

namespace operations {
Expand All @@ -121,7 +126,7 @@ Tensor all_gather_async(
const uint32_t num_links = 1,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring,
std::optional<SubDeviceId> subdevice_id = std::nullopt,
std::optional<SubDeviceId> sub_device_id = std::nullopt,
bool enable_persistent_fabric_mode = false); // TODO make reference

Tensor all_gather_async(
Expand All @@ -133,7 +138,7 @@ Tensor all_gather_async(
const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_preferred_links = std::nullopt,
std::optional<SubDeviceId> subdevice_id = std::nullopt,
std::optional<SubDeviceId> sub_device_id = std::nullopt,
bool enable_persistent_fabric_mode = false);

} // namespace ccl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,18 @@ static void print_tensor_slice(const ttnn::ccl::v2::TensorSlice& slice_v2) {
}

std::tuple<CoreRangeSet, std::vector<CoreCoord>> choose_worker_cores(
size_t num_links, size_t num_workers_per_link, bool persistent_fabric_mode, IDevice* device) {
size_t num_links,
size_t num_workers_per_link,
bool persistent_fabric_mode,
IDevice* device,
const std::optional<SubDeviceId>& sub_device_id) {
std::tuple<CoreRangeSet, std::vector<CoreCoord>> result;
CoreRangeSet sender_worker_core_range;
if (persistent_fabric_mode) {
const size_t num_workers_preferred = num_workers_per_link * num_links;
const auto available_cores =
device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0));
const auto available_cores = device->worker_cores(
HalProgrammableCoreType::TENSIX,
sub_device_id.has_value() ? *sub_device_id : device->get_sub_device_ids().at(0));
if (available_cores.num_cores() < num_workers_preferred) {
log_warning(
tt::LogOp,
Expand Down Expand Up @@ -131,6 +136,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers(
const uint32_t ring_index,
ccl::Topology topology,
const GlobalSemaphore semaphore,
const std::optional<SubDeviceId>& sub_device_id,
bool enable_persistent_fabric_mode) {
tt::tt_metal::Program program{};
const bool enable_async_output_tensor = false;
Expand Down Expand Up @@ -174,7 +180,7 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers(
// Get worker cores, assuming 1 worker per link
uint32_t num_workers_per_link = 1;
const auto [sender_worker_core_range, sender_worker_cores] =
choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device);
choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id);

// L1 Scratch CB Creation
const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes();
Expand Down

0 comments on commit e5564d4

Please sign in to comment.