From ba505af5c4c8620632c302f8278df33414e76d8c Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Tue, 18 Jun 2024 20:55:01 +0800 Subject: [PATCH] Mnnvl with split comm (#185) support split comm and get_local_mnnvl_comm split_comm ``` def split_communicator(comm: WholeMemoryCommunicator, color: int, key: int = 0): """Split Communicator. Creates a set of new communicators from an existing one. Ranks which pass the same color value will be part of the same group; color must be a non-negative value. The value of key will determine the rank order, and the smaller key means the smaller rank in new communicator. If keys are equal between ranks, then the rank in the original communicator will be used to order ranks. """ ``` Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/185 --- cpp/CMakeLists.txt | 5 + cpp/include/wholememory/wholememory.h | 33 +++ cpp/src/wholememory/communicator.cpp | 191 +++++++++++++++++- cpp/src/wholememory/communicator.hpp | 16 +- cpp/src/wholememory/memory_handle.cpp | 2 +- cpp/src/wholememory/nccl_comms.cpp | 4 +- cpp/src/wholememory/nccl_comms.hpp | 3 +- cpp/src/wholememory/system_info.cpp | 70 ++++++- cpp/src/wholememory/system_info.hpp | 15 +- cpp/src/wholememory/wholememory.cpp | 19 ++ .../binding/wholememory_binding.pyx | 30 +++ .../pylibwholegraph/torch/__init__.py | 4 +- .../pylibwholegraph/torch/comm.py | 107 ++++++++-- 13 files changed, 456 insertions(+), 43 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cea2c0459..9c364b0f6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -199,6 +199,11 @@ PRIVATE NCCL::NCCL ) +if (CUDAToolkit_VERSION VERSION_GREATER "12.3") + # Link the NVML library if CUDA version is greater than 12.3 + target_link_libraries(wholegraph PRIVATE CUDA::nvml) +endif() + if(BUILD_WITH_NVSHMEM) file(GLOB_RECURSE NVSHMEM_SOURCE_FILES "src/wholememory_ops/functions/nvshmem*.cu") diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 08f16213f..f6bacccb3 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -40,6 +40,7 @@ enum wholememory_error_code_t { WHOLEMEMORY_INVALID_VALUE, /*!< input value is invalid */ WHOLEMEMORY_OUT_OF_MEMORY, /*!< out of memory */ WHOLEMEMORY_NOT_SUPPORTED, /*!< not supported */ + WHOLEMEMORY_SYSTEM_ERROR, /*!< system error>*/ }; #define WHOLEMEMORY_RETURN_ON_FAIL(X) \ @@ -90,6 +91,7 @@ enum LogLevel { LEVEL_TRACE /*!< Trace */ }; +#define WHOLEMEMORY_SPILT_NO_COLOR -1 /** * Initialize WholeMemory library * @param flags : reserved should be 0 @@ -111,6 +113,15 @@ wholememory_error_code_t wholememory_finalize(); */ typedef struct wholememory_comm_* wholememory_comm_t; +struct clique_info_t { + int is_in_clique; // is_in_clique >0 means the gpu belongs to a mnnvl domain + int clique_first_rank; + int clique_rank; // the rank of gpu in a mnnvl domain + int clique_rank_num; // the num of gpu in the mnnvl domain + int clique_id; // the id of clique + int clique_num; // the num of clique in the comm domain. +}; + #define WHOLEMEMORY_UNIQUE_ID_BYTES (128) /** * @brief Unique ID for WholeMemory Communicators @@ -142,6 +153,24 @@ wholememory_error_code_t wholememory_create_communicator(wholememory_comm_t* com int rank, int size); +/** + * Split WholeMemory Communicator + * @param new_comm: returned the splited wholeMemory Communicator + * @param comm: WholeMemory Communicator to split + * @param color: color value to split communicator,Ranks which pass the same color value will be + * part of the same group; color must be a non-negative value. If it is passed as + * WHOLEMEMORY_SPLIT_NOCOLOR, it means that the rank will not be part of any group, therefore + * returning NULL as newcomm. + * @param key: key value to split communicator,the value of key will determine the + * rank order, and the smaller key means the smaller rank in new communicator. If keys are equal + * between ranks, then the rank in the original communicator will be used to order ranks. + * @return : wholememory_error_code_t + +*/ +wholememory_error_code_t wholememory_split_communicator(wholememory_comm_t* new_comm, + wholememory_comm_t comm, + int color, + int key); /** * Destroy WholeMemory Communicator * @param comm : WholeMemory Communicator to destroy @@ -177,6 +206,9 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor */ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm); +wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info, + wholememory_comm_t comm); + bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm); wholememory_error_code_t wholememory_communicator_set_distributed_backend( @@ -393,6 +425,7 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem */ bool wholememory_is_intranode_communicator(wholememory_comm_t comm); +bool wholememory_is_intra_mnnvl_communicator(wholememory_comm_t comm); bool wholememory_is_build_with_nvshmem(); #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t wholememory_get_nvshmem_reference( diff --git a/cpp/src/wholememory/communicator.cpp b/cpp/src/wholememory/communicator.cpp index 5e8bf855c..d08fe0804 100644 --- a/cpp/src/wholememory/communicator.cpp +++ b/cpp/src/wholememory/communicator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "communicator.hpp" #include +#include #include #include @@ -352,18 +353,19 @@ void wholememory_comm_::device_multicast_sendrecv(const void* sendbuf, bool wholememory_comm_::is_intranode() const { return intra_node_rank_num == world_size; } +bool wholememory_comm_::is_intra_mnnvl() const { return support_mnnvl; } bool wholememory_comm_::support_type_location(wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location) const { if (memory_location == WHOLEMEMORY_ML_HOST) { if (is_intranode() || memory_type == WHOLEMEMORY_MT_DISTRIBUTED) return true; - return SupportMNNVLForEGM(); + return is_intra_mnnvl() && SupportEGM(); } else if (memory_location == WHOLEMEMORY_ML_DEVICE) { if (memory_type == WHOLEMEMORY_MT_DISTRIBUTED) return true; if (is_intranode()) { return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num); } else { - return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num) && SupportMNNVL(); + return DevicesCanAccessP2P(&local_gpu_ids[0], intra_node_rank_num) && is_intra_mnnvl(); } } else { return false; @@ -422,6 +424,10 @@ struct rank_info { int rank; int size; int gpu_id; +// MNNVL support +#if CUDA_VERSION >= 12030 + nvmlGpuFabricInfo_t fabric_info; +#endif }; static void get_host_name(char* hostname, int maxlen, const char delim) @@ -487,20 +493,72 @@ void get_host_info(host_info* phi) get_shm_devid(&phi->shm_dev); } +bool comm_support_mnnvl(wholememory_comm_t wm_comm, const std::unique_ptr& p_rank_info) +{ +#if CUDA_VERSION >= 12030 + int flag = 0; + CUdevice currentDev; + WM_CU_CHECK_NO_THROW(cuDeviceGet(¤tDev, wm_comm->dev_id)); + // Ignore error if CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED is not supported + WM_CU_CHECK_NO_THROW( + cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev)); + if (!flag) return false; + + nvmlGpuFabricInfo_t gpuFabricInfo; + WHOLEMEMORY_CHECK_NOTHROW(wholememory::GetGpuFabricInfo(wm_comm->dev_id, &gpuFabricInfo) == + WHOLEMEMORY_SUCCESS); + + if (gpuFabricInfo.state != NVML_GPU_FABRIC_STATE_COMPLETED) { return false; } + + // Check that all ranks have initialized the fabric fully + for (int i = 0; i < wm_comm->world_rank; i++) { + if (p_rank_info.get()[i].fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED) return 0; + } + + return GetCudaCompCap() >= 90; +#else + + return 0; +#endif +}; + void exchange_rank_info(wholememory_comm_t wm_comm) { rank_info ri; get_host_info(&ri.rank_host_info); - ri.rank = wm_comm->world_rank; - ri.size = wm_comm->world_size; - ri.pid = getpid(); - ri.gpu_id = wm_comm->dev_id; + ri.rank = wm_comm->world_rank; + ri.size = wm_comm->world_size; + ri.pid = getpid(); + ri.gpu_id = wm_comm->dev_id; + wm_comm->clique_info.is_in_clique = 0; + +#if CUDA_VERSION >= 12030 + memset(&ri.fabric_info, 0, sizeof(ri.fabric_info)); + WHOLEMEMORY_CHECK_NOTHROW(GetGpuFabricInfo(wm_comm->dev_id, &ri.fabric_info) == + WHOLEMEMORY_SUCCESS); + + // // A zero UUID means we don't have MNNVL fabric info + if (((((long*)ri.fabric_info.clusterUuid)[0] | ((long*)ri.fabric_info.clusterUuid)[1]) == 0)) { + wm_comm->clique_info.is_in_clique = 0; + + } else { + wm_comm->clique_info.is_in_clique = 1; + } + +#endif std::unique_ptr p_rank_info(new rank_info[ri.size]); wm_comm->host_allgather(&ri, p_rank_info.get(), sizeof(rank_info), WHOLEMEMORY_DT_INT8); wm_comm->intra_node_first_rank = -1; wm_comm->intra_node_rank_num = 0; wm_comm->intra_node_rank = -1; + + wm_comm->clique_info.clique_first_rank = -1; + wm_comm->clique_info.clique_rank = -1; + wm_comm->clique_info.clique_rank_num = 0; + + std::set clique_ids{}; + for (int r = 0; r < wm_comm->world_size; r++) { WHOLEMEMORY_CHECK(r == p_rank_info.get()[r].rank); if (ri.rank_host_info == p_rank_info.get()[r].rank_host_info) { @@ -512,7 +570,36 @@ void exchange_rank_info(wholememory_comm_t wm_comm) wm_comm->local_gpu_ids[wm_comm->intra_node_rank_num] = p_rank_info.get()[r].gpu_id; wm_comm->intra_node_rank_num++; } + +#if CUDA_VERSION >= 12030 + + if ((memcmp(ri.fabric_info.clusterUuid, + p_rank_info.get()[r].fabric_info.clusterUuid, + NVML_GPU_FABRIC_UUID_LEN) == 0) && + (ri.fabric_info.cliqueId == p_rank_info.get()[r].fabric_info.cliqueId)) { + if (r == wm_comm->world_rank) { + wm_comm->clique_info.clique_rank = wm_comm->clique_info.clique_rank_num; + } + if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; } + wm_comm->clique_info.clique_rank_num++; + } + clique_ids.insert(p_rank_info.get()[r].fabric_info.cliqueId); + +#endif + } + +#if CUDA_VERSION >= 12030 + wm_comm->clique_info.clique_num = clique_ids.size(); + int id = 0; + for (auto clique_id : clique_ids) { + if (clique_id == ri.fabric_info.cliqueId) { wm_comm->clique_info.clique_id = id; } + id++; } + + wm_comm->support_mnnvl = (comm_support_mnnvl(wm_comm, p_rank_info)) && + (wm_comm->clique_info.clique_rank_num == wm_comm->world_size); + +#endif } void negotiate_communicator_id_locked(wholememory_comm_t wm_comm) @@ -648,6 +735,70 @@ wholememory_error_code_t create_communicator(wholememory_comm_t* comm, } } +/** + * + * Ranks which pass the same color value will be part of the same group; color must be a + * non-negative value. If it is passed as WHOLEMEMORY_SPLIT_NOCOLOR, it means that the rank will not + * be part of any group, therefore returning NULL as newcomm. The value of key will determine the + * rank order, and the smaller key means the smaller rank in new communicator. If keys are equal + * between ranks, then the rank in the original communicator will be used to order ranks. + * + */ + +wholememory_error_code_t split_communicator(wholememory_comm_t* new_comm, + wholememory_comm_t parent_comm, + int color, + int key) noexcept +{ + try { + std::unique_lock mlock(comm_mu); + + WHOLEMEMORY_EXPECTS(wholememory_communicator_is_bind_to_nvshmem(parent_comm) == false, + "Cannot split a communicator that is already bind to NVSHMEM"); + + ncclComm_t nccl_comm = parent_comm->raft_nccl_comm->raw_nccl_comm(); + WHOLEMEMORY_CHECK(nccl_comm != nullptr); + ncclComm_t new_nccl_comm; + WHOLEMEMORY_CHECK(ncclCommSplit(nccl_comm, color, key, &new_nccl_comm, NULL) == ncclSuccess); + cudaStream_t cuda_stream; + WM_CUDA_CHECK(cudaStreamCreateWithFlags(&cuda_stream, cudaStreamNonBlocking)); + if (new_nccl_comm == NULL) { + *new_comm = nullptr; + return WHOLEMEMORY_SUCCESS; + } + int new_rank; + int new_size; + WHOLEMEMORY_CHECK(ncclCommUserRank(new_nccl_comm, &new_rank) == ncclSuccess); + WHOLEMEMORY_CHECK(ncclCommCount(new_nccl_comm, &new_size) == ncclSuccess); + + auto* wm_comm = new wholememory_comm_(new_nccl_comm, new_size, new_rank, cuda_stream); + *new_comm = wm_comm; + WM_COMM_CHECK_ALL_SAME(wm_comm, WM_COMM_OP_STARTING); + + exchange_rank_info(wm_comm); + + negotiate_communicator_id_locked(wm_comm); + + maybe_create_temp_dir(wm_comm); + + determine_alloc_granularity(wm_comm); + + return WHOLEMEMORY_SUCCESS; + } catch (const wholememory::cu_error& wce) { + WHOLEMEMORY_FAIL_NOTHROW("%s", wce.what()); + } catch (const wholememory::cuda_error& wce) { + WHOLEMEMORY_FAIL_NOTHROW("%s", wce.what()); + } catch (const wholememory::logic_error& wle) { + WHOLEMEMORY_FAIL_NOTHROW("%s", wle.what()); + } catch (const raft::exception& re) { + WHOLEMEMORY_FAIL_NOTHROW("%s", re.what()); + } catch (const std::bad_alloc& sba) { + WHOLEMEMORY_FAIL_NOTHROW("%s", sba.what()); + } catch (...) { + WHOLEMEMORY_FAIL_NOTHROW("Unknown exception."); + } +} + void destroy_all_wholememory(wholememory_comm_t comm) noexcept { try { @@ -740,6 +891,27 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com return WHOLEMEMORY_SUCCESS; } +// wholememory_error_code_t communicator_get_clique_rank(int* clique_rank, +// wholememory_comm_t comm) noexcept +// { +// *clique_rank = comm->clique_rank; +// return WHOLEMEMORY_SUCCESS; +// } + +// wholememory_error_code_t communicator_get_clique_size(int* clique_size, +// wholememory_comm_t comm) noexcept +// { +// *clique_size = comm->clique_rank_num; +// return WHOLEMEMORY_SUCCESS; +// } + +wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info, + wholememory_comm_t comm) noexcept +{ + *clique_info = comm->clique_info; + return WHOLEMEMORY_SUCCESS; +} + bool communicator_is_bind_to_nvshmem(wholememory_comm_t comm) noexcept { #ifdef WITH_NVSHMEM_SUPPORT @@ -772,6 +944,11 @@ void communicator_barrier(wholememory_comm_t comm) bool is_intranode_communicator(wholememory_comm_t comm) noexcept { return comm->is_intranode(); } +bool is_intra_mnnvl_communicator(wholememory_comm_t comm) noexcept +{ + return comm->is_intra_mnnvl(); +} + #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t init_nvshmem_with_comm(wholememory_comm_t comm) noexcept { diff --git a/cpp/src/wholememory/communicator.hpp b/cpp/src/wholememory/communicator.hpp index 5ed68a9df..b48d66b77 100644 --- a/cpp/src/wholememory/communicator.hpp +++ b/cpp/src/wholememory/communicator.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -192,6 +192,7 @@ struct wholememory_comm_ { bool is_intranode() const; + bool is_intra_mnnvl() const; bool support_type_location(wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location) const; @@ -212,10 +213,13 @@ struct wholememory_comm_ { int intra_node_rank_num = 0; int intra_node_first_rank_pid = -1; + clique_info_t clique_info; + int comm_id = -1; int dev_id = -1; int local_gpu_ids[16] = {0}; + bool support_mnnvl = false; size_t alloc_granularity = 2 * 1024 * 1024UL; @@ -267,6 +271,11 @@ wholememory_error_code_t create_communicator(wholememory_comm_t* comm, int rank, int size) noexcept; +wholememory_error_code_t split_communicator(wholememory_comm_t* new_comm, + wholememory_comm_t parent_comm, + int color, + int key) noexcept; + wholememory_error_code_t destroy_communicator_locked(wholememory_comm_t comm) noexcept; wholememory_error_code_t destroy_communicator(wholememory_comm_t comm) noexcept; @@ -282,10 +291,15 @@ wholememory_error_code_t communicator_get_rank(int* rank, wholememory_comm_t com wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t comm) noexcept; +wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info, + wholememory_comm_t comm) noexcept; + void communicator_barrier(wholememory_comm_t comm); bool is_intranode_communicator(wholememory_comm_t comm) noexcept; +bool is_intra_mnnvl_communicator(wholememory_comm_t comm) noexcept; + std::string get_temporary_directory_path(wholememory_comm_t comm); std::string get_shm_prefix(wholememory_comm_t comm); diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index ca8b0ad75..2e55edb62 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1318,7 +1318,7 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i void check_valid() { if (location_ == WHOLEMEMORY_ML_HOST) { WHOLEMEMORY_CHECK_NOTHROW(SupportEGM()); } - WHOLEMEMORY_CHECK_NOTHROW(SupportMNNVL()); + WHOLEMEMORY_CHECK_NOTHROW(comm_->is_intra_mnnvl()); } void create_memory() override { diff --git a/cpp/src/wholememory/nccl_comms.cpp b/cpp/src/wholememory/nccl_comms.cpp index b06313551..4f6f96806 100644 --- a/cpp/src/wholememory/nccl_comms.cpp +++ b/cpp/src/wholememory/nccl_comms.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -514,4 +514,6 @@ void nccl_comms::group_start() const { RAFT_NCCL_TRY(ncclGroupStart()); } void nccl_comms::group_end() const { RAFT_NCCL_TRY(ncclGroupEnd()); } +ncclComm_t nccl_comms::raw_nccl_comm() const { return nccl_comm_; } + } // namespace wholememory diff --git a/cpp/src/wholememory/nccl_comms.hpp b/cpp/src/wholememory/nccl_comms.hpp index 49babab9d..55eab2437 100644 --- a/cpp/src/wholememory/nccl_comms.hpp +++ b/cpp/src/wholememory/nccl_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -197,6 +197,7 @@ class nccl_comms { void group_start() const; void group_end() const; + ncclComm_t raw_nccl_comm() const; private: ncclComm_t nccl_comm_; diff --git a/cpp/src/wholememory/system_info.cpp b/cpp/src/wholememory/system_info.cpp index c8a35f400..01c124a6f 100644 --- a/cpp/src/wholememory/system_info.cpp +++ b/cpp/src/wholememory/system_info.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,21 @@ #include "cuda_macros.hpp" +#include "logger.hpp" +#include "system_info.hpp" +#include "wholememory/wholememory.h" +#if CUDA_VERSION >= 12030 +#include + +namespace { + +std::mutex lock; // NVML has had some thread safety bugs +bool nvmlInitialized = false; +thread_local bool threadInitialized = false; +wholememory_error_code_t initResult; +}; // namespace + +#endif bool DevAttrPagebleMemoryAccess() { int current_dev_id = -1; @@ -88,16 +103,55 @@ const char* GetCPUArch() return arch_str; } -bool SupportMNNVL() -{ - // TODO: replace with NVML, nvmlDeviceGetGpuFabricInfo - return GetCudaCompCap() >= 90; -} - bool SupportEGM() { std::string const arch_str = GetCPUArch(); return arch_str == "arm64" && DevAttrPagebleMemoryAccess(); } -bool SupportMNNVLForEGM() { return SupportMNNVL() && SupportEGM(); } +// bool SupportMNNVLForEGM() { return SupportMNNVL() && SupportEGM(); } +#if CUDA_VERSION >= 12030 + +namespace wholememory { + +wholememory_error_code_t NvmlEnsureInitialized() +{ + // Optimization to avoid repeatedly grabbing the lock when we only want to + // read from the global tables. + if (threadInitialized) return initResult; + threadInitialized = true; + + std::lock_guard locked(lock); + + if (nvmlInitialized) return initResult; + nvmlInitialized = true; + nvmlReturn_t nvml_res = nvmlInit(); + if (nvml_res != NVML_SUCCESS) { + WHOLEMEMORY_ERROR("nvmlInit() failed, the error is %s", nvmlErrorString(nvml_res)); + initResult = WHOLEMEMORY_SYSTEM_ERROR; + + return initResult; + } + initResult = WHOLEMEMORY_SUCCESS; + + return initResult; +} + +wholememory_error_code_t GetGpuFabricInfo(int dev, nvmlGpuFabricInfo_t* gpuFabricInfo) +{ + WHOLEMEMORY_CHECK_NOTHROW(NvmlEnsureInitialized() == WHOLEMEMORY_SUCCESS); + std::lock_guard locked(lock); + // gpuFabricInfo->version = nvmlGpuFabricInfo_v2; + nvmlDevice_t nvml_device; + nvmlReturn_t ret = nvmlDeviceGetHandleByIndex(dev, &nvml_device); + WHOLEMEMORY_EXPECTS_NOTHROW( + ret == NVML_SUCCESS, "nvmlDeviceGetHandleByIndex error:%s", nvmlErrorString(ret)); + ret = nvmlDeviceGetGpuFabricInfo(nvml_device, gpuFabricInfo); + WHOLEMEMORY_EXPECTS_NOTHROW( + ret == NVML_SUCCESS, "nvmlDeviceGetGpuFabricInfo error:%s", nvmlErrorString(ret)); + + return WHOLEMEMORY_SUCCESS; +} + +}; // namespace wholememory +#endif diff --git a/cpp/src/wholememory/system_info.hpp b/cpp/src/wholememory/system_info.hpp index f62364300..a157924eb 100644 --- a/cpp/src/wholememory/system_info.hpp +++ b/cpp/src/wholememory/system_info.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,11 @@ */ #pragma once +#include "wholememory/wholememory.h" + +#if CUDA_VERSION >= 12030 +#include +#endif bool DevAttrPagebleMemoryAccess(); bool DeviceCanAccessPeer(int peer_device); @@ -29,4 +34,10 @@ bool SupportMNNVL(); bool SupportEGM(); -bool SupportMNNVLForEGM(); +// bool SupportMNNVLForEGM(); +#if CUDA_VERSION >= 12030 +namespace wholememory { +wholememory_error_code_t GetGpuFabricInfo(int dev, nvmlGpuFabricInfo_t* gpuFabricInfo); +} + +#endif diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 180da2f01..814e90087 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -45,6 +45,14 @@ wholememory_error_code_t wholememory_create_communicator(wholememory_comm_t* com return wholememory::create_communicator(comm, unique_id, rank, size); } +wholememory_error_code_t wholememory_split_communicator(wholememory_comm_t* new_comm, + wholememory_comm_t comm, + int color, + int key) +{ + return wholememory::split_communicator(new_comm, comm, color, key); +} + wholememory_error_code_t wholememory_destroy_communicator(wholememory_comm_t comm) { return wholememory::destroy_communicator(comm); @@ -266,6 +274,17 @@ bool wholememory_is_intranode_communicator(wholememory_comm_t comm) return wholememory::is_intranode_communicator(comm); } +bool wholememory_is_intra_mnnvl_communicator(wholememory_comm_t comm) +{ + return wholememory::is_intra_mnnvl_communicator(comm); +} + +wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info, + wholememory_comm_t comm) +{ + return wholememory::communicator_get_clique_info(clique_info, comm); +} + bool wholememory_is_build_with_nvshmem() { #ifdef WITH_NVSHMEM_SUPPORT diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index ddf5de544..dc72eb32c 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -185,6 +185,24 @@ cdef extern from "wholememory/wholememory.h": cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend( wholememory_comm_t comm) cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm) + cdef bool wholememory_is_intra_mnnvl_communicator(wholememory_comm_t comm) + + + cdef struct clique_info_t: + int is_in_clique + int clique_first_rank + int clique_rank + int clique_rank_num + int clique_id + int clique_num + + cdef wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info, wholememory_comm_t comm) + + + cdef wholememory_error_code_t wholememory_split_communicator(wholememory_comm_t* new_comm, + wholememory_comm_t comm, + int color, + int key) cpdef enum WholeMemoryErrorCode: Success = WHOLEMEMORY_SUCCESS @@ -1267,6 +1285,14 @@ cdef class PyWholeMemoryComm: cdef int world_size = -1 check_wholememory_error_code(wholememory_communicator_get_size(&world_size, self.comm_id)) return world_size + def get_clique_info(self): + cdef clique_info_t clique_info + check_wholememory_error_code(wholememory_communicator_get_clique_info(&clique_info,self.comm_id)) + + cdef bint is_in_clique = clique_info.is_in_clique > 0 + + return is_in_clique,clique_info.clique_first_rank,clique_info.clique_rank,clique_info.clique_rank_num,clique_info.clique_id,clique_info.clique_num + def barrier(self): check_wholememory_error_code(wholememory_communicator_barrier(self.comm_id)) @@ -1628,6 +1654,10 @@ def create_communicator(PyWholeMemoryUniqueID py_uid, int world_rank, int world_ def destroy_communicator(PyWholeMemoryComm py_comm): check_wholememory_error_code(wholememory_destroy_communicator(py_comm.comm_id)) +def split_communicator(PyWholeMemoryComm comm,int color,int key): + py_comm = PyWholeMemoryComm() + check_wholememory_error_code(wholememory_split_communicator(&py_comm.comm_id,comm.comm_id,color,key)) + return py_comm def communicator_set_distributed_backend(PyWholeMemoryComm py_comm,WholeMemoryDistributedBackend distributed_backend): check_wholememory_error_code(wholememory_communicator_set_distributed_backend(py_comm.comm_id,int(distributed_backend))) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/__init__.py b/python/pylibwholegraph/pylibwholegraph/torch/__init__.py index ced391605..873f0d729 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/__init__.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,6 +20,8 @@ get_global_communicator, get_local_node_communicator, get_local_device_communicator, + split_communicator, + get_local_mnnvl_communicator, ) from .embedding import ( diff --git a/python/pylibwholegraph/pylibwholegraph/torch/comm.py b/python/pylibwholegraph/pylibwholegraph/torch/comm.py index aa15d3a0a..1f8d9f520 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/comm.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/comm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,12 +19,13 @@ str_to_wmb_wholememory_distributed_backend_type, wholememory_distributed_backend_type_to_str, str_to_wmb_wholememory_memory_type, - str_to_wmb_wholememory_location + str_to_wmb_wholememory_location, ) global_communicators = {} local_node_communicator = None local_device_communicator = None +local_mnnvl_communicator = None all_comm_world_rank = 0 all_comm_world_size = 1 @@ -34,10 +35,11 @@ def reset_communicators(): global all_comm_world_rank, all_comm_world_size, all_comm_local_rank, all_comm_local_size - global global_communicators, local_node_communicator, local_device_communicator + global global_communicators, local_node_communicator, local_device_communicator, local_mnnvl_communicator global_communicators = {} local_node_communicator = None local_device_communicator = None + local_mnnvl_communicator = None all_comm_world_rank = 0 all_comm_world_size = 1 @@ -82,6 +84,18 @@ def get_size(self): """Get world size of this communicator""" return self.wmb_comm.get_size() + def get_clique_info(self): + """Get info of clique where current process is located, a clique is made up of GPUs in same mnnvl domain. + return: + is_in_clique: is_in_clique >0 means the gpu belongs to a mnnvl domain + clique_first_rank; // the rank in the comm of first gpu in the clique , + clique_rank; // the rank of gpu in a mnnvl domain + clique_rank_num; // the num of gpu in the mnnvl domain + clique_id; // the id of clique + clique_num; // the num of clique in the comm domain. + """ + return self.wmb_comm.get_clique_info() + def barrier(self): """ Barrier on WholeMemory Communicator. @@ -91,9 +105,7 @@ def barrier(self): """ return self.wmb_comm.barrier() - def support_type_location(self, - memory_type: str, - memory_location: str): + def support_type_location(self, memory_type: str, memory_location: str): """ Return True if Communicator supports combination of memory_type and memory_location. """ @@ -107,11 +119,15 @@ def destroy(self): @property def distributed_backend(self): - return wholememory_distributed_backend_type_to_str(self.wmb_comm.get_distributed_backend()) + return wholememory_distributed_backend_type_to_str( + self.wmb_comm.get_distributed_backend() + ) @distributed_backend.setter def distributed_backend(self, value): - self.wmb_comm.set_distributed_backend(str_to_wmb_wholememory_distributed_backend_type(value)) + self.wmb_comm.set_distributed_backend( + str_to_wmb_wholememory_distributed_backend_type(value) + ) def create_group_communicator(group_size: int = -1, comm_stride: int = 1): @@ -152,6 +168,21 @@ def create_group_communicator(group_size: int = -1, comm_stride: int = 1): return WholeMemoryCommunicator(wm_comm) +def split_communicator(comm: WholeMemoryCommunicator, color: int, key: int = 0): + """Split Communicator. + Creates a set of new communicators from an existing one. Ranks which pass the same color value will be part of the + same group; color must be a non-negative value. + The value of key will determine the rank order, and the smaller key means the smaller rank in new communicator. + If keys are equal between ranks, then the rank in the original communicator will be used to order ranks. + """ + if not isinstance(color, int) or not isinstance(key, int): + raise TypeError("color and key must be int") + if color < 0: + return None + new_wm_comm = wmb.split_communicator(comm.wmb_comm, color, key) + return WholeMemoryCommunicator(new_wm_comm) + + def destroy_communicator(wm_comm: WholeMemoryCommunicator): """ Destroy WholeMemoryCommunicator @@ -163,19 +194,24 @@ def destroy_communicator(wm_comm: WholeMemoryCommunicator): wm_comm.wmb_comm = None -def get_global_communicator(distributed_backend='nccl'): +def get_global_communicator(distributed_backend="nccl"): """ Get the global communicator of this job :return: WholeMemoryCommunicator that has all GPUs in it. """ - global global_communicators, local_node_communicator, local_device_communicator + global global_communicators, local_node_communicator, local_device_communicator, local_mnnvl_communicator global all_comm_local_size, all_comm_world_size if distributed_backend not in global_communicators: global_communicator = create_group_communicator() comm_set_distributed_backend(global_communicator, distributed_backend) global_communicators[distributed_backend] = global_communicator - if distributed_backend == 'nccl': # local_node/device_communicator can only be nccl backend for now - if local_node_communicator is None and all_comm_local_size == all_comm_world_size: + if ( + distributed_backend == "nccl" + ): # local_node/device_communicator can only be nccl backend for now + if ( + local_node_communicator is None + and all_comm_local_size == all_comm_world_size + ): local_node_communicator = global_communicator if local_device_communicator is None and all_comm_world_size == 1: local_device_communicator = global_communicator @@ -187,13 +223,13 @@ def get_local_node_communicator(): Get the local node communicator of this job :return: WholeMemoryCommunicator that has GPUs in the same node. """ - global global_communicators, local_node_communicator, local_device_communicator + global global_communicators, local_node_communicator, local_device_communicator, local_mnnvl_communicator global all_comm_local_size, all_comm_world_size if local_node_communicator is None: local_node_communicator = create_group_communicator(all_comm_local_size) if all_comm_local_size == all_comm_world_size: - assert 'nccl' not in global_communicators - global_communicators['nccl'] = local_node_communicator + assert "nccl" not in global_communicators + global_communicators["nccl"] = local_node_communicator if all_comm_local_size == 1: assert local_device_communicator is None local_device_communicator = local_node_communicator @@ -205,7 +241,7 @@ def get_local_device_communicator(): Get the local device communicator of this job :return: WholeMemoryCommunicator that has only the GPU belonging to current process. """ - global global_communicators, local_node_communicator, local_device_communicator + global global_communicators, local_node_communicator, local_device_communicator, local_mnnvl_communicator global all_comm_local_size, all_comm_world_size if local_device_communicator is None: local_device_communicator = create_group_communicator(1) @@ -213,13 +249,42 @@ def get_local_device_communicator(): assert local_node_communicator is None local_node_communicator = local_device_communicator if all_comm_world_size == 1: - assert 'nccl' not in global_communicators - global_communicators['nccl'] = local_device_communicator + assert "nccl" not in global_communicators + global_communicators["nccl"] = local_device_communicator return local_device_communicator -def comm_set_distributed_backend(wm_comm: WholeMemoryCommunicator, distributed_backend: str): +def get_local_mnnvl_communicator(): + """ """ + global global_communicators, local_node_communicator, local_device_communicator, local_mnnvl_communicator + global all_comm_local_size, all_comm_world_size - wmb.communicator_set_distributed_backend(wm_comm.wmb_comm, - str_to_wmb_wholememory_distributed_backend_type(distributed_backend)) + if local_mnnvl_communicator is None: + g_communicator = get_global_communicator() + ( + is_in_clique, + _, + _, + _, + clique_id, + _, + ) = g_communicator.get_clique_info() + if not is_in_clique: + raise RuntimeError( + "the gpu does not belong to any mnnvl domain,can not create local_mnnvl_communicator" + ) + + local_mnnvl_communicator = split_communicator(g_communicator, clique_id) + + return local_mnnvl_communicator + + +def comm_set_distributed_backend( + wm_comm: WholeMemoryCommunicator, distributed_backend: str +): + + wmb.communicator_set_distributed_backend( + wm_comm.wmb_comm, + str_to_wmb_wholememory_distributed_backend_type(distributed_backend), + ) return