Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Mnnvl with split comm (#185)
Browse files Browse the repository at this point in the history
support split comm and get_local_mnnvl_comm

split_comm
```

def split_communicator(comm: WholeMemoryCommunicator, color: int, key: int = 0):
    """Split Communicator.
    Creates a set of new communicators from an existing one. Ranks which pass the same color value will be part of the
    same group; color must be a non-negative value.
    The value of key will determine the rank order, and the smaller key means the smaller rank in new communicator.
    If keys are equal between ranks, then the rank in the original communicator will be used to order ranks.
    """
```

Authors:
  - Chuang Zhu (https://github.com/chuangz0)

Approvers:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

URL: #185
  • Loading branch information
chuangz0 authored Jun 18, 2024
1 parent 8d4cd9b commit ba505af
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 43 deletions.
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ PRIVATE
NCCL::NCCL
)

if (CUDAToolkit_VERSION VERSION_GREATER "12.3")
# Link the NVML library if CUDA version is greater than 12.3
target_link_libraries(wholegraph PRIVATE CUDA::nvml)
endif()

if(BUILD_WITH_NVSHMEM)

file(GLOB_RECURSE NVSHMEM_SOURCE_FILES "src/wholememory_ops/functions/nvshmem*.cu")
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum wholememory_error_code_t {
WHOLEMEMORY_INVALID_VALUE, /*!< input value is invalid */
WHOLEMEMORY_OUT_OF_MEMORY, /*!< out of memory */
WHOLEMEMORY_NOT_SUPPORTED, /*!< not supported */
WHOLEMEMORY_SYSTEM_ERROR, /*!< system error>*/
};

#define WHOLEMEMORY_RETURN_ON_FAIL(X) \
Expand Down Expand Up @@ -90,6 +91,7 @@ enum LogLevel {
LEVEL_TRACE /*!< Trace */
};

#define WHOLEMEMORY_SPILT_NO_COLOR -1
/**
* Initialize WholeMemory library
* @param flags : reserved should be 0
Expand All @@ -111,6 +113,15 @@ wholememory_error_code_t wholememory_finalize();
*/
typedef struct wholememory_comm_* wholememory_comm_t;

struct clique_info_t {
int is_in_clique; // is_in_clique >0 means the gpu belongs to a mnnvl domain
int clique_first_rank;
int clique_rank; // the rank of gpu in a mnnvl domain
int clique_rank_num; // the num of gpu in the mnnvl domain
int clique_id; // the id of clique
int clique_num; // the num of clique in the comm domain.
};

#define WHOLEMEMORY_UNIQUE_ID_BYTES (128)
/**
* @brief Unique ID for WholeMemory Communicators
Expand Down Expand Up @@ -142,6 +153,24 @@ wholememory_error_code_t wholememory_create_communicator(wholememory_comm_t* com
int rank,
int size);

/**
* Split WholeMemory Communicator
* @param new_comm: returned the splited wholeMemory Communicator
* @param comm: WholeMemory Communicator to split
* @param color: color value to split communicator,Ranks which pass the same color value will be
* part of the same group; color must be a non-negative value. If it is passed as
* WHOLEMEMORY_SPLIT_NOCOLOR, it means that the rank will not be part of any group, therefore
* returning NULL as newcomm.
* @param key: key value to split communicator,the value of key will determine the
* rank order, and the smaller key means the smaller rank in new communicator. If keys are equal
* between ranks, then the rank in the original communicator will be used to order ranks.
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_split_communicator(wholememory_comm_t* new_comm,
wholememory_comm_t comm,
int color,
int key);
/**
* Destroy WholeMemory Communicator
* @param comm : WholeMemory Communicator to destroy
Expand Down Expand Up @@ -177,6 +206,9 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor
*/
wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm);

wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm);

bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm);

wholememory_error_code_t wholememory_communicator_set_distributed_backend(
Expand Down Expand Up @@ -393,6 +425,7 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem
*/
bool wholememory_is_intranode_communicator(wholememory_comm_t comm);

bool wholememory_is_intra_mnnvl_communicator(wholememory_comm_t comm);
bool wholememory_is_build_with_nvshmem();
#ifdef WITH_NVSHMEM_SUPPORT
wholememory_error_code_t wholememory_get_nvshmem_reference(
Expand Down
191 changes: 184 additions & 7 deletions cpp/src/wholememory/communicator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
#include "communicator.hpp"

#include <cstdlib>
#include <set>
#include <sys/stat.h>
#include <unistd.h>

Expand Down Expand Up @@ -352,18 +353,19 @@ void wholememory_comm_::device_multicast_sendrecv(const void* sendbuf,

bool wholememory_comm_::is_intranode() const { return intra_node_rank_num == world_size; }

bool wholememory_comm_::is_intra_mnnvl() const { return support_mnnvl; }
bool wholememory_comm_::support_type_location(wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location) const
{
if (memory_location == WHOLEMEMORY_ML_HOST) {
if (is_intranode() || memory_type == WHOLEMEMORY_MT_DISTRIBUTED) return true;
return SupportMNNVLForEGM();
return is_intra_mnnvl() && SupportEGM();
} else if (memory_location == WHOLEMEMORY_ML_DEVICE) {
if (memory_type == WHOLEMEMORY_MT_DISTRIBUTED) return true;
if (is_intranode()) {
return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num);
} else {
return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num) && SupportMNNVL();
return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num) && is_intra_mnnvl();
}
} else {
return false;
Expand Down Expand Up @@ -422,6 +424,10 @@ struct rank_info {
int rank;
int size;
int gpu_id;
// MNNVL support
#if CUDA_VERSION >= 12030
nvmlGpuFabricInfo_t fabric_info;
#endif
};

static void get_host_name(char* hostname, int maxlen, const char delim)
Expand Down Expand Up @@ -487,20 +493,72 @@ void get_host_info(host_info* phi)
get_shm_devid(&phi->shm_dev);
}

bool comm_support_mnnvl(wholememory_comm_t wm_comm, const std::unique_ptr<rank_info[]>& p_rank_info)
{
#if CUDA_VERSION >= 12030
int flag = 0;
CUdevice currentDev;
WM_CU_CHECK_NO_THROW(cuDeviceGet(&currentDev, wm_comm->dev_id));
// Ignore error if CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED is not supported
WM_CU_CHECK_NO_THROW(
cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev));
if (!flag) return false;

nvmlGpuFabricInfo_t gpuFabricInfo;
WHOLEMEMORY_CHECK_NOTHROW(wholememory::GetGpuFabricInfo(wm_comm->dev_id, &gpuFabricInfo) ==
WHOLEMEMORY_SUCCESS);

if (gpuFabricInfo.state != NVML_GPU_FABRIC_STATE_COMPLETED) { return false; }

// Check that all ranks have initialized the fabric fully
for (int i = 0; i < wm_comm->world_rank; i++) {
if (p_rank_info.get()[i].fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED) return 0;
}

return GetCudaCompCap() >= 90;
#else

return 0;
#endif
};

void exchange_rank_info(wholememory_comm_t wm_comm)
{
rank_info ri;
get_host_info(&ri.rank_host_info);
ri.rank = wm_comm->world_rank;
ri.size = wm_comm->world_size;
ri.pid = getpid();
ri.gpu_id = wm_comm->dev_id;
ri.rank = wm_comm->world_rank;
ri.size = wm_comm->world_size;
ri.pid = getpid();
ri.gpu_id = wm_comm->dev_id;
wm_comm->clique_info.is_in_clique = 0;

#if CUDA_VERSION >= 12030
memset(&ri.fabric_info, 0, sizeof(ri.fabric_info));
WHOLEMEMORY_CHECK_NOTHROW(GetGpuFabricInfo(wm_comm->dev_id, &ri.fabric_info) ==
WHOLEMEMORY_SUCCESS);

// // A zero UUID means we don't have MNNVL fabric info
if (((((long*)ri.fabric_info.clusterUuid)[0] | ((long*)ri.fabric_info.clusterUuid)[1]) == 0)) {
wm_comm->clique_info.is_in_clique = 0;

} else {
wm_comm->clique_info.is_in_clique = 1;
}

#endif

std::unique_ptr<rank_info[]> p_rank_info(new rank_info[ri.size]);
wm_comm->host_allgather(&ri, p_rank_info.get(), sizeof(rank_info), WHOLEMEMORY_DT_INT8);
wm_comm->intra_node_first_rank = -1;
wm_comm->intra_node_rank_num = 0;
wm_comm->intra_node_rank = -1;

wm_comm->clique_info.clique_first_rank = -1;
wm_comm->clique_info.clique_rank = -1;
wm_comm->clique_info.clique_rank_num = 0;

std::set<int> clique_ids{};

for (int r = 0; r < wm_comm->world_size; r++) {
WHOLEMEMORY_CHECK(r == p_rank_info.get()[r].rank);
if (ri.rank_host_info == p_rank_info.get()[r].rank_host_info) {
Expand All @@ -512,7 +570,36 @@ void exchange_rank_info(wholememory_comm_t wm_comm)
wm_comm->local_gpu_ids[wm_comm->intra_node_rank_num] = p_rank_info.get()[r].gpu_id;
wm_comm->intra_node_rank_num++;
}

#if CUDA_VERSION >= 12030

if ((memcmp(ri.fabric_info.clusterUuid,
p_rank_info.get()[r].fabric_info.clusterUuid,
NVML_GPU_FABRIC_UUID_LEN) == 0) &&
(ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) {
if (r == wm_comm->world_rank) {
wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num;
}
if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; }
wm_comm->clique_info.clique_rank_num++;
}
clique_ids.insert(p_rank_info.get()[r].fabric_info.cliqueId);

#endif
}

#if CUDA_VERSION >= 12030
wm_comm->clique_info.clique_num = clique_ids.size();
int id = 0;
for (auto clique_id : clique_ids) {
if (clique_id == ri.fabric_info.cliqueId) { wm_comm->clique_info.clique_id = id; }
id++;
}

wm_comm->support_mnnvl = (comm_support_mnnvl(wm_comm, p_rank_info)) &&
(wm_comm->clique_info.clique_rank_num == wm_comm->world_size);

#endif
}

void negotiate_communicator_id_locked(wholememory_comm_t wm_comm)
Expand Down Expand Up @@ -648,6 +735,70 @@ wholememory_error_code_t create_communicator(wholememory_comm_t* comm,
}
}

/**
*
* Ranks which pass the same color value will be part of the same group; color must be a
* non-negative value. If it is passed as WHOLEMEMORY_SPLIT_NOCOLOR, it means that the rank will not
* be part of any group, therefore returning NULL as newcomm. The value of key will determine the
* rank order, and the smaller key means the smaller rank in new communicator. If keys are equal
* between ranks, then the rank in the original communicator will be used to order ranks.
*
*/

wholememory_error_code_t split_communicator(wholememory_comm_t* new_comm,
wholememory_comm_t parent_comm,
int color,
int key) noexcept
{
try {
std::unique_lock<std::mutex> mlock(comm_mu);

WHOLEMEMORY_EXPECTS(wholememory_communicator_is_bind_to_nvshmem(parent_comm) == false,
"Cannot split a communicator that is already bind to NVSHMEM");

ncclComm_t nccl_comm = parent_comm->raft_nccl_comm->raw_nccl_comm();
WHOLEMEMORY_CHECK(nccl_comm != nullptr);
ncclComm_t new_nccl_comm;
WHOLEMEMORY_CHECK(ncclCommSplit(nccl_comm, color, key, &new_nccl_comm, NULL) == ncclSuccess);
cudaStream_t cuda_stream;
WM_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream, cudaStreamNonBlocking));
if (new_nccl_comm == NULL) {
*new_comm = nullptr;
return WHOLEMEMORY_SUCCESS;
}
int new_rank;
int new_size;
WHOLEMEMORY_CHECK(ncclCommUserRank(new_nccl_comm, &new_rank) == ncclSuccess);
WHOLEMEMORY_CHECK(ncclCommCount(new_nccl_comm, &new_size) == ncclSuccess);

auto* wm_comm = new wholememory_comm_(new_nccl_comm, new_size, new_rank, cuda_stream);
*new_comm = wm_comm;
WM_COMM_CHECK_ALL_SAME(wm_comm, WM_COMM_OP_STARTING);

exchange_rank_info(wm_comm);

negotiate_communicator_id_locked(wm_comm);

maybe_create_temp_dir(wm_comm);

determine_alloc_granularity(wm_comm);

return WHOLEMEMORY_SUCCESS;
} catch (const wholememory::cu_error& wce) {
WHOLEMEMORY_FAIL_NOTHROW("%s", wce.what());
} catch (const wholememory::cuda_error& wce) {
WHOLEMEMORY_FAIL_NOTHROW("%s", wce.what());
} catch (const wholememory::logic_error& wle) {
WHOLEMEMORY_FAIL_NOTHROW("%s", wle.what());
} catch (const raft::exception& re) {
WHOLEMEMORY_FAIL_NOTHROW("%s", re.what());
} catch (const std::bad_alloc& sba) {
WHOLEMEMORY_FAIL_NOTHROW("%s", sba.what());
} catch (...) {
WHOLEMEMORY_FAIL_NOTHROW("Unknown exception.");
}
}

void destroy_all_wholememory(wholememory_comm_t comm) noexcept
{
try {
Expand Down Expand Up @@ -740,6 +891,27 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com
return WHOLEMEMORY_SUCCESS;
}

// wholememory_error_code_t communicator_get_clique_rank(int* clique_rank,
// wholememory_comm_t comm) noexcept
// {
// *clique_rank = comm->clique_rank;
// return WHOLEMEMORY_SUCCESS;
// }

// wholememory_error_code_t communicator_get_clique_size(int* clique_size,
// wholememory_comm_t comm) noexcept
// {
// *clique_size = comm->clique_rank_num;
// return WHOLEMEMORY_SUCCESS;
// }

wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info,
wholememory_comm_t comm) noexcept
{
*clique_info = comm->clique_info;
return WHOLEMEMORY_SUCCESS;
}

bool communicator_is_bind_to_nvshmem(wholememory_comm_t comm) noexcept
{
#ifdef WITH_NVSHMEM_SUPPORT
Expand Down Expand Up @@ -772,6 +944,11 @@ void communicator_barrier(wholememory_comm_t comm)

bool is_intranode_communicator(wholememory_comm_t comm) noexcept { return comm->is_intranode(); }

bool is_intra_mnnvl_communicator(wholememory_comm_t comm) noexcept
{
return comm->is_intra_mnnvl();
}

#ifdef WITH_NVSHMEM_SUPPORT
wholememory_error_code_t init_nvshmem_with_comm(wholememory_comm_t comm) noexcept
{
Expand Down
Loading

0 comments on commit ba505af

Please sign in to comment.