From 16870d68424fe02eb1128d96b13699bc13b306d3 Mon Sep 17 00:00:00 2001 From: panzezhong Date: Mon, 19 Feb 2024 15:10:23 +0800 Subject: [PATCH] =?UTF-8?q?feat=20(dist):=20nccl=E9=80=9A=E4=BF=A1?= =?UTF-8?q?=E5=BA=93=E6=8E=A5=E5=85=A5=EF=BC=8Callreduce=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/04kernel/CMakeLists.txt | 3 + src/04kernel/cmake/FindNCCL.cmake | 165 ++++++++++++++++++ .../cuda/include/kernel/cuda/functions.cuh | 2 + src/04kernel/cuda/src/functions.cu | 4 + .../include/kernel/attributes/communication.h | 14 ++ .../include/kernel/collectors/all_reduce.h | 21 +++ src/04kernel/src/collectors/all_reduce.cc | 20 +++ .../src/kernels/all_reduce/nccl_kernel.cc | 32 ++++ .../src/kernels/all_reduce/nccl_kernel.cu | 20 +++ .../src/kernels/all_reduce/nccl_kernel.hh | 28 +++ .../src/utilities/cuda/nccl_communicator.cu | 61 +++++++ .../src/utilities/cuda/nccl_communicator.hh | 84 +++++++++ .../all_reduce/test_allreduce_nccl.cpp | 62 +++++++ .../computation/operators/all_reduce.h | 23 +++ src/05computation/src/operators/all_reduce.cc | 59 +++++++ .../src/operators/all_reduce.cc | 70 +++++++- .../src/operators/all_reduce.hh | 7 +- 17 files changed, 666 insertions(+), 9 deletions(-) create mode 100644 src/04kernel/cmake/FindNCCL.cmake create mode 100644 src/04kernel/include/kernel/attributes/communication.h create mode 100644 src/04kernel/include/kernel/collectors/all_reduce.h create mode 100644 src/04kernel/src/collectors/all_reduce.cc create mode 100644 src/04kernel/src/kernels/all_reduce/nccl_kernel.cc create mode 100644 src/04kernel/src/kernels/all_reduce/nccl_kernel.cu create mode 100644 src/04kernel/src/kernels/all_reduce/nccl_kernel.hh create mode 100644 src/04kernel/src/utilities/cuda/nccl_communicator.cu create mode 100644 src/04kernel/src/utilities/cuda/nccl_communicator.hh create mode 100644 src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp create mode 100644 src/05computation/include/computation/operators/all_reduce.h create mode 100644 src/05computation/src/operators/all_reduce.cc diff --git a/src/04kernel/CMakeLists.txt b/src/04kernel/CMakeLists.txt index b75fff174..77b655c0e 100644 --- a/src/04kernel/CMakeLists.txt +++ b/src/04kernel/CMakeLists.txt @@ -28,6 +28,9 @@ if(USE_CUDA) # cudnn for conv and others target_link_libraries(kernel PUBLIC cuda nvrtc cublas cublasLt cudnn kernel_cuda) target_include_directories(kernel PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + find_package(NCCL REQUIRED) + target_link_libraries(kernel PUBLIC nccl) endif() if(USE_KUNLUN) include_directories(${KUNLUN_HOME}/XTDK/include/) diff --git a/src/04kernel/cmake/FindNCCL.cmake b/src/04kernel/cmake/FindNCCL.cmake new file mode 100644 index 000000000..d2f2f8358 --- /dev/null +++ b/src/04kernel/cmake/FindNCCL.cmake @@ -0,0 +1,165 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Find the nccl libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… +# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") +set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") + +if ($ENV{NCCL_ROOT_DIR}) + message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") +endif() +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) + +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/src/04kernel/cuda/include/kernel/cuda/functions.cuh b/src/04kernel/cuda/include/kernel/cuda/functions.cuh index 1917e7088..23fcf8f70 100644 --- a/src/04kernel/cuda/include/kernel/cuda/functions.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/functions.cuh @@ -6,6 +6,8 @@ namespace refactor::kernel::cuda { int currentDevice(); void sync(); + + void setCudaDevice(int); void copyOut(void *dst, const void *src, size_t size); diff --git a/src/04kernel/cuda/src/functions.cu b/src/04kernel/cuda/src/functions.cu index 0fa84f175..adf98a98d 100644 --- a/src/04kernel/cuda/src/functions.cu +++ b/src/04kernel/cuda/src/functions.cu @@ -19,4 +19,8 @@ namespace refactor::kernel::cuda { CUDA_ASSERT(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); } + void setCudaDevice(int id) { + cudaSetDevice(id); + } + }// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/communication.h b/src/04kernel/include/kernel/attributes/communication.h new file mode 100644 index 000000000..7cfc46513 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/communication.h @@ -0,0 +1,14 @@ +#ifndef KERNEL_COMMUNICATION_ATTRIBUTES_H +#define KERNEL_COMMUNICATION_ATTRIBUTES_H + +namespace refactor::kernel { + enum class AllReduceType { + Sum, + Avg, + Min, + Max, + Prod + }; +} + +#endif diff --git a/src/04kernel/include/kernel/collectors/all_reduce.h b/src/04kernel/include/kernel/collectors/all_reduce.h new file mode 100644 index 000000000..7245b2b43 --- /dev/null +++ b/src/04kernel/include/kernel/collectors/all_reduce.h @@ -0,0 +1,21 @@ +#ifndef KERNEL_COLLECTOR_ALL_REDUCE_H +#define KERNEL_COLLECTOR_ALL_REDUCE_H + +#include "../collector.h" +#include "kernel/attributes/communication.h" + +namespace refactor::kernel { + + struct AllReduceCollector final : public InfoCollector { + + AllReduceType type; + + constexpr AllReduceCollector(decltype(_target) target, AllReduceType type_) noexcept + : InfoCollector(target), type(type_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; +}// namespace refactor::kernel + +#endif diff --git a/src/04kernel/src/collectors/all_reduce.cc b/src/04kernel/src/collectors/all_reduce.cc new file mode 100644 index 000000000..72aa9bb93 --- /dev/null +++ b/src/04kernel/src/collectors/all_reduce.cc @@ -0,0 +1,20 @@ +#include "kernel/collectors/all_reduce.h" +#include "../kernels/all_reduce/nccl_kernel.hh" +namespace refactor::kernel { + std::vector + AllReduceCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + if (auto ptr = AllReduceNccl::build(type, inputs[0], outputs[0]); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/all_reduce/nccl_kernel.cc b/src/04kernel/src/kernels/all_reduce/nccl_kernel.cc new file mode 100644 index 000000000..6308043f3 --- /dev/null +++ b/src/04kernel/src/kernels/all_reduce/nccl_kernel.cc @@ -0,0 +1,32 @@ +#include "nccl_kernel.hh" + +namespace refactor::kernel { + using K = AllReduceNccl; + using DT = DataType; + + K::AllReduceNccl(AllReduceType opType_, DT dataType_, size_t size_) noexcept + : opType(opType_), dataType(dataType_), size(size_) {} + + auto K::build(AllReduceType opType_, Tensor const &input, Tensor const &output) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + if (input.elementsSize() != output.elementsSize() || + input.dataType != output.dataType) { + return nullptr; + } + + return std::make_unique(opType_, input.dataType, input.elementsSize()); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing AllReduce using NCCL"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/all_reduce/nccl_kernel.cu b/src/04kernel/src/kernels/all_reduce/nccl_kernel.cu new file mode 100644 index 000000000..a58130338 --- /dev/null +++ b/src/04kernel/src/kernels/all_reduce/nccl_kernel.cu @@ -0,0 +1,20 @@ +#include "nccl_kernel.hh" +#include "../../utilities/cuda/nccl_communicator.hh" +#include +namespace refactor::kernel { + using K = AllReduceNccl; + using DT = DataType; + using namespace nccl; + + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace{ + return [count = size, + redOp = getRedOp(opType), + ncclDataType = getNcclDataType(dataType)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto communicator = res.fetch(); + auto input = inputs[0]; + auto output = outputs[0]; + checkNcclError(ncclAllReduce(input, output, count, ncclDataType, + redOp, communicator->get(), 0));// TODO: use default stream for now + }; + } +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/all_reduce/nccl_kernel.hh b/src/04kernel/src/kernels/all_reduce/nccl_kernel.hh new file mode 100644 index 000000000..d25f2325b --- /dev/null +++ b/src/04kernel/src/kernels/all_reduce/nccl_kernel.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_ALLREDUCE_NCCL_KERNEL_HH +#define KERNEL_ALLREDUCE_NCCL_KERNEL_HH + +#include "kernel/collectors/all_reduce.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct AllReduceNccl final : public Kernel { + AllReduceType opType; + DataType dataType; + size_t size; + + AllReduceNccl(AllReduceType, DataType, size_t) noexcept; + + static KernelBox build(AllReduceType, Tensor const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ALLREDUCE_NCCL_KERNEL_HH diff --git a/src/04kernel/src/utilities/cuda/nccl_communicator.cu b/src/04kernel/src/utilities/cuda/nccl_communicator.cu new file mode 100644 index 000000000..fb91b55f6 --- /dev/null +++ b/src/04kernel/src/utilities/cuda/nccl_communicator.cu @@ -0,0 +1,61 @@ +#include "common.h" +#include "nccl_communicator.hh" +#include +#include +#include +#include +#include + + +namespace refactor::kernel::nccl { + NcclCommunicator::NcclCommunicator(int worldSize, int rank) : worldSize_(worldSize), rank_(rank) { + const std::string filePath("./nccl_id.bin"); + + ncclUniqueId commId; + + if (rank == 0) { + checkNcclError(ncclGetUniqueId(&commId)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *) &commId, sizeof(ncclUniqueId)); + + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + ASSERT(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *) &commId, sizeof(ncclUniqueId)); + } + checkNcclError(ncclCommInitRank(&comm, worldSize, commId, rank)); + + if (rank == 0) { + std::filesystem::remove(filePath); + } + + printf("Rank %d established NCCL communicator.\n", rank); + } + + NcclCommunicator::~NcclCommunicator() { + checkNcclError(ncclCommFinalize(comm)); + checkNcclError(ncclCommDestroy(comm)); + } + + auto NcclCommunicator::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto NcclCommunicator::build(int worldSize, int rank) noexcept -> runtime::ResourceBox { + return std::make_unique(worldSize, rank); + } + + auto NcclCommunicator::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto NcclCommunicator::description() const noexcept -> std::string_view { + return "NcclCommunicator"; + } + +}// namespace refactor::kernel::nccl diff --git a/src/04kernel/src/utilities/cuda/nccl_communicator.hh b/src/04kernel/src/utilities/cuda/nccl_communicator.hh new file mode 100644 index 000000000..64389300a --- /dev/null +++ b/src/04kernel/src/utilities/cuda/nccl_communicator.hh @@ -0,0 +1,84 @@ +#ifndef NCCL_COMMUNICATOR_HH +#define NCCL_COMMUNICATOR_HH + +#include "kernel/attributes/communication.h" +#include "runtime/resource.h" +#include + +#define checkNcclError(call) \ + { \ + auto err = call; \ + if (ncclSuccess != err) { \ + fprintf(stderr, "NCCL error in %s:%i : %s.\n", __FILE__, __LINE__, \ + ncclGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } + + +namespace refactor::kernel::nccl { + + inline ncclRedOp_t getRedOp(kernel::AllReduceType t) { + switch (t) { + case kernel::AllReduceType::Sum: + return ncclSum; + case kernel::AllReduceType::Avg: + return ncclAvg; + case kernel::AllReduceType::Min: + return ncclMin; + case kernel::AllReduceType::Max: + return ncclMax; + case kernel::AllReduceType::Prod: + return ncclProd; + default: + return ncclSum; + } + } + + inline ncclDataType_t getNcclDataType(DataType dataType) { + switch (dataType) { + case DataType::F32: + return ncclFloat32; + case DataType::U8: + return ncclUint8; + case DataType::I8: + return ncclInt8; + case DataType::I32: + return ncclInt32; + case DataType::I64: + return ncclInt64; + case DataType::FP16: + return ncclFloat16; + case DataType::F64: + return ncclFloat64; + case DataType::U32: + return ncclUint32; + case DataType::BF16: + return ncclBfloat16; + default: + RUNTIME_ERROR("Datatype not supported by NCCL."); + } + } + + + class NcclCommunicator final : public runtime::Resource { + private: + ncclComm_t comm; + int const worldSize_, rank_; + + public: + NcclCommunicator(int worldSize, int rank); + ~NcclCommunicator(); + ncclComm_t get() { return comm; } + int getWorldSize() { return worldSize_; } + int getRank() { return rank_; } + static size_t typeId() noexcept; + static runtime::ResourceBox build(int worldSize, int rank) noexcept; + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + }; + +}// namespace refactor::kernel::nccl + +#endif diff --git a/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp b/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp new file mode 100644 index 000000000..63c166089 --- /dev/null +++ b/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp @@ -0,0 +1,62 @@ +#ifdef USE_CUDA + +#include "../src/kernels/all_reduce/nccl_kernel.hh" +#include "../src/utilities/cuda/nccl_communicator.hh" +#include "kernel/cuda/functions.cuh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace nccl; +using namespace cuda; +using namespace hardware; + +void allReduce(AllReduceType redType, int rank, int worldSize, std::vector data, std::vector ans) { + cuda::setCudaDevice(rank); + auto &dev = *device::init(Device::Type::Nvidia, rank, ""); + auto input = Tensor::share(DataType::F32, Shape{2}, LayoutType::NCHW); + auto output = Tensor::share(DataType::F32, Shape{2}, LayoutType::NCHW); + auto kernel = AllReduceNccl::build(redType, *input, *output); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + res.fetchOrStore(worldSize, rank); + auto routine = kernel->lower(res).routine; + + + auto inGPU = dev.malloc(input->bytesSize()); + auto outGPU = dev.malloc(output->bytesSize()); + inGPU->copyFromHost(data.data(), input->bytesSize()); + + void const *inputs[]{*inGPU}; + void *outputs[]{*outGPU}; + routine(res, nullptr, inputs, outputs); + + std::vector result(output->elementsSize()); + outGPU->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(ans[i], result[i]); + } +} + + +TEST(kernel, NCCL_AllReduceSum) { + std::vector data[2] = {{2., 3.}, {5., 6.}}; + std::vector ans = {7., 9.}; + int worldSize = 2; + + std::vector threads; + for (int gpu = 0; gpu < worldSize; ++gpu) { + threads.emplace_back(allReduce, AllReduceType::Sum, + gpu, worldSize, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } + + // Reset device context for following tests + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + dev.setContext(); +} +#endif diff --git a/src/05computation/include/computation/operators/all_reduce.h b/src/05computation/include/computation/operators/all_reduce.h new file mode 100644 index 000000000..47a7fce9d --- /dev/null +++ b/src/05computation/include/computation/operators/all_reduce.h @@ -0,0 +1,23 @@ +#ifndef COMPUTATION_ALL_REDUCE_H +#define COMPUTATION_ALL_REDUCE_H + +#include "../operator.h" +#include "kernel/attributes/communication.h" + +namespace refactor::computation { + + struct AllReduce final : public Operator { + kernel::AllReduceType type; + + constexpr explicit AllReduce(kernel::AllReduceType type_) noexcept + : Operator(), type(type_) {} + + static size_t typeId(kernel::AllReduceType) noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_ALL_REDUCE_H diff --git a/src/05computation/src/operators/all_reduce.cc b/src/05computation/src/operators/all_reduce.cc new file mode 100644 index 000000000..c5658c1c7 --- /dev/null +++ b/src/05computation/src/operators/all_reduce.cc @@ -0,0 +1,59 @@ +#include "computation/operators/all_reduce.h" +#include "kernel/collectors/all_reduce.h" + +namespace refactor::computation { + using Op = AllReduce; + using Ty = kernel::AllReduceType; + + auto Op::typeId(Ty type_) noexcept -> size_t { + switch (type_) { + case Ty::Sum: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Avg: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Min: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Max: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Prod: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + default: + UNREACHABLE(); + } + } + + auto Op::opTypeId() const noexcept -> size_t { return typeId(type); } + + auto Op::name() const noexcept -> std::string_view { + switch (type) { + case Ty::Sum: + return "AllReduceSum"; + case Ty::Avg: + return "AllReduceAvg"; + case Ty::Min: + return "AllReduceMin"; + case Ty::Max: + return "AllReduceMax"; + case Ty::Prod: + return "AllReduceProd"; + default: + UNREACHABLE(); + } + } + + auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::AllReduceCollector; + return std::make_unique(target, type); + } + +}// namespace refactor::computation diff --git a/src/08communication/src/operators/all_reduce.cc b/src/08communication/src/operators/all_reduce.cc index 989ed6ad9..98eb339bf 100644 --- a/src/08communication/src/operators/all_reduce.cc +++ b/src/08communication/src/operators/all_reduce.cc @@ -1,19 +1,70 @@ #include "all_reduce.hh" #include "common.h" +#include "computation/operators/all_reduce.h" namespace refactor::communication { using Op = AllReduce; + using Ty = kernel::AllReduceType; - auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox { - return OpBox(std::make_unique()); + Op::AllReduce(Ty type_) : Operator(), type(type_) {} + + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + // clang-format off + auto type_ = + opType == "onnx::AllReduceAvg" ? Ty::Avg : + opType == "onnx::AllReduceSum" ? Ty::Sum : + opType == "onnx::AllReduceMin" ? Ty::Min : + opType == "onnx::AllReduceMax" ? Ty::Max : + opType == "onnx::AllReduceProd" ? Ty::Prod : + UNREACHABLEX(Ty, "Unsupported allReduce operator: {}", opType); + // clang-format on + return OpBox(std::make_unique(type_)); } - auto Op::typeId() -> size_t { - static uint8_t ID = 1; - return reinterpret_cast(&ID); + + auto Op::typeId(Ty type_) -> size_t { + switch (type_) { + case Ty::Sum: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Avg: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Min: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Max: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + case Ty::Prod: { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + default: + UNREACHABLE(); + } } - auto Op::opTypeId() const -> size_t { return typeId(); } - auto Op::opTypeName() const -> std::string_view { return "AllReduce"; } + auto Op::opTypeId() const -> size_t { return typeId(type); } + auto Op::opTypeName() const -> std::string_view { + switch (type) { + case Ty::Sum: + return "AllReduceSum"; + case Ty::Avg: + return "AllReduceAvg"; + case Ty::Min: + return "AllReduceMin"; + case Ty::Max: + return "AllReduceMax"; + case Ty::Prod: + return "AllReduceProd"; + default: + UNREACHABLE(); + } + } auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { EXPECT_SIZE(1) @@ -23,4 +74,9 @@ namespace refactor::communication { extractDependency(inputs))}); } + computation::OpBox Op::lower(TensorRefs inputs) const { + + return std::make_unique(type); + } + }// namespace refactor::communication diff --git a/src/08communication/src/operators/all_reduce.hh b/src/08communication/src/operators/all_reduce.hh index a14bc74f4..d7701add7 100644 --- a/src/08communication/src/operators/all_reduce.hh +++ b/src/08communication/src/operators/all_reduce.hh @@ -2,20 +2,23 @@ #define COMMUNICATION_ALL_REDUCE_HH #include "frontend/operator.h" +#include "kernel/attributes/communication.h" namespace refactor::communication { using namespace frontend; struct AllReduce final : public Operator { + kernel::AllReduceType type; - constexpr AllReduce() noexcept = default; + AllReduce(kernel::AllReduceType); static OpBox build(ModelContext const &, std::string_view, Attributes); - static size_t typeId(); + static size_t typeId(kernel::AllReduceType); size_t opTypeId() const final; std::string_view opTypeName() const final; InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; }; }// namespace refactor::communication