diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 6ad4ba610..2d039122b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 2c2578e04..339646eca 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.06 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.06 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 489348971..348476641 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/VERSION b/VERSION index 4a2fe8aa5..0bff6981a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -24.04.00 +24.06.00 diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 88774f94a..61fa8ec22 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -22,7 +22,7 @@ rapids-print-env rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) -export RAPIDS_VERSION_NUMBER="24.04" +export RAPIDS_VERSION_NUMBER="24.06" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-mamba-retry install \ diff --git a/ci/test_python.sh b/ci/test_python.sh index 0efa5e8e3..dd56e7b92 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -49,6 +49,7 @@ PACKAGES="pylibwholegraph" rapids-mamba-retry install \ --channel "${CPP_CHANNEL}" \ --channel "${PYTHON_CHANNEL}" \ + 'mkl<2024.1.0' \ "${PACKAGES}" rapids-logger "Check GPU usage" diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index af28b2b52..45fc02021 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.4.* -- librmm==24.4.* +- libraft-headers==24.6.* +- librmm==24.6.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 6486f500f..dd33e60c1 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.4.* -- librmm==24.4.* +- libraft-headers==24.6.* +- librmm==24.6.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/recipes/libwholegraph/conda_build_config.yaml b/conda/recipes/libwholegraph/conda_build_config.yaml index aad996394..52573b012 100644 --- a/conda/recipes/libwholegraph/conda_build_config.yaml +++ b/conda/recipes/libwholegraph/conda_build_config.yaml @@ -19,11 +19,8 @@ doxygen_version: nccl_version: - ">=2.9.9" -gtest_version: - - ">=1.13.0" +c_stdlib: + - sysroot -gmock_version: - - ">=1.13.0" - -sysroot_version: +c_stdlib_version: - "2.17" diff --git a/conda/recipes/libwholegraph/meta.yaml b/conda/recipes/libwholegraph/meta.yaml index fd1b3dfa9..e4c400e60 100644 --- a/conda/recipes/libwholegraph/meta.yaml +++ b/conda/recipes/libwholegraph/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. {% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') + environ.get('VERSION_SUFFIX', '') %} {% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} @@ -49,7 +49,7 @@ requirements: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: {% if cuda_major == "11" %} - cudatoolkit @@ -59,8 +59,6 @@ requirements: {% endif %} - cuda-version ={{ cuda_version }} - doxygen {{ doxygen_version }} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} - libraft ={{ minor_version }} - libraft-headers ={{ minor_version }} - librmm ={{ minor_version }} @@ -134,8 +132,6 @@ outputs: {% else %} - cuda-cudart {% endif %} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} about: home: https://rapids.ai/ license: Apache-2.0 diff --git a/conda/recipes/pylibwholegraph/conda_build_config.yaml b/conda/recipes/pylibwholegraph/conda_build_config.yaml index d45aacc92..41050978a 100644 --- a/conda/recipes/pylibwholegraph/conda_build_config.yaml +++ b/conda/recipes/pylibwholegraph/conda_build_config.yaml @@ -16,5 +16,8 @@ cmake_version: scikit_build_core_version: - ">=0.7.0" -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" diff --git a/conda/recipes/pylibwholegraph/meta.yaml b/conda/recipes/pylibwholegraph/meta.yaml index 1caa9573f..829350851 100644 --- a/conda/recipes/pylibwholegraph/meta.yaml +++ b/conda/recipes/pylibwholegraph/meta.yaml @@ -54,7 +54,7 @@ requirements: - cmake {{ cmake_version }} - ninja - doxygen =1.8.20 - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - cuda-version ={{ cuda_version }} {% if cuda_major == "11" %} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index dc75bd99c..b3fdc6d74 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. #============================================================================= -set(RAPIDS_VERSION "24.04") +set(RAPIDS_VERSION "24.06") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) @@ -213,7 +213,8 @@ endif() # optionally build tests if(BUILD_TESTS AND CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) - include(./cmake/thirdparty/get_gtest.cmake) + include(${rapids-cmake-dir}/cpm/gtest.cmake) + rapids_cpm_gtest(BUILD_STATIC) include(CTest) # calls enable_testing() add_subdirectory(tests) diff --git a/cpp/Doxyfile b/cpp/Doxyfile index e480d8ef4..3e4e9e53f 100644 --- a/cpp/Doxyfile +++ b/cpp/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "WholeGraph C API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 24.04 +PROJECT_NUMBER = 24.06 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/cpp/cmake/thirdparty/get_gtest.cmake b/cpp/cmake/thirdparty/get_gtest.cmake deleted file mode 100644 index cdc2c5d88..000000000 --- a/cpp/cmake/thirdparty/get_gtest.cmake +++ /dev/null @@ -1,24 +0,0 @@ -#============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= - -function(find_and_configure_gtest) - - include(${rapids-cmake-dir}/cpm/gtest.cmake) - rapids_cpm_gtest() - -endfunction() - -find_and_configure_gtest() diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 66bd993fd..08f16213f 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -387,6 +387,12 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem size_t file_entry_size, const char* local_file_name); +/** + * @param comm : WholeMemory Comm + * @return : bool + */ +bool wholememory_is_intranode_communicator(wholememory_comm_t comm); + bool wholememory_is_build_with_nvshmem(); #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t wholememory_get_nvshmem_reference( diff --git a/cpp/src/graph_ops/append_unique_func.cuh b/cpp/src/graph_ops/append_unique_func.cuh index ff623a22b..761fabb63 100644 --- a/cpp/src/graph_ops/append_unique_func.cuh +++ b/cpp/src/graph_ops/append_unique_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -316,7 +316,7 @@ void graph_append_unique_func(void* target_nodes_ptr, <<>>(value_id, bucket_count_ptr); WM_CUDA_CHECK(cudaGetLastError()); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), bucket_count_ptr, bucket_count_ptr + num_bucket_count, (int*)bucket_prefix_sum_ptr); diff --git a/cpp/src/wholegraph_ops/sample_comm.cuh b/cpp/src/wholegraph_ops/sample_comm.cuh index 6bf4c66e9..29cd1c472 100644 --- a/cpp/src/wholegraph_ops/sample_comm.cuh +++ b/cpp/src/wholegraph_ops/sample_comm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,4 +57,41 @@ __global__ void sample_all_kernel(wholememory_gref_t wm_csr_row_ptr, } } } + +__device__ __forceinline__ int log2_up_device(int x) +{ + if (x <= 2) return x - 1; + return 32 - __clz(x - 1); +} +template +struct ExpandWithOffsetFunc { + const IdType* indptr; + IdType* indptr_shift; + int length; + __host__ __device__ auto operator()(int64_t tIdx) + { + indptr_shift[tIdx] = indptr[tIdx % length] + tIdx / length; + } +}; + +template +struct ReduceForDegrees { + WMIdType* rowoffsets; + DegreeType* in_degree_ptr; + int length; + __host__ __device__ auto operator()(int64_t tIdx) + { + in_degree_ptr[tIdx] = rowoffsets[tIdx + length] - rowoffsets[tIdx]; + } +}; + +template +struct MinInDegreeFanout { + int max_sample_count; + __host__ __device__ auto operator()(DegreeType degree) + { + return min(static_cast(degree), max_sample_count); + } +}; + } // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp index 89fb2a9f9..b835a4bb5 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( } WHOLEMEMORY_EXPECTS_NOTHROW(!csr_row_ptr_has_handle || csr_row_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED || - csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS, + csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS || + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED, "Memory type not supported."); bool const csr_col_ptr_has_handle = wholememory_tensor_has_handle(wm_csr_col_ptr_tensor); wholememory_memory_type_t csr_col_ptr_memory_type = WHOLEMEMORY_MT_NONE; @@ -51,7 +52,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( } WHOLEMEMORY_EXPECTS_NOTHROW(!csr_col_ptr_has_handle || csr_col_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED || - csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS, + csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS || + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED, "Memory type not supported."); auto csr_row_ptr_tensor_description = @@ -108,6 +110,40 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( void* center_nodes = wholememory_tensor_get_data_pointer(center_nodes_tensor); void* output_sample_offset = wholememory_tensor_get_data_pointer(output_sample_offset_tensor); + if (csr_col_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED && + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED) { + wholememory_distributed_backend_t distributed_backend_row = wholememory_get_distributed_backend( + wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor)); + wholememory_distributed_backend_t distributed_backend_col = wholememory_get_distributed_backend( + wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor)); + if (distributed_backend_col == WHOLEMEMORY_DB_NCCL && + distributed_backend_row == WHOLEMEMORY_DB_NCCL) { + wholememory_handle_t wm_csr_row_ptr_handle = + wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor); + wholememory_handle_t wm_csr_col_ptr_handle = + wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor); + return wholegraph_ops::wholegraph_csr_unweighted_sample_without_replacement_nccl( + wm_csr_row_ptr_handle, + wm_csr_col_ptr_handle, + csr_row_ptr_tensor_description, + csr_col_ptr_tensor_description, + center_nodes, + center_nodes_desc, + max_sample_count, + output_sample_offset, + output_sample_offset_desc, + output_dest_memory_context, + output_center_localid_memory_context, + output_edge_gid_memory_context, + random_seed, + p_env_fns, + static_cast(stream)); + } else { + WHOLEMEMORY_ERROR("Only NCCL communication backend is supported for sampling."); + return WHOLEMEMORY_INVALID_INPUT; + } + } + wholememory_gref_t wm_csr_row_ptr_gref, wm_csr_col_ptr_gref; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_tensor_get_global_reference(wm_csr_row_ptr_tensor, &wm_csr_row_ptr_gref)); diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index 291b26b2d..2ee08ce58 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,12 +123,6 @@ __global__ void large_sample_kernel( } } -__device__ __forceinline__ int log2_up_device(int x) -{ - if (x <= 2) return x - 1; - return 32 - __clz(x - 1); -} - template + +#include +#include + +#include "unweighted_sample_without_replacement_nccl_func.cuh" +#include "wholememory_ops/register.hpp" + +namespace wholegraph_ops { + +REGISTER_DISPATCH_TWO_TYPES(UnweightedSampleWithoutReplacementCSRNCCL, + wholegraph_csr_unweighted_sample_without_replacement_nccl_func, + SINT3264, + SINT3264) + +wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl( + wholememory_handle_t csr_row_wholememory_handle, + wholememory_handle_t csr_col_wholememory_handle, + wholememory_tensor_description_t wm_csr_row_ptr_desc, + wholememory_tensor_description_t wm_csr_col_ptr_desc, + void* center_nodes, + wholememory_array_description_t center_nodes_desc, + int max_sample_count, + void* output_sample_offset, + wholememory_array_description_t output_sample_offset_desc, + void* output_dest_memory_context, + void* output_center_localid_memory_context, + void* output_edge_gid_memory_context, + unsigned long long random_seed, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_TWO_TYPES(center_nodes_desc.dtype, + wm_csr_col_ptr_desc.dtype, + UnweightedSampleWithoutReplacementCSRNCCL, + csr_row_wholememory_handle, + csr_col_wholememory_handle, + wm_csr_row_ptr_desc, + wm_csr_col_ptr_desc, + center_nodes, + center_nodes_desc, + max_sample_count, + output_sample_offset, + output_sample_offset_desc, + output_dest_memory_context, + output_center_localid_memory_context, + output_edge_gid_memory_context, + random_seed, + p_env_fns, + stream); + + } catch (const wholememory::cuda_error& rle) { + // WHOLEMEMORY_FAIL_NOTHROW("%s", rle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (const wholememory::logic_error& le) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_LOGIC_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh new file mode 100644 index 000000000..18f4db21a --- /dev/null +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "wholememory_ops/output_memory_handle.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "sample_comm.cuh" + +#include "wholememory_ops/gather_op_impl.h" +using wholememory_ops::wholememory_gather_nccl; +#define WARP_SIZE 32 + +namespace wholegraph_ops { + +template +__global__ void unweighted_sample_without_replacement_nccl_kernel( + const IdType* input_nodes, + const WMIdType* csr_row_ptr_sta, + const DegreeType* in_degree, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + int* src_lid, + int64_t* output_edge_gid_ptr) +{ + int gidx = threadIdx.x + blockIdx.x * blockDim.x; + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); + int input_idx = blockIdx.x; + if (input_idx >= input_node_count) return; + + WMIdType start = csr_row_ptr_sta[input_idx]; + WMIdType end = start + in_degree[input_idx]; + int neighbor_count = (int)(in_degree[input_idx]); + if (neighbor_count <= 0) return; + int offset = sample_offset[input_idx]; + // use all neighbors if neighbors less than max_sample_count + if (neighbor_count <= max_sample_count) { + for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += blockDim.x) { + output_edge_gid_ptr[offset + sample_id] = start + sample_id; + if (src_lid) src_lid[offset + sample_id] = input_idx; + } + return; + } + uint64_t sa_p[ITEMS_PER_THREAD]; + int M = max_sample_count; + int N = neighbor_count; + // UnWeightedIndexSampleWithOutReplacement(M, N, + // sa_p, rng); + typedef cub::BlockRadixSort BlockRadixSort; + struct IntArray { + int value[BLOCK_DIM * ITEMS_PER_THREAD]; + }; + struct SampleSharedData { + IntArray s; + IntArray p; + IntArray q; + IntArray chain; + IntArray last_chain_tmp; + }; + __shared__ union { + typename BlockRadixSort::TempStorage temp_storage; + SampleSharedData sample_shared_data; + } shared_data; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; + int32_t rand_num; + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); + int32_t r = idx < M ? rand_num % (N - idx) : N; + sa_p[i] = ((uint64_t)r << 32UL) | idx; + } + __syncthreads(); + BlockRadixSort(shared_data.temp_storage).SortBlockedToStriped(sa_p); + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int s = (int)(sa_p[i] >> 32UL); + shared_data.sample_shared_data.s.value[idx] = s; + int p = sa_p[i] & 0xFFFFFFFF; + shared_data.sample_shared_data.p.value[idx] = p; + if (idx < M) shared_data.sample_shared_data.q.value[p] = idx; + shared_data.sample_shared_data.chain.value[idx] = idx; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int si = shared_data.sample_shared_data.s.value[idx]; + int si1 = shared_data.sample_shared_data.s.value[idx + 1]; + if (idx < M && (idx == M - 1 || si != si1) && si >= N - M) { + shared_data.sample_shared_data.chain.value[N - si - 1] = + shared_data.sample_shared_data.p.value[idx]; + } + } + __syncthreads(); + for (int step = 0; step < log2_up_device(M); ++step) { +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + shared_data.sample_shared_data.last_chain_tmp.value[idx] = + shared_data.sample_shared_data.chain.value[idx]; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + if (idx < M) { + shared_data.sample_shared_data.chain.value[idx] = + shared_data.sample_shared_data.last_chain_tmp + .value[shared_data.sample_shared_data.last_chain_tmp.value[idx]]; + } + } + __syncthreads(); + } +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + shared_data.sample_shared_data.last_chain_tmp.value[idx] = + N - shared_data.sample_shared_data.chain.value[idx] - 1; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int ai; + if (idx < M) { + int qi = shared_data.sample_shared_data.q.value[idx]; + if (idx == 0 || qi == 0 || + shared_data.sample_shared_data.s.value[qi] != + shared_data.sample_shared_data.s.value[qi - 1]) { + ai = shared_data.sample_shared_data.s.value[qi]; + } else { + int prev_i = shared_data.sample_shared_data.p.value[qi - 1]; + ai = shared_data.sample_shared_data.last_chain_tmp.value[prev_i]; + } + sa_p[i] = ai; + } + } + // Output +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int ai = sa_p[i]; + if (idx < M) { + output_edge_gid_ptr[offset + idx] = (int64_t)(start + ai); + if (src_lid) src_lid[offset + idx] = (LocalIdType)input_idx; + } + } +} + +template +void wholegraph_csr_unweighted_sample_without_replacement_nccl_func( + wholememory_handle_t csr_row_wholememory_handle, + wholememory_handle_t csr_col_wholememory_handle, + wholememory_tensor_description_t wm_csr_row_ptr_desc, + wholememory_tensor_description_t wm_csr_col_ptr_desc, + void* center_nodes, + wholememory_array_description_t center_nodes_desc, + int max_sample_count, + void* output_sample_offset, + wholememory_array_description_t output_sample_offset_desc, + void* output_dest_memory_context, + void* output_center_localid_memory_context, + void* output_edge_gid_memory_context, + unsigned long long random_seed, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + int center_node_count = center_nodes_desc.size; + WHOLEMEMORY_EXPECTS(wm_csr_row_ptr_desc.dtype == WHOLEMEMORY_DT_INT64, + "wholegraph_csr_unweighted_sample_without_replacement_nccl_func(). " + "wm_csr_row_ptr_desc.dtype != WHOLEMEMORY_DT_INT64, " + "wm_csr_row_ptr_desc.dtype = %d", + wm_csr_row_ptr_desc.dtype); + + WHOLEMEMORY_EXPECTS(output_sample_offset_desc.dtype == WHOLEMEMORY_DT_INT, + "wholegraph_csr_unweighted_sample_without_replacement_nccl_func(). " + "output_sample_offset_desc.dtype != WHOLEMEMORY_DT_INT, " + "output_sample_offset_desc.dtype = %d", + output_sample_offset_desc.dtype); + + auto double_center_node_count = center_node_count + center_node_count; + wholememory_ops::temp_memory_handle center_nodes_buf(p_env_fns); + IdType* center_nodes_expandshift_one = static_cast( + center_nodes_buf.device_malloc(double_center_node_count, center_nodes_desc.dtype)); + // fill center_nodes_shift_one with [center_nodes, center_nodes+1] + wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); + thrust::counting_iterator iota(0); + thrust::for_each(thrust::cuda::par_nosync(thrust_allocator).on(stream), + iota, + iota + double_center_node_count, + ExpandWithOffsetFunc{ + (const IdType*)center_nodes, center_nodes_expandshift_one, center_node_count}); + + // gathering [rowoffsets, rowoffsets+1] + wholememory_ops::temp_memory_handle output_buf(p_env_fns); + int64_t* center_nodes_indptr = + static_cast(output_buf.device_malloc(double_center_node_count, WHOLEMEMORY_DT_INT64)); + wholememory_array_description_t center_nodes_expandshift_one_desc{ + double_center_node_count, 0, center_nodes_desc.dtype}; + wholememory_matrix_description_t center_nodes_indptr_desc{ + {double_center_node_count, 1}, 1, 0, wm_csr_row_ptr_desc.dtype}; + wholememory_matrix_description_t wm_csr_row_ptr_mat_desc; + wholememory_convert_tensor_desc_to_matrix(&wm_csr_row_ptr_mat_desc, &wm_csr_row_ptr_desc); + wholememory_ops::wholememory_gather_nccl(csr_row_wholememory_handle, + wm_csr_row_ptr_mat_desc, + center_nodes_expandshift_one, + center_nodes_expandshift_one_desc, + center_nodes_indptr, + center_nodes_indptr_desc, + p_env_fns, + stream, + -1); + // find the in_degree (subtraction) and sample count + // temporarily store sampled_csr_ptr_buf (# of degrees/samples per node) in int32; + // can be changed to int8_t/16_t later + wholememory_ops::temp_memory_handle sampled_csr_ptr_buf(p_env_fns); + int* in_degree = + static_cast(sampled_csr_ptr_buf.device_malloc(center_node_count + 1, WHOLEMEMORY_DT_INT)); + thrust::for_each( + thrust::cuda::par_nosync(thrust_allocator).on(stream), + iota, + iota + center_node_count, + ReduceForDegrees{center_nodes_indptr, in_degree, center_node_count}); + // prefix sum to get the output_sample_offset (depending on min(max_sample_count and in_degree)) + int sampled_count = max_sample_count <= 0 ? std::numeric_limits::max() : max_sample_count; + thrust::transform_exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), + in_degree, + in_degree + center_node_count + 1, + (int*)output_sample_offset, + MinInDegreeFanout{sampled_count}, + 0, + thrust::plus()); + // start local sampling + int count; + WM_CUDA_CHECK(cudaMemcpyAsync(&count, + ((int*)output_sample_offset) + center_node_count, + sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + int64_t* output_edge_gid_ptr = nullptr; + wholememory_ops::temp_memory_handle edge_gid_buffer_mh(p_env_fns); + if (output_edge_gid_memory_context) { + wholememory_ops::output_memory_handle gen_output_edge_gid_buffer_mh( + p_env_fns, output_edge_gid_memory_context); + output_edge_gid_ptr = + (int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); + } else { + output_edge_gid_ptr = (int64_t*)edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); + } + + wholememory_ops::output_memory_handle gen_output_dest_buffer_mh(p_env_fns, + output_dest_memory_context); + WMIdType* output_dest_node_ptr = + (WMIdType*)gen_output_dest_buffer_mh.device_malloc(count, wm_csr_col_ptr_desc.dtype); + + int* output_center_localid_ptr = nullptr; + if (output_center_localid_memory_context) { + wholememory_ops::output_memory_handle gen_output_center_localid_buffer_mh( + p_env_fns, output_center_localid_memory_context); + output_center_localid_ptr = + (int*)gen_output_center_localid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT); + } + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + { + typedef void (*unweighted_sample_func_type)( + const IdType* input_nodes, + const int64_t* center_nodes_indptr, + const int* in_degree, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + int* src_lid, + int64_t* output_edge_gid_ptr); + static const unweighted_sample_func_type func_array[32] = { + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel}; + static const int warp_count_array[32] = {1, 1, 1, 2, 2, 2, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}; + int func_idx = (max_sample_count - 1) / 32; + func_array[func_idx]<<>>( + (const IdType*)center_nodes, + (const int64_t*)center_nodes_indptr, + (const int*)in_degree, + center_node_count, + sampled_count, + rngstate, + (const int*)output_sample_offset, + output_sample_offset_desc, + (int*)output_center_localid_ptr, + (int64_t*)output_edge_gid_ptr); + + wholememory_matrix_description_t wm_csr_col_ptr_mat_desc; + wholememory_matrix_description_t output_dest_node_ptr_desc{ + {count, 1}, 1, 0, wm_csr_col_ptr_desc.dtype}; + wholememory_array_description_t output_edge_gid_ptr_desc{count, 0, WHOLEMEMORY_DT_INT64}; + wholememory_convert_tensor_desc_to_matrix(&wm_csr_col_ptr_mat_desc, &wm_csr_col_ptr_desc); + wholememory_ops::wholememory_gather_nccl(csr_col_wholememory_handle, + wm_csr_col_ptr_mat_desc, + output_edge_gid_ptr, + output_edge_gid_ptr_desc, + output_dest_node_ptr, + output_dest_node_ptr_desc, + p_env_fns, + stream, + -1); + } + WM_CUDA_CHECK(cudaGetLastError()); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); +} +} // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index de75d7394..057d4c0c4 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -462,7 +462,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( // prefix sum wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), tmp_sample_count_mem_pointer, tmp_sample_count_mem_pointer + center_node_count + 1, static_cast(output_sample_offset)); @@ -500,7 +500,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count > sample_count_threshold) { wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(tmp_thrust_allocator).on(stream), tmp_neighbor_counts_mem_pointer, tmp_neighbor_counts_mem_pointer + center_node_count + 1, tmp_neighbor_counts_mem_pointer); diff --git a/cpp/src/wholememory/file_io.cpp b/cpp/src/wholememory/file_io.cpp index 0a627eed2..31b87c144 100644 --- a/cpp/src/wholememory/file_io.cpp +++ b/cpp/src/wholememory/file_io.cpp @@ -15,12 +15,17 @@ */ #include "file_io.h" +#include +#include + #include +#include #include #include #include #include +#include #include #include "communicator.hpp" @@ -92,7 +97,7 @@ static size_t get_handle_partial_size(size_t handle_size, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin(char* local_ptr, size_t local_size, @@ -384,6 +389,476 @@ static void read_file_list_to_local_memory(char* local_ptr, "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes); } +/*! + * Read from file list to local memory of WholeMemory. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. WholeMemory will use round-robin sharding + * strategy. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_roundrobin_with_multi_threads( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int round_robin_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + size_t buffer_size; + size_t buffer_entry_count = 1; + if (suggested_buffer_size < entry_size) { + buffer_size = entry_size; + } else { + buffer_entry_count = suggested_buffer_size / entry_size; + buffer_size = buffer_entry_count * entry_size; + } + + std::atomic_size_t total_read_entry = 0; + + if (memory_offset >= memory_entry_stride) + WHOLEMEMORY_ERROR("memory offset %lu should be less than memory entry stride %lu.", + memory_offset, + memory_entry_stride); + size_t total_file_sizes = 0; + for (int i = 0; i < file_count; i++) + total_file_sizes += file_sizes[i]; + size_t total_file_entry_count = total_file_sizes / entry_size; + if (round_robin_size <= 0 || round_robin_size > total_file_entry_count / wm_world_size) + WHOLEMEMORY_ERROR("illegal round_robin_size."); + + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + size_t local_entry_memory_start_index = wm_rank * round_robin_size; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + int extra_entry = total_file_entry_count % (wm_world_size * round_robin_size); + int local_extra_entry = (extra_entry > (wm_rank + 1) * round_robin_size) + ? round_robin_size + : extra_entry - wm_rank * round_robin_size; + local_extra_entry = local_extra_entry > 0 ? local_extra_entry : 0; + size_t local_entry_count = + total_file_entry_count / (wm_world_size * round_robin_size) * round_robin_size; + + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + + int64_t local_round_robin_count = local_entry_count / round_robin_size; + + auto read_file_thread_fun = [=, &total_read_entry](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + std::vector file_read_buffer(buffer_size); + + int64_t round_robin_count_per_thread = (local_round_robin_count + thread_num - 1) / thread_num; + int64_t round_robin_count_this_thread = + std::max(0L, + std::min(round_robin_count_per_thread, + local_round_robin_count - round_robin_count_per_thread * thread_id)); + int64_t local_entry_count_this_thread = round_robin_count_this_thread * round_robin_size; + if (thread_id == thread_num - 1) { + // last thread + local_entry_count_this_thread += local_extra_entry; + } + + if (local_entry_count_this_thread == 0) return; + int64_t start_round_robin_id_in_local = thread_id * round_robin_count_per_thread; + + if (round_robin_count_this_thread == 0) { + // last thread + if (round_robin_count_per_thread != 1) { + WHOLEMEMORY_ERROR("round_robin_count_per_thread should be 1,but get %d \n", + round_robin_count_per_thread); + } + start_round_robin_id_in_local = local_round_robin_count; + } + + size_t local_entry_file_start_index_this_thread = + local_entry_file_start_index + + start_round_robin_id_in_local * wm_world_size * round_robin_size; + char* this_thread_write_ptr = + local_write_ptr + start_round_robin_id_in_local * round_robin_size * memory_entry_stride; + + size_t total_read_entry_this_thread = 0; + size_t next_entry_gap = local_entry_file_start_index_this_thread; + size_t next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + size_t read_file_begin_entry_off = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + if (file_entry_count <= next_entry_gap) { + next_entry_gap -= file_entry_count; + continue; + } + size_t read_size_from_cur_file = 0; + read_file_begin_entry_off = 0; + //$open file get fp + FILE* fp = fopen(file_names[i], "rb"); + if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } + /*|***read_file_begin_entry_off***|***entry_gap***|***cur_file_read_entry_count***|******|*/ + + while (read_file_begin_entry_off < file_entry_count) { + //$fseek by remain_entry_gap + if (read_file_begin_entry_off + next_entry_gap >= file_entry_count) { + next_entry_gap = (read_file_begin_entry_off + next_entry_gap) - file_entry_count; + break; + } + size_t file_read_start_offset = next_entry_gap * entry_size; + if (fseeko(fp, file_read_start_offset, SEEK_CUR) != 0) { + WHOLEMEMORY_ERROR("File %s seek to %ld failed.", file_names[i], file_read_start_offset); + } + + size_t cur_file_read_entry_count; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + cur_file_read_entry_count = file_entry_count - read_file_begin_entry_off - next_entry_gap; + total_read_entry_this_thread += cur_file_read_entry_count; + read_file_begin_entry_off = file_entry_count; + next_continuous_entry_count -= cur_file_read_entry_count; + next_entry_gap = 0; + } else { + cur_file_read_entry_count = next_continuous_entry_count; + total_read_entry_this_thread += cur_file_read_entry_count; + read_file_begin_entry_off += cur_file_read_entry_count + next_entry_gap; + next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + next_entry_gap = (wm_world_size - 1) * round_robin_size; + } + read_size_from_cur_file += cur_file_read_entry_count * entry_size; + // read cur_file_read_entry_count of embeddings + size_t cur_file_read_entry = cur_file_read_entry_count; + while (cur_file_read_entry_count > 0) { + size_t read_entry_count = std::min(cur_file_read_entry_count, buffer_entry_count); + int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); + if (ret != read_entry_count) { + WHOLEMEMORY_ERROR( + "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " + "returned %d, error=%s\n", + __FILE__, + __LINE__, + file_names[i], + read_entry_count, + entry_size, + ret, + strerror(errno)); + } + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(this_thread_write_ptr, + memory_entry_stride, + file_read_buffer.data(), + entry_size, + entry_size, + read_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy(this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count * entry_size, + cudaMemcpyDefault)); + } + this_thread_write_ptr += read_entry_count * memory_entry_stride; + cur_file_read_entry_count -= read_entry_count; + } + if (total_read_entry_this_thread > local_entry_count_this_thread) { + WHOLEMEMORY_ERROR( + "file read error from rank %d, thread_id %d, should read %lu entries, infact %lu " + "entries.", + wm_rank, + thread_id, + local_entry_count, + local_entry_count_this_thread); + break; + } else if (total_read_entry_this_thread == local_entry_count_this_thread) { + break; + } + } + + fclose(fp); + WHOLEMEMORY_INFO("Rank=%d thread_id=%d ,done Reading %ld bytes from file %s size=%ld", + wm_rank, + thread_id, + read_size_from_cur_file, + file_names[i], + file_sizes[i]); + + if (total_read_entry_this_thread == local_entry_count_this_thread) break; + } + total_read_entry.fetch_add(total_read_entry_this_thread); + }; + + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + + WHOLEMEMORY_INFO("Rank=%d done Reading %ld entries, infact read %ld entries", + wm_rank, + total_read_entry.load(), + local_entry_count); +}; + +/*! + * Read from file list to local memory of WholeMemory. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + size_t buffer_size; + size_t buffer_entry_count = 1; + if (suggested_buffer_size < entry_size) { + buffer_size = entry_size; + } else { + buffer_entry_count = suggested_buffer_size / entry_size; + buffer_size = buffer_entry_count * entry_size; + } + + size_t local_entry_memory_start_index = local_offset / memory_entry_stride; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + std::atomic_size_t total_read_bytes = 0; + + auto read_file_thread_fun = [=, &total_read_bytes](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + const size_t entry_count_per_thread = (local_entry_count + thread_num - 1) / thread_num; + const size_t entry_count_this_thread = + std::min(entry_count_per_thread, local_entry_count - entry_count_per_thread * thread_id); + const size_t entry_file_start_index_this_thread = + local_entry_file_start_index + thread_id * entry_count_per_thread; + char* this_thread_write_ptr = + local_write_ptr + entry_count_per_thread * thread_id * memory_entry_stride; + + std::vector file_read_buffer(buffer_size); + + if (entry_count_this_thread <= 0) return; + size_t file_entry_offset = 0; + size_t read_size_this_thread = 0; + + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= (entry_file_start_index_this_thread + entry_count_this_thread)) + break; + + // in reading window + if (file_entry_offset + file_entry_count > entry_file_start_index_this_thread) { + size_t file_read_start_offset = 0; + FILE* fp = fopen(file_names[i], "rb"); + if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = std::min( + file_entry_count, + entry_file_start_index_this_thread + entry_count_this_thread - file_entry_offset); + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + if (file_entry_offset < entry_file_start_index_this_thread) { + size_t skip_entry_count = entry_file_start_index_this_thread - file_entry_offset; + + file_read_start_offset = skip_entry_count * entry_size; + + if (fseeko(fp, file_read_start_offset, SEEK_SET) != 0) { + WHOLEMEMORY_ERROR( + "File %s seek to %ld failed.", file_names[i], skip_entry_count * entry_size); + } + to_read_file_entry_count -= skip_entry_count; + } + // now all data in file_entry_count need to be read. + size_t bytes_to_read = to_read_file_entry_count * entry_size; + size_t left_entry_count = to_read_file_entry_count; + while (left_entry_count > 0) { + size_t read_entry_count = std::min(left_entry_count, buffer_entry_count); + + int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); + if (ret != read_entry_count) { + WHOLEMEMORY_ERROR( + "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " + "returned %d, error=%s\n", + __FILE__, + __LINE__, + file_names[i], + read_entry_count, + entry_size, + ret, + strerror(errno)); + } + + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(this_thread_write_ptr, + memory_entry_stride, + file_read_buffer.data(), + entry_size, + entry_size, + read_entry_count, + cudaMemcpyDefault)); + } else { + WHOLEMEMORY_INFO( + "Rank:%d, threadid:%d, cuda Memcpy : this_thread_write_ptr:%p, " + "file_read_buffer.data():%p, read_entry_count:%d, entry_size:%d\n", + wm_rank, + thread_id, + this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count, + entry_size); + WM_CUDA_CHECK(cudaMemcpy(this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count * entry_size, + cudaMemcpyDefault)); + } + this_thread_write_ptr += read_entry_count * memory_entry_stride; + + left_entry_count -= read_entry_count; + } + fclose(fp); + WHOLEMEMORY_INFO( + "Rank=%d thread_id=%d done Reading %ld bytes from file %s size=%ld, starting from " + "offset=%ld.", + wm_rank, + thread_id, + bytes_to_read, + file_names[i], + file_sizes[i], + file_read_start_offset); + total_read_bytes.fetch_add(bytes_to_read); + read_size_this_thread += bytes_to_read; + } + + file_entry_offset += file_entry_count; + } + + WHOLEMEMORY_INFO("Rank=%d thread_id=%d done Reading %ld bytes from needed files size.", + wm_rank, + thread_id, + read_size_this_thread); + }; + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + WHOLEMEMORY_INFO( + "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes.load()); +} + /*! * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better * performance by bypassing system cache if it is bottleneck. File list are binary files, which are @@ -403,7 +878,7 @@ static void read_file_list_to_local_memory(char* local_ptr, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin_directio( char* local_ptr, @@ -475,15 +950,7 @@ static void read_file_list_to_local_memory_roundrobin_directio( WHOLEMEMORY_FAIL_NOTHROW( "block_size=%ld for file %s, but alignment is %ld", block_size, file_names[i], kAlignSize); } - /*if ((round_robin_size * entry_size) % block_size != 0) { - WHOLEMEMORY_FAIL_NOTHROW( - "per rank round-robin size (%d x %ld) is not mutiple of block_size (%d)for file %d.", - round_robin_size, - entry_size, - block_size, - i - ); - }*/ + size_t buffer_block_count = suggested_buffer_size / block_size; int fd = open(file_names[i], O_DIRECT | O_RDONLY); if (fd < 0) { WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); } @@ -813,6 +1280,583 @@ static void read_file_list_to_local_memory_directio(char* local_ptr, free(block_buffer); } +/*! + * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better + * performance by bypassing system cache if it is bottleneck. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_directio_with_multi_thread( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int dev_id) +{ + if (memory_offset + entry_size > memory_entry_stride) { + WHOLEMEMORY_FAIL_NOTHROW("Direct io mode only support reading all entries."); + } + size_t local_entry_start_index = local_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + static size_t kAlignSize = 16 * 1024 * 1024; + suggested_buffer_size = round_up_unsafe(suggested_buffer_size, kAlignSize); + + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + + auto read_file_thread_fun = [=](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + + char* block_buffer; + WHOLEMEMORY_CHECK_NOTHROW(posix_memalign(reinterpret_cast(&block_buffer), + kAlignSize, + suggested_buffer_size) == 0); + + const size_t entry_count_per_thread = (local_entry_count + thread_num - 1) / thread_num; + const size_t entry_count_this_thread = + std::min(entry_count_per_thread, local_entry_count - entry_count_per_thread * thread_id); + const size_t entry_file_start_index_this_thread = + local_entry_start_index + thread_id * entry_count_per_thread; + const size_t this_thread_entry_start_index = entry_file_start_index_this_thread; + char* this_thread_write_ptr = + local_write_ptr + entry_count_per_thread * thread_id * memory_entry_stride; + + if (entry_count_this_thread <= 0) return; + + size_t file_entry_offset = 0; + size_t read_entry_count = 0; + + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= this_thread_entry_start_index + entry_count_this_thread) break; + // reading window not reached + if (file_entry_offset + file_entry_count <= this_thread_entry_start_index) { + file_entry_offset += file_entry_count; + continue; + } + // in reading window + auto block_size = StatFileBlockSize(file_names[i]); + if (block_size == 0 || block_size == (size_t)-1 || kAlignSize % block_size != 0) { + WHOLEMEMORY_FAIL_NOTHROW("block_size=%ld for file %s, but alignment is %ld", + block_size, + file_names[i], + kAlignSize); + } + size_t buffer_block_count = suggested_buffer_size / block_size; + int fd = open(file_names[i], O_DIRECT | O_RDONLY); + if (fd < 0) { + WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); + } + + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = + std::min(file_entry_count, + this_thread_entry_start_index + entry_count_this_thread - file_entry_offset); + + size_t file_read_end = to_read_file_entry_count * entry_size; + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + size_t file_read_start = 0; + if (file_entry_offset < this_thread_entry_start_index) { + size_t skip_entry_count = this_thread_entry_start_index - file_entry_offset; + to_read_file_entry_count -= skip_entry_count; + file_read_start = skip_entry_count * entry_size; + } + + size_t file_block_read_offset = file_read_start / block_size * block_size; + size_t skip_head_size = file_read_start - file_block_read_offset; + + char* local_mem_write_entry_for_file = + this_thread_write_ptr + read_entry_count * memory_entry_stride; + size_t first_mem_entry_offset = 0; + size_t useful_data_bytes_read = 0; + size_t physical_data_bytes_read = 0; + while (file_block_read_offset < file_read_end) { + size_t left_size = file_read_end - file_block_read_offset; + size_t left_block_count = div_rounding_up_unsafe(left_size, block_size); + size_t read_block_count = std::min(left_block_count, buffer_block_count); + size_t physical_read_size = read_block_count * block_size; + physical_data_bytes_read += physical_read_size; + + ssize_t pread_size = pread64(fd, block_buffer, physical_read_size, file_block_read_offset); + if (pread_size != physical_read_size && + file_block_read_offset + pread_size != file_sizes[i]) { + WHOLEMEMORY_FAIL_NOTHROW( + "rank=%d, pread_size=%ld, physical_read_size=%ld, file_block_read_offset=%ld, " + "file_sizes[i]=%ld, file=%s", + wm_rank, + pread_size, + physical_read_size, + file_block_read_offset, + file_sizes[i], + file_names[i]); + } + physical_read_size = pread_size; + + size_t drop_tail_size = 0; + if (file_block_read_offset + physical_read_size > file_read_end) { + drop_tail_size = file_block_read_offset + physical_read_size - file_read_end; + } + + char* useful_data_ptr = block_buffer + skip_head_size; + size_t useful_data_size = physical_read_size - skip_head_size - drop_tail_size; + + useful_data_bytes_read += useful_data_size; + + if (first_mem_entry_offset != 0) { + // process head + size_t entry_left_size = entry_size - first_mem_entry_offset; + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(local_mem_write_entry_for_file + first_mem_entry_offset, + useful_data_ptr, + entry_left_size, + cudaMemcpyDefault)); + local_mem_write_entry_for_file += memory_entry_stride; + useful_data_ptr += entry_left_size; + useful_data_size -= entry_left_size; + entry_left_size = 0; + } + + size_t full_entry_count = useful_data_size / entry_size; + size_t full_entry_size = full_entry_count * entry_size; + + if (full_entry_size > 0) { + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_mem_write_entry_for_file, + memory_entry_stride, + useful_data_ptr, + entry_size, + entry_size, + full_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, full_entry_size, cudaMemcpyDefault)); + } + local_mem_write_entry_for_file += memory_entry_stride * full_entry_count; + useful_data_ptr += full_entry_size; + useful_data_size -= full_entry_size; + } + + size_t tail_entry_size = useful_data_size % entry_size; + first_mem_entry_offset = tail_entry_size; + + if (tail_entry_size != 0) { + // process tail + WM_CUDA_CHECK_NO_THROW(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, tail_entry_size, cudaMemcpyDefault)); + // first_mem_entry_offset = tail_entry_size; + } + + file_block_read_offset += physical_read_size; + skip_head_size = 0; + } + + WHOLEMEMORY_INFO( + "Rank=%d threadid=%d done Reading %ld useful bytes by reading %ld block bytes using " + "DirectIO from file " + "%s size=%ld.", + wm_rank, + thread_id, + useful_data_bytes_read, + physical_data_bytes_read, + file_names[i], + file_sizes[i]); + + close(fd); + file_entry_offset += file_entry_count; + read_entry_count += to_read_file_entry_count; + } + free(block_buffer); + }; + + if (threads_per_rank != 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } +} + +/*! + * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better + * performance by bypassing system cache if it is bottleneck. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. Wholememory uses round-robin sharding strategy + * here. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int round_robin_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + + if (memory_offset + entry_size > memory_entry_stride) + WHOLEMEMORY_FAIL_NOTHROW("Direct io mode only support reading all entries."); + + static size_t kAlignSize = 16 * 1024 * 1024; + suggested_buffer_size = round_up_unsafe(suggested_buffer_size, kAlignSize); + + size_t total_file_sizes = 0; + for (int i = 0; i < file_count; i++) + total_file_sizes += file_sizes[i]; + size_t total_file_entry_count = total_file_sizes / entry_size; + if (round_robin_size <= 0 || round_robin_size > total_file_entry_count / wm_world_size) + WHOLEMEMORY_ERROR("illegal round_robin_size."); + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + size_t local_entry_memory_start_index = wm_rank * round_robin_size; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + + int extra_entry = total_file_entry_count % (wm_world_size * round_robin_size); + int local_extra_entry = (extra_entry > (wm_rank + 1) * round_robin_size) + ? round_robin_size + : extra_entry - wm_rank * round_robin_size; + local_extra_entry = local_extra_entry > 0 ? local_extra_entry : 0; + size_t local_entry_count = + total_file_entry_count / (wm_world_size * round_robin_size) * round_robin_size; + std::atomic_size_t total_read_entry = 0; + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + + int64_t local_round_robin_count = local_entry_count / round_robin_size; + + auto read_file_thread_fun = [=, &total_read_entry](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + + char* block_buffer; + WHOLEMEMORY_CHECK_NOTHROW(posix_memalign(reinterpret_cast(&block_buffer), + kAlignSize, + suggested_buffer_size) == 0); + int64_t round_robin_count_per_thread = (local_round_robin_count + thread_num - 1) / thread_num; + int64_t round_robin_count_this_thread = + std::max(0L, + std::min(round_robin_count_per_thread, + local_round_robin_count - round_robin_count_per_thread * thread_id)); + int64_t local_entry_count_this_thread = round_robin_count_this_thread * round_robin_size; + if (thread_id == thread_num - 1) { + // last thread + local_entry_count_this_thread += local_extra_entry; + } + + if (local_entry_count_this_thread == 0) return; + int64_t start_round_robin_id_in_local = thread_id * round_robin_count_per_thread; + + if (round_robin_count_this_thread == 0) { + // last thread + if (round_robin_count_per_thread != 1) { + WHOLEMEMORY_ERROR("round_robin_count_per_thread should be 1,but get %d \n", + round_robin_count_per_thread); + } + start_round_robin_id_in_local = local_round_robin_count; + } + + size_t local_entry_file_start_index_this_thread = + local_entry_file_start_index + + start_round_robin_id_in_local * wm_world_size * round_robin_size; + char* this_thread_write_ptr = + local_write_ptr + start_round_robin_id_in_local * round_robin_size * memory_entry_stride; + + size_t total_read_entry_this_thread = 0; + + size_t next_entry_gap = local_entry_file_start_index_this_thread; + size_t next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + size_t read_file_begin_entry_off = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + if (file_entry_count <= next_entry_gap) { + next_entry_gap -= file_entry_count; + continue; + } + + auto block_size = StatFileBlockSize(file_names[i]); + if (block_size == 0 || block_size == (size_t)-1 || kAlignSize % block_size != 0) { + WHOLEMEMORY_FAIL_NOTHROW("block_size=%ld for file %s, but alignment is %ld", + block_size, + file_names[i], + kAlignSize); + } + + size_t buffer_block_count = suggested_buffer_size / block_size; + int fd = open(file_names[i], O_DIRECT | O_RDONLY); + if (fd < 0) { + WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); + } + + size_t read_size_from_cur_file = 0; + size_t useful_data_bytes_read = 0; + read_file_begin_entry_off = 0; + + /*|***read_file_begin_entry_off***|***entry_gap***|***cur_file_read_entry_count***|******|*/ + while (read_file_begin_entry_off < file_entry_count) { + if (read_file_begin_entry_off + next_entry_gap >= file_entry_count) { + next_entry_gap = (read_file_begin_entry_off + next_entry_gap) - file_entry_count; + break; + } + size_t cur_file_read_entry_count; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + cur_file_read_entry_count = file_entry_count - read_file_begin_entry_off - next_entry_gap; + } else { + cur_file_read_entry_count = next_continuous_entry_count; + } + + // read concerned vars + size_t cur_read_entry_start = read_file_begin_entry_off + next_entry_gap; + size_t cur_read_byte_start = (cur_read_entry_start * entry_size) / block_size * block_size; + size_t cur_read_byte_end = (cur_read_entry_start + cur_file_read_entry_count) * entry_size; + size_t skip_head_size = cur_read_entry_start * entry_size - cur_read_byte_start; + // write concerned vars + char* local_mem_write_entry_for_file = + this_thread_write_ptr + total_read_entry_this_thread * memory_entry_stride; + size_t first_mem_entry_offset = 0; + + while (cur_read_byte_start < cur_read_byte_end) { + size_t left_size = cur_read_byte_end - cur_read_byte_start; + size_t left_block_count = div_rounding_up_unsafe(left_size, block_size); + size_t read_block_count = std::min(left_block_count, buffer_block_count); + size_t physical_read_size = read_block_count * block_size; + // physical_data_bytes_read += physical_read_size; + read_size_from_cur_file += physical_read_size; + + ssize_t pread_size = pread64(fd, block_buffer, physical_read_size, cur_read_byte_start); + if (pread_size != physical_read_size && + cur_read_byte_start + pread_size != file_sizes[i]) { + WHOLEMEMORY_FAIL_NOTHROW( + "rank=%d, pread_size=%ld, physical_read_size=%ld, file_block_read_offset=%ld, " + "file_sizes[i]=%ld, file=%s", + wm_rank, + pread_size, + physical_read_size, + cur_read_byte_start, + file_sizes[i], + file_names[i]); + } + physical_read_size = pread_size; + size_t drop_tail_size = 0; + if (cur_read_byte_start + physical_read_size > cur_read_byte_end) { + drop_tail_size = cur_read_byte_start + physical_read_size - cur_read_byte_end; + } + + char* useful_data_ptr = block_buffer + skip_head_size; + size_t useful_data_size = physical_read_size - skip_head_size - drop_tail_size; + useful_data_bytes_read += useful_data_size; + + if (first_mem_entry_offset != 0) { + size_t entry_left_size = entry_size - first_mem_entry_offset; + WM_CUDA_CHECK_NO_THROW( + cudaMemcpy(local_mem_write_entry_for_file + first_mem_entry_offset, + useful_data_ptr, + entry_left_size, + cudaMemcpyDefault)); + local_mem_write_entry_for_file += memory_entry_stride; + useful_data_ptr += entry_left_size; + useful_data_size -= entry_left_size; + entry_left_size = 0; + } + + size_t full_entry_count = useful_data_size / entry_size; + size_t full_entry_size = full_entry_count * entry_size; + + if (full_entry_size > 0) { + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_mem_write_entry_for_file, + memory_entry_stride, + useful_data_ptr, + entry_size, + entry_size, + full_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy(local_mem_write_entry_for_file, + useful_data_ptr, + full_entry_size, + cudaMemcpyDefault)); + } + local_mem_write_entry_for_file += memory_entry_stride * full_entry_count; + useful_data_ptr += full_entry_size; + useful_data_size -= full_entry_size; + } + + size_t tail_entry_size = useful_data_size % entry_size; + first_mem_entry_offset = tail_entry_size; + if (tail_entry_size != 0) { + // process tail + WM_CUDA_CHECK_NO_THROW(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, tail_entry_size, cudaMemcpyDefault)); + } + + cur_read_byte_start += physical_read_size; + skip_head_size = 0; + } + + total_read_entry_this_thread += cur_file_read_entry_count; + // read_size_from_cur_file += cur_file_read_entry_count * entry_size; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + read_file_begin_entry_off = file_entry_count; + next_continuous_entry_count -= cur_file_read_entry_count; + next_entry_gap = 0; + } else { + read_file_begin_entry_off += cur_file_read_entry_count + next_entry_gap; + next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + next_entry_gap = (wm_world_size - 1) * round_robin_size; + } + if (total_read_entry_this_thread > local_entry_count_this_thread) { + WHOLEMEMORY_ERROR( + "file read error from rank %d, thread_id=%d should read %lu entries, infact %lu " + "entries.", + wm_rank, + thread_id, + local_entry_count_this_thread, + total_read_entry_this_thread); + break; + } else if (total_read_entry_this_thread == local_entry_count_this_thread) { + break; + } + } + close(fd); + WHOLEMEMORY_INFO( + "Rank=%d thread_id=%d done Reading useful %ld bytes by totally reading %ld bytes from " + "file %s size=%ld " + "using direct IO", + wm_rank, + thread_id, + useful_data_bytes_read, + read_size_from_cur_file, + file_names[i], + file_sizes[i]); + if (total_read_entry_this_thread == local_entry_count_this_thread) break; + } + total_read_entry.fetch_add(total_read_entry_this_thread); + }; + + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + + WHOLEMEMORY_INFO("Rank=%d done Reading %ld entries, infact read %ld entries", + wm_rank, + total_read_entry.load(), + local_entry_count); +} + wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_handle, size_t memory_offset, size_t memory_entry_stride, @@ -935,59 +1979,65 @@ wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_ha } if (!use_direct_io) { if (round_robin_size == 0) { - read_file_list_to_local_memory(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank); + read_file_list_to_local_memory_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + wm_comm->dev_id); } else { - read_file_list_to_local_memory_roundrobin(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank, - wm_world_size, - round_robin_size); + read_file_list_to_local_memory_roundrobin_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + round_robin_size, + wm_comm->dev_id); } } else { if (round_robin_size == 0) { - read_file_list_to_local_memory_directio(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank); + read_file_list_to_local_memory_directio_with_multi_thread(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + wm_comm->dev_id); } else { - read_file_list_to_local_memory_roundrobin_directio(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank, - wm_world_size, - round_robin_size); + read_file_list_to_local_memory_roundrobin_directio_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + round_robin_size, + wm_comm->dev_id); } } diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 8024b461a..ca8b0ad75 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,12 @@ class wholememory_impl { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size; if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset; + if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) && + (!(comm_->is_intranode()))) { + WHOLEMEMORY_WARN( + " Multi-node continuous type wholememory can only be accessed by GPU threads but not CPU " + "threads, regardless of whether the location of wholememory is host."); + } } virtual bool get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 59dcc89bb..180da2f01 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -261,6 +261,11 @@ wholememory_error_code_t wholememory_load_from_hdfs_file(wholememory_handle_t wh return WHOLEMEMORY_NOT_IMPLEMENTED; } +bool wholememory_is_intranode_communicator(wholememory_comm_t comm) +{ + return wholememory::is_intranode_communicator(comm); +} + bool wholememory_is_build_with_nvshmem() { #ifdef WITH_NVSHMEM_SUPPORT diff --git a/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu index 7cb96bcb4..88d7f331c 100644 --- a/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,7 +126,7 @@ void dedup_indice_and_gradients_temp_func(int64_t* run_count, int* dev_mapping_sequence = static_cast(mapping_sequence_handle.device_malloc(raw_count * 2, WHOLEMEMORY_DT_INT)); int* dev_indice_mapping = dev_mapping_sequence + raw_count; - thrust::sequence(thrust::cuda::par(allocator).on(stream), + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), dev_mapping_sequence, dev_mapping_sequence + raw_count, 0); diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu index 53df31be0..137b10470 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,8 +59,10 @@ void exchange_ids_temp_func(const void* indices_before_sort, int64_t* seq_indices = reinterpret_cast(allocator.allocate( wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(int64_t))); - thrust::sequence( - thrust::cuda::par(allocator).on(stream), seq_indices, seq_indices + indices_desc.size, 0); + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), + seq_indices, + seq_indices + indices_desc.size, + 0); // use UTypeT to put minus indices at last. using UTypeT = typename UnsignedType::UType; const UTypeT* indices_to_sort = static_cast(indices_before_sort); diff --git a/cpp/src/wholememory_ops/functions/gather_func.cu b/cpp/src/wholememory_ops/functions/gather_func.cu index 0b79f0f15..271245d78 100644 --- a/cpp/src/wholememory_ops/functions/gather_func.cu +++ b/cpp/src/wholememory_ops/functions/gather_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -76,6 +84,75 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, void* indices, wholememory_array_description_t, + bool, + void*, + void*, + wholememory_matrix_description_t, + cudaStream_t, + int) = nullptr; + if (embedding_is_float) { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_floating_int32_func; + } else { + p_gather_func = gather_floating_int64_func; + } + } else { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_integer_int32_func; + } else { + p_gather_func = gather_integer_int64_func; + } + } + return p_gather_func(embedding_gref, + embedding_desc, + indices, + indices_desc, + false, + nullptr, + output, + output_desc, + stream, + gather_sms); + } catch (const wholememory::cuda_error& rle) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (const wholememory::logic_error& le) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_LOGIC_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms) +{ + try { + bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype); + WHOLEMEMORY_CHECK(embedding_is_float || + wholememory_dtype_is_integer_number(embedding_desc.dtype)); + bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype); + WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype)); + WHOLEMEMORY_EXPECTS( + embedding_is_float == output_is_float, + "embedding and output should be same number type, e.g. floating number or integer number."); + if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size); + WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype); + wholememory_error_code_t (*p_gather_func)(wholememory_gref_t, + wholememory_matrix_description_t, + void* indices, + wholememory_array_description_t, + bool, + void*, void*, wholememory_matrix_description_t, cudaStream_t, @@ -97,6 +174,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, embedding_desc, indices, indices_desc, + true, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu index c7679c508..a67ac0040 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu index af9d6d6ec..159aaf9a6 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu index bdb7c0be8..9943cb14b 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu index 6a6c7f330..b06ebad9f 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 87c89d9c2..a4979f7be 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -255,6 +255,8 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, const IndexT* indices, int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -284,7 +286,9 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, for (int64_t output_idx = warp_id; output_idx < indice_count; output_idx += gridDim.x * (blockDim.x / 32)) { - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; if (!use_shm) { my_shared = output_ptr; } int64_t embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; @@ -309,11 +313,73 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, return; } +template +struct IsPowerOfTwo { + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + +template +__global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + const IndexT* indices, + int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, + OutputT* output, + wholememory_matrix_description_t output_desc) +{ + static_assert(IsPowerOfTwo::value && SUB_WARP_SIZE < 32, + "SUB_WARP_SIZE must be the power of 2,and smaller than 32."); + + auto block = cooperative_groups::this_thread_block(); + + auto subwarp = cooperative_groups::tiled_partition(block); + int sub_warp_id = subwarp.meta_group_size() * blockIdx.x + subwarp.meta_group_rank(); + int sub_warp_num = subwarp.meta_group_size() * gridDim.x; + + int lane_id_in_sub_warp = subwarp.thread_rank(); + wholememory::device_reference embedding_dev_ref(embedding_gref); + + int embedding_size = embedding_desc.sizes[1]; + int64_t embedding_stride = embedding_desc.stride; + int64_t output_stride = output_desc.stride; + + typed_data_vector embeddings; + typed_data_vector outputs; + for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; + IndexT embedding_table_idx = indices[output_idx]; + if (embedding_table_idx < 0) continue; + int64_t embedding_offset = + embedding_desc.storage_offset + embedding_table_idx * embedding_stride; + + for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; + emb_idx += ALIGNMENT * SUB_WARP_SIZE) { + mov_data(&embeddings, + &embedding_dev_ref[embedding_offset + emb_idx]); +#pragma unroll + for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { + typed_data_vector_at(outputs, sub_idx) = + convert_type(typed_data_vector_at(embeddings, sub_idx)); + } + mov_data(output_ptr + emb_idx, &outputs); + } + } +} + template void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -336,8 +402,11 @@ void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, const IndexT*, int64_t, + bool, + const IndexT*, OutputT*, wholememory_matrix_description_t) = nullptr; + switch (alignment) { case 16: { kernel_fn = gather_func_kernel; @@ -367,10 +436,79 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int block_size = 1024; int block_count = indice_count > 1568 ? 1568 : indice_count; if (gather_sms != -1) block_count = gather_sms; + + // for small embedding size ,use subwarp to gather + int min_threads_per_embedding = embedding_desc.sizes[1] / alignment; + if (min_threads_per_embedding < 32) { +#define SWITCH_GATHER_FUNC_WITH_ALIGNMENT(KERNEL_NAME, SUB_WARP_SIZE) \ + switch (alignment) { \ + case 16: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 8: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 4: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 2: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 1: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + default: { \ + WHOLEMEMORY_FAIL("gather func alignment=%d.", alignment); \ + return; \ + } \ + } + + int threads_per_embedding = 16; + if (min_threads_per_embedding >= 16) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 16); + threads_per_embedding = 16; + } else if (min_threads_per_embedding < 16 && min_threads_per_embedding >= 8) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 8); + threads_per_embedding = 8; + } else if (min_threads_per_embedding < 8 && min_threads_per_embedding >= 4) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 4); + threads_per_embedding = 4; + } else if (min_threads_per_embedding < 4 && min_threads_per_embedding >= 2) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 2); + threads_per_embedding = 2; + } else { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 1); + threads_per_embedding = 1; + } + +#undef SWITCH_GATHER_FUNC_WITH_ALIGNMENT + block_size = 128; + int max_blocks_per_sm = 8; + WM_CUDA_CHECK( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel_fn, block_size, 0)); + + int sm_count = 100; + int device_id = 0; + WM_CUDA_CHECK(cudaGetDevice(&device_id)); + WM_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + + // block_count = indice_count > 1568 ? 1568 : indice_count; + int min_embedding_per_block = block_size / threads_per_embedding; + block_count = min((int)(indice_count + min_embedding_per_block - 1) / min_embedding_per_block, + sm_count * max_blocks_per_sm * 4); + if (gather_sms != -1) block_count = gather_sms * max_blocks_per_sm; + } kernel_fn<<>>(embedding_gref, embedding_desc, static_cast(indices), indice_count, + gather_with_sorted_ids, + static_cast(raw_indices), static_cast(output), output_desc); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.h b/cpp/src/wholememory_ops/functions/gather_scatter_func.h index 0c0b9e4a4..374ea2b39 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.h +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,18 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, cudaStream_t stream, int gather_sms = -1); +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms); + wholememory_error_code_t scatter_func(const void* input, wholememory_matrix_description_t input_desc, void* indices, diff --git a/cpp/src/wholememory_ops/functions/map_indices_func.cu b/cpp/src/wholememory_ops/functions/map_indices_func.cu index 97d6ca868..1a1418179 100644 --- a/cpp/src/wholememory_ops/functions/map_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/map_indices_func.cu @@ -58,7 +58,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, if (block_num > 1568) block_num = 1568; IndexT* indice = static_cast(indice_ptr); IndexT* mapped_indice = static_cast(mapped_indice_ptr); - storage_idx2wm_emb_idx_kernel<<>>( + storage_idx2wm_emb_idx_kernel<<>>( indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); return; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh index ea905cd93..a0091c31c 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,7 +80,7 @@ void sort_index_in_pair(const void* indices_before_sort, IndexT* seq_indices = reinterpret_cast(allocator.allocate(indice_count * sizeof(IndexT))); thrust::sequence( - thrust::cuda::par(allocator).on(stream), seq_indices, seq_indices + indice_count, 0); + thrust::cuda::par_nosync(allocator).on(stream), seq_indices, seq_indices + indice_count, 0); // TODO: use unsigned type (wm_ops::UTypeT) can put all negative indices at last. But maybe // later... using UTypeT = typename UnsignedType::UType; auto indices_to_sort = static_cast(indices_before_sort); diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_indices_func.cu new file mode 100644 index 000000000..4cbbb0837 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.cu @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_indices_func.h" + +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +struct UnsignedType {}; + +template <> +struct UnsignedType { + using UType = unsigned int; +}; + +template <> +struct UnsignedType { + using UType = uint64_t; +}; + +template +void sort_indices_temp_func(const void* indices_before_sort, + wholememory_array_description_t indices_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + auto index_type = indices_desc.dtype; + WHOLEMEMORY_CHECK(indices_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(index_type == WHOLEMEMORY_DT_INT || index_type == WHOLEMEMORY_DT_INT64); + wm_thrust_allocator& allocator = *p_thrust_allocator; + + IndexT* seq_indices = reinterpret_cast(allocator.allocate( + wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(IndexT))); + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), + seq_indices, + seq_indices + indices_desc.size, + 0); + // use UTypeT to put minus indices at last. + using UTypeT = typename UnsignedType::UType; + const UTypeT* indices_to_sort = static_cast(indices_before_sort); + UTypeT* sorted_indice = static_cast(indices_after_sort); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + allocator.deallocate(reinterpret_cast(seq_indices), + wholememory_get_memory_size_from_array(&indices_desc)); + allocator.deallocate(static_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortIndices, sort_indices_temp_func, SINT3264) + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortIndices, + indices_before_sort, + indice_desc, + indices_after_sort, + raw_indices, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.h b/cpp/src/wholememory_ops/functions/sort_indices_func.h new file mode 100644 index 000000000..98a7932cb --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index a6b2e97b5..98d41d222 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,13 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten void* stream, int gather_sms) { - bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); - wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); + wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_NONE; if (has_handle) { - memory_type = - wholememory_get_memory_type(wholememory_tensor_get_memory_handle(wholememory_tensor)); + auto memory_handle = wholememory_tensor_get_memory_handle(wholememory_tensor); + memory_type = wholememory_get_memory_type(memory_handle); + memory_location = wholememory_get_memory_location(memory_handle); } wholememory_matrix_description_t matrix_description; auto tensor_description = *wholememory_tensor_get_tensor_description(wholememory_tensor); @@ -98,12 +100,18 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten wholememory_gref_t gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference(wholememory_tensor, &gref)); + int64_t entry_size = + tensor_description.sizes[1] * wholememory_dtype_get_element_size(tensor_description.dtype); + bool gather_with_sorted_ids = + (memory_location == WHOLEMEMORY_ML_HOST) && (entry_size <= 512) && + (memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS); return wholememory_ops::wholememory_gather_mapped(gref, matrix_description, indices, indices_desc, output, output_desc, + gather_with_sorted_ids, p_env_fns, static_cast(stream), gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index 6f85d6410..21896ff24 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu index 38e64919d..849005860 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ #include "cuda_macros.hpp" #include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_indices_func.h" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" namespace wholememory_ops { @@ -30,18 +33,46 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, - wholememory_desc, - indices, - indice_desc, - output, - output_desc, - stream, - gather_sms)); + if (gather_with_sorted_ids) { + wm_thrust_allocator thrust_allocator(p_env_fns); + temp_memory_handle dev_indices_after_sort(p_env_fns); + void* dev_indices_after_sort_ptr = + dev_indices_after_sort.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_raw_indices(p_env_fns); + void* dev_raw_indices_ptr = dev_raw_indices.device_malloc(indice_desc.size, indice_desc.dtype); + auto raw_indices_desc = wholememory_create_array_desc(indice_desc.size, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(sort_indices_func(indices, + indice_desc, + dev_indices_after_sort_ptr, + dev_raw_indices_ptr, + &thrust_allocator, + p_env_fns, + stream)); + WHOLEMEMORY_RETURN_ON_FAIL(gather_with_sorted_ids_func(wholememory_gref, + wholememory_desc, + dev_indices_after_sort_ptr, + indice_desc, + dev_raw_indices_ptr, + raw_indices_desc, + output, + output_desc, + stream, + gather_sms)); + } else { + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, + wholememory_desc, + indices, + indice_desc, + output, + output_desc, + stream, + gather_sms)); + } WM_CUDA_DEBUG_SYNC_STREAM(stream); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory_ops/temp_memory_handle.hpp b/cpp/src/wholememory_ops/temp_memory_handle.hpp index 7f74677ba..408d3bfa1 100644 --- a/cpp/src/wholememory_ops/temp_memory_handle.hpp +++ b/cpp/src/wholememory_ops/temp_memory_handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class temp_memory_handle { ~temp_memory_handle() { free_memory(); } void* device_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -40,7 +40,7 @@ class temp_memory_handle { } void* host_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -49,7 +49,7 @@ class temp_memory_handle { } void* pinned_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -57,6 +57,13 @@ class temp_memory_handle { return ptr_; } [[nodiscard]] void* pointer() const { return ptr_; } + void free_data() + { + if (ptr_ != nullptr) { + temp_mem_fns_->free_fn(memory_context_, temp_mem_fns_->global_context); + ptr_ = nullptr; + } + } void free_memory() { if (ptr_ != nullptr) { diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index 330587481..ada9c87e1 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -301,6 +301,16 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_dim(11) @@ -311,6 +321,16 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(1) + .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(2) + .set_indices_count(100005), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_dim(11) diff --git a/dependencies.yaml b/dependencies.yaml index 1e6edbe65..d20ccf9bc 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -74,8 +74,8 @@ dependencies: - cxx-compiler - cython>=3.0.0 - &doxygen doxygen==1.9.1 - - libraft-headers==24.4.* - - librmm==24.4.* + - libraft-headers==24.6.* + - librmm==24.6.* - nanobind>=0.2.0 - nccl - scikit-build-core>=0.7.0 diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index 3f9023810..4226d5b23 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -12,7 +12,7 @@ # the License. # ============================================================================= if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.04/RAPIDS.cmake + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.06/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake ) endif() diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index a10513c31..d22e3d51c 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(RAPIDS_VERSION "24.04") +set(RAPIDS_VERSION "24.06") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") include(FetchContent) diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 7cbffadd4..feffa9162 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -184,7 +184,7 @@ cdef extern from "wholememory/wholememory.h": cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend( wholememory_comm_t comm) - + cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm) cpdef enum WholeMemoryErrorCode: Success = WHOLEMEMORY_SUCCESS @@ -1113,6 +1113,10 @@ cdef class PyWholeMemoryFlattenDlpack: cdef wholememory_comm_t comm cdef int world_rank cdef int world_size + if self.device_type == MlHost and mem_type == MtContinuous: + check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle)) + if wholememory_is_intranode_communicator(comm) == False : + raise ValueError('Multi-node continuous type wholememory does not support host_view. Only supports host_view=false regardless of whether location is host or not.') global_size = wholememory_get_total_size(handle.wholememory_handle) if global_size % elt_size != 0: raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size)) diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index 0c822be0b..438a485e1 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -192,5 +192,7 @@ def int_to_wholememory_type(value: int): return wmb.WholeMemoryMemoryType.MtContinuous if value == 1: return wmb.WholeMemoryMemoryType.MtChunked + if value == 2: + return wmb.WholeMemoryMemoryType.MtDistributed else: raise ValueError("invalid int_to_wholememory_type value") diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index d9543fbcc..e9bed3a5b 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,6 +55,7 @@ def load_routine_func( embedding_dim, embedding_stride, storage_offset, + round_robin_size=0, ): wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -63,11 +64,54 @@ def load_routine_func( data_type = wmb.WholeMemoryDataType.DtInt file_list = [None] * file_part_count - per_rank_entry = wmb.determine_partition_plan(embedding_entry_count, world_size) - rank_start_entry = min(per_rank_entry * world_rank, embedding_entry_count) - rank_end_entry = min(per_rank_entry * (world_rank + 1), embedding_entry_count) + extra_embedding_count = embedding_entry_count + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + world_size * round_robin_size + ) + first_rank_extra_embedding_entry_count = min( + first_rank_extra_embedding_entry_count, round_robin_size + ) + extra_embedding_count = ( + embedding_entry_count + - embedding_entry_count % (world_size * round_robin_size) + + first_rank_extra_embedding_entry_count * world_size + ) + + per_rank_entry = wmb.determine_partition_plan(extra_embedding_count, world_size) + rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) + rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) rank_entry_count = rank_end_entry - rank_start_entry + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + world_size * round_robin_size + ) + per_rank_entry_round_robin = per_rank_entry - per_rank_entry % round_robin_size + if first_rank_extra_embedding_entry_count < round_robin_size: + if world_rank == 0: + rank_entry_count = ( + per_rank_entry_round_robin + first_rank_extra_embedding_entry_count + ) + else: + rank_entry_count = per_rank_entry_round_robin + else: + rank_entry_count = ( + per_rank_entry_round_robin + - round_robin_size + + min( + round_robin_size, + max( + 0, + ( + first_rank_extra_embedding_entry_count + - world_rank * round_robin_size + ), + ), + ) + ) + rank_end_entry = rank_start_entry + rank_entry_count + reference_local_tensor = cpu_embedding_tensor_base[ rank_start_entry:rank_end_entry, : ].cuda() @@ -87,23 +131,29 @@ def load_routine_func( continue wholememory_root_tensor = wmb.create_wholememory_matrix( data_type, - embedding_entry_count, + extra_embedding_count, embedding_dim + storage_offset, embedding_stride, wm_comm, mt, ml, ) + wholememory_tensor = wholememory_root_tensor.get_sub_tensor( [-1, storage_offset], [-1, -1] ) - wholememory_tensor.from_filelist(file_list) + wholememory_tensor.from_filelist(file_list, round_robin_size) local_tensor, local_offset = wholememory_tensor.get_local_tensor( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank, ) + if round_robin_size != 0: + assert local_tensor.shape[0] == per_rank_entry + + local_tensor = local_tensor[:rank_entry_count] + assert local_tensor.dim() == 2 assert local_tensor.shape[0] == rank_entry_count assert local_tensor.shape[1] == embedding_dim @@ -122,17 +172,41 @@ def load_routine_func( @pytest.mark.parametrize("embedding_dim", [16, 31, 33]) @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) +@pytest.mark.parametrize("round_robin_size", [256, 1024, 0]) def test_wholememory_load( file_part_count, embedding_entry_count, embedding_dim, embedding_stride, storage_offset, + round_robin_size, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( "Skipping due to embedding_stride, embedding_dim and storage_offset configuration not valid." ) + if round_robin_size != 0 and storage_offset != 0: + pytest.skip( + "Skipping due to round_robin_size!=0 and storage offset !=0 , the configuration is not valid." + ) + global gpu_count + if not gpu_count: + gpu_count = 1 + extra_embedding_count = embedding_entry_count + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + gpu_count * round_robin_size + ) + first_rank_extra_embedding_entry_count = min( + first_rank_extra_embedding_entry_count, round_robin_size + ) + + extra_embedding_count = ( + embedding_entry_count + - embedding_entry_count % (gpu_count * round_robin_size) + + first_rank_extra_embedding_entry_count * gpu_count + ) + cpu_embedding_tensor_base = torch.randint( -1000000000, 1000000000, @@ -154,6 +228,26 @@ def test_wholememory_load( "%s_part_%d_of_%d" % (file_name_prefix, i, file_part_count) ) + if round_robin_size != 0: + entry_per_rank = wmb.determine_partition_plan(extra_embedding_count, gpu_count) + + cpu_embedding_tensor_base_extra = torch.empty( + (extra_embedding_count, embedding_dim), dtype=torch.int, device="cpu" + ) + global_indices = torch.arange(0, embedding_entry_count, device="cpu") + indices_to_robin_id = global_indices // round_robin_size + indices_to_rank = indices_to_robin_id % gpu_count + indices_to_robin_id_offset = indices_to_robin_id // gpu_count + target_id = ( + indices_to_rank * entry_per_rank + + indices_to_robin_id_offset * round_robin_size + + global_indices % round_robin_size + ) + cpu_embedding_tensor_base_extra[target_id] = cpu_embedding_tensor_base + + cpu_embedding_tensor_base = cpu_embedding_tensor_base_extra + cpu_embedding_tensor_base.contiguous() + cpu_embedding_tensor_base = cpu_embedding_tensor_base.share_memory_() load_routine_func_partial = partial( @@ -165,9 +259,9 @@ def test_wholememory_load( embedding_dim=embedding_dim, embedding_stride=embedding_stride, storage_offset=storage_offset, + round_robin_size=round_robin_size, ) - global gpu_count multiprocess_run(gpu_count, load_routine_func_partial) for i in range(file_part_count): diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py index 1953419f5..366a4298b 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -364,7 +364,7 @@ def routine_func(world_rank: int, world_size: int, **kwargs): @pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("col_id_dtype", [0, 1]) @pytest.mark.parametrize("wholememory_location", ([0, 1])) -@pytest.mark.parametrize("wholememory_type", ([0, 1])) +@pytest.mark.parametrize("wholememory_type", ([0, 1, 2])) @pytest.mark.parametrize("need_center_local_output", [True, False]) @pytest.mark.parametrize("need_edge_output", [True, False]) def test_wholegraph_unweighted_sample( diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 8ad83bd77..8abc92be9 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -407,7 +407,7 @@ def create_embedding( cache_policy: Union[WholeMemoryCachePolicy, None] = None, random_init: bool = False, gather_sms: int = -1, - round_robin_size=0, + round_robin_size: int = 0, ): r""" Create embedding @@ -419,6 +419,7 @@ def create_embedding( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: WholeMemoryEmbedding """ if optimizer is None: @@ -491,6 +492,7 @@ def create_embedding_from_filelist( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: """ if isinstance(filelist, str): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index ee62e9964..84ee59eee 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -67,7 +67,7 @@ def gather(self, embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( - force_dtype if force_dtype is not None else self.embedding_tensor.dtype + force_dtype if force_dtype is not None else self.dtype ) output_tensor = torch.empty( [embedding_count, embedding_dim], @@ -156,6 +156,7 @@ def from_filelist(self, filelist: Union[List[str], str], round_robin_size: int = """ Load WholeMemory Tensor from file lists :param filelist: file list to load from + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: None """ if isinstance(filelist, str): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index aba4b6bea..d083a8abc 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -75,6 +75,11 @@ def free(self): torch_cpp_ext_lib.destroy_output_context(self.get_handle()) self.handle = 0 + def free_data(self): + self.tensor = None + if torch_cpp_ext_loaded and self.get_handle() != 0: + torch_cpp_ext_lib.free_context_data(self.get_handle()) + def torch_create_memory_context_env_fn( global_context: TorchEmptyGlobalContext, @@ -121,7 +126,7 @@ def torch_malloc_env_fn( def torch_free_env_fn( memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext ): - memory_context.free() + memory_context.free_data() class ExtContextWrapper(object): diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp index 15d2e5160..be0385d9c 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,4 +60,8 @@ void destroy_output_context(void* output_context) { destroy_torch_memory_context_func(output_context, nullptr); } +void free_context_data(void* output_context) { + torch_common_free_func(output_context, nullptr); +} + } // namespace wholegraph_torch diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp index f1dcbecdb..d805d24a5 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -24,6 +39,11 @@ void wrapped_destroy_output_context(int64_t output_context) wholegraph_torch::destroy_output_context(reinterpret_cast(output_context)); } +void wrapped_free_context_data(int64_t output_context) +{ + wholegraph_torch::free_context_data(reinterpret_cast(output_context), nullptr); +} + torch::Tensor get_torch_tensor_from_output_context(int64_t output_context) { auto* torch_output_context = @@ -39,6 +59,7 @@ PYBIND11_MODULE(pylibwholegraph_torch_ext, m) m.def("get_stream", &wrapped_get_stream, "Get current CUDA stream."); m.def("create_output_context", &wrapped_create_output_context, "Create output memory context."); m.def("destroy_output_context", &wrapped_destroy_output_context, "Destroy output memory context."); + m.def("free_context_data", &wrapped_free_context_data, "Free data in output memory context."); m.def("get_tensor_from_context", &get_torch_tensor_from_output_context, "Get PyTorch Tensor from output memory context"); diff --git a/scripts/checks/copyright.py b/scripts/checks/copyright.py deleted file mode 100644 index a26109c57..000000000 --- a/scripts/checks/copyright.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import datetime -import os -import re -import sys - -import git - -from fileutils import modifiedFiles - -FilesToCheck = [ - re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), - re.compile(r"CMakeLists[.]txt$"), - re.compile(r"setup[.]cfg$"), - re.compile(r"meta[.]yaml$"), -] -ExemptFiles = [ - re.compile(r"versioneer[.]py"), - re.compile(r".*[.]json$"), - re.compile(r"src/io/gzstream[.]hpp$"), -] - -# this will break starting at year 10000, which is probably OK :) -CheckSimple = re.compile( - r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" -) -CheckDouble = re.compile( - r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 -) - - -def checkThisFile(f): - if isinstance(f, git.Diff): - if f.deleted_file or f.b_blob.size == 0: - return False - f = f.b_path - elif not os.path.exists(f) or os.stat(f).st_size == 0: - # This check covers things like symlinks which point to files that DNE - return False - for exempt in ExemptFiles: - if exempt.search(f): - return False - for checker in FilesToCheck: - if checker.search(f): - return True - return False - - -def getCopyrightYears(line): - res = CheckSimple.search(line) - if res: - return int(res.group(1)), int(res.group(1)) - res = CheckDouble.search(line) - if res: - return int(res.group(1)), int(res.group(2)) - return None, None - - -def replaceCurrentYear(line, start, end): - # first turn a simple regex into double (if applicable). then update years - res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) - res = CheckDouble.sub( - rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", - res, - ) - return res - - -def checkCopyright(f, update_current_year): - """Checks for copyright headers and their years.""" - errs = [] - thisYear = datetime.datetime.now().year - lineNum = 0 - crFound = False - yearMatched = False - - if isinstance(f, git.Diff): - path = f.b_path - lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) - else: - path = f - with open(f, encoding="utf-8") as fp: - lines = fp.readlines() - - for line in lines: - lineNum += 1 - start, end = getCopyrightYears(line) - if start is None: - continue - crFound = True - if start > end: - e = [ - path, - lineNum, - "First year after second year in the copyright " - "header (manual fix required)", - None, - ] - errs.append(e) - elif thisYear < start or thisYear > end: - e = [ - path, - lineNum, - f"Current year {thisYear} not included in the copyright header {start}-{end}", - None, - ] - if thisYear < start: - e[-1] = replaceCurrentYear(line, thisYear, end) - if thisYear > end: - e[-1] = replaceCurrentYear(line, start, thisYear) - errs.append(e) - else: - yearMatched = True - # copyright header itself not found - if not crFound: - e = [ - path, - 0, - "Copyright header missing or formatted incorrectly " - "(manual fix required)", - None, - ] - errs.append(e) - # even if the year matches a copyright header, make the check pass - if yearMatched: - errs = [] - - if update_current_year: - errs_update = [x for x in errs if x[-1] is not None] - if len(errs_update) > 0: - lines_changed = ", ".join(str(x[1]) for x in errs_update) - print(f"File: {path}. Changing line(s) {lines_changed}") - for _, lineNum, __, replacement in errs_update: - lines[lineNum - 1] = replacement - with open(path, "w", encoding="utf-8") as out_file: - out_file.writelines(lines) - - return errs - - -def getAllFilesUnderDir(root, pathFilter=None): - retList = [] - for dirpath, dirnames, filenames in os.walk(root): - for fn in filenames: - filePath = os.path.join(dirpath, fn) - if pathFilter(filePath): - retList.append(filePath) - return retList - - -def checkCopyright_main(): - """ - Checks for copyright headers in all the modified files. In case of local - repo, this script will just look for uncommitted files and in case of CI - it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" - """ - retVal = 0 - - argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files" - ) - argparser.add_argument( - "--update-current-year", - dest="update_current_year", - action="store_true", - required=False, - help="If set, " - "update the current year if a header is already " - "present and well formatted.", - ) - argparser.add_argument( - "--git-modified-only", - dest="git_modified_only", - action="store_true", - required=False, - help="If set, " "only files seen as modified by git will be " "processed.", - ) - - args, dirs = argparser.parse_known_args() - - if args.git_modified_only: - files = [f for f in modifiedFiles() if checkThisFile(f)] - else: - files = [] - for d in [os.path.abspath(d) for d in dirs]: - if not os.path.isdir(d): - raise ValueError(f"{d} is not a directory.") - files += getAllFilesUnderDir(d, pathFilter=checkThisFile) - - errors = [] - for f in files: - errors += checkCopyright(f, args.update_current_year) - - if len(errors) > 0: - if any(e[-1] is None for e in errors): - print("Copyright headers incomplete in some of the files!") - for e in errors: - print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) - print("") - n_fixable = sum(1 for e in errors if e[-1] is not None) - file_from_repo = os.path.relpath(os.path.abspath(__file__)) - if n_fixable > 0 and not args.update_current_year: - print( - f"You can run `python {file_from_repo} --git-modified-only " - "--update-current-year` and stage the results in git to " - f"fix {n_fixable} of these errors.\n" - ) - retVal = 1 - - return retVal - - -if __name__ == "__main__": - sys.exit(checkCopyright_main())