From a5e30053dd174cefffa61d3e861634d5f59164ea Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Fri, 15 Mar 2024 12:05:59 -0400 Subject: [PATCH 01/15] DOC v24.06 Updates [skip ci] --- .github/workflows/build.yaml | 12 ++++++------ .github/workflows/pr.yaml | 18 +++++++++--------- .github/workflows/test.yaml | 6 +++--- VERSION | 2 +- ci/build_docs.sh | 4 ++-- .../environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- .../environments/all_cuda-122_arch-x86_64.yaml | 4 ++-- cpp/CMakeLists.txt | 4 ++-- cpp/Doxyfile | 2 +- dependencies.yaml | 4 ++-- fetch_rapids.cmake | 4 ++-- python/pylibwholegraph/CMakeLists.txt | 4 ++-- 12 files changed, 34 insertions(+), 34 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 6ad4ba610..2d039122b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 2c2578e04..339646eca 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.06 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.06 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 489348971..348476641 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/VERSION b/VERSION index 4a2fe8aa5..0bff6981a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -24.04.00 +24.06.00 diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 698754b28..55970c488 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -22,7 +22,7 @@ rapids-print-env rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) -export RAPIDS_VERSION_NUMBER="24.04" +export RAPIDS_VERSION_NUMBER="24.06" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-mamba-retry install \ diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 190c5d319..c0fa7bbfe 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.4.* -- librmm==24.4.* +- libraft-headers==24.6.* +- librmm==24.6.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 732e9fcff..134d72146 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.4.* -- librmm==24.4.* +- libraft-headers==24.6.* +- librmm==24.6.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index dc75bd99c..788cc2131 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. #============================================================================= -set(RAPIDS_VERSION "24.04") +set(RAPIDS_VERSION "24.06") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) diff --git a/cpp/Doxyfile b/cpp/Doxyfile index e480d8ef4..3e4e9e53f 100644 --- a/cpp/Doxyfile +++ b/cpp/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "WholeGraph C API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 24.04 +PROJECT_NUMBER = 24.06 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/dependencies.yaml b/dependencies.yaml index c699ba389..c76671a87 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -74,8 +74,8 @@ dependencies: - cxx-compiler - cython>=3.0.0 - &doxygen doxygen==1.9.1 - - libraft-headers==24.4.* - - librmm==24.4.* + - libraft-headers==24.6.* + - librmm==24.6.* - nanobind>=0.2.0 - nccl - scikit-build diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index 3f9023810..4226d5b23 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -12,7 +12,7 @@ # the License. # ============================================================================= if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.04/RAPIDS.cmake + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.06/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake ) endif() diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index 34a788f55..6fc6f134f 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(RAPIDS_VERSION "24.04") +set(RAPIDS_VERSION "24.06") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") include(FetchContent) From 5c4ffcef066e8e4ad4daba6570b00f320ec8b8e1 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Mon, 22 Apr 2024 21:48:02 +0800 Subject: [PATCH 02/15] fix CI issue due to pytorch and mkl version conflict (#162) mkl2024.1.0 conflicts with pytorch, this PR constrains mkl<2024.1.0 in the test scripts where pytorch is required. Authors: - https://github.com/linhu-nv Approvers: - Brad Rees (https://github.com/BradReesWork) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/wholegraph/pull/162 --- ci/test_python.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/test_python.sh b/ci/test_python.sh index 0efa5e8e3..dd56e7b92 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -49,6 +49,7 @@ PACKAGES="pylibwholegraph" rapids-mamba-retry install \ --channel "${CPP_CHANNEL}" \ --channel "${PYTHON_CHANNEL}" \ + 'mkl<2024.1.0' \ "${PACKAGES}" rapids-logger "Check GPU usage" From 8624b409c57164f9aed4e5191d58f503d821b57f Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:09:46 +0800 Subject: [PATCH 03/15] fix CI issue due to pytorch and mkl version conflict (#162) mkl2024.1.0 conflicts with pytorch, this PR constrains mkl<2024.1.0 in the test scripts where pytorch is required. Authors: - https://github.com/linhu-nv Approvers: - Brad Rees (https://github.com/BradReesWork) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/wholegraph/pull/162 From f8cadcf422a367442e1aa091be7e988e8b002163 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:21:25 +0800 Subject: [PATCH 04/15] remove unnecessary sync between thrust ops and host threads (#160) fix to issue 148[https://github.com/rapidsai/wholegraph/issues/148](url), remove unnecessary sync between thrust ops and host cpu threads Authors: - https://github.com/linhu-nv Approvers: - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/160 --- cpp/src/graph_ops/append_unique_func.cuh | 4 ++-- .../unweighted_sample_without_replacement_func.cuh | 4 ++-- .../weighted_sample_without_replacement_func.cuh | 6 +++--- .../functions/exchange_embeddings_nccl_func.cu | 4 ++-- .../wholememory_ops/functions/exchange_ids_nccl_func.cu | 8 +++++--- .../functions/nvshmem_gather_scatter_func.cuh | 4 ++-- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cpp/src/graph_ops/append_unique_func.cuh b/cpp/src/graph_ops/append_unique_func.cuh index ff623a22b..761fabb63 100644 --- a/cpp/src/graph_ops/append_unique_func.cuh +++ b/cpp/src/graph_ops/append_unique_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -316,7 +316,7 @@ void graph_append_unique_func(void* target_nodes_ptr, <<>>(value_id, bucket_count_ptr); WM_CUDA_CHECK(cudaGetLastError()); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), bucket_count_ptr, bucket_count_ptr + num_bucket_count, (int*)bucket_prefix_sum_ptr); diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index 291b26b2d..be0b261be 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -337,7 +337,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( // prefix sum wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), tmp_sample_count_mem_pointer, tmp_sample_count_mem_pointer + center_node_count + 1, (int*)output_sample_offset); diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index de75d7394..057d4c0c4 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -462,7 +462,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( // prefix sum wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), tmp_sample_count_mem_pointer, tmp_sample_count_mem_pointer + center_node_count + 1, static_cast(output_sample_offset)); @@ -500,7 +500,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count > sample_count_threshold) { wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns); - thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream), + thrust::exclusive_scan(thrust::cuda::par_nosync(tmp_thrust_allocator).on(stream), tmp_neighbor_counts_mem_pointer, tmp_neighbor_counts_mem_pointer + center_node_count + 1, tmp_neighbor_counts_mem_pointer); diff --git a/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu index 7cb96bcb4..88d7f331c 100644 --- a/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_embeddings_nccl_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,7 +126,7 @@ void dedup_indice_and_gradients_temp_func(int64_t* run_count, int* dev_mapping_sequence = static_cast(mapping_sequence_handle.device_malloc(raw_count * 2, WHOLEMEMORY_DT_INT)); int* dev_indice_mapping = dev_mapping_sequence + raw_count; - thrust::sequence(thrust::cuda::par(allocator).on(stream), + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), dev_mapping_sequence, dev_mapping_sequence + raw_count, 0); diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu index 53df31be0..137b10470 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,8 +59,10 @@ void exchange_ids_temp_func(const void* indices_before_sort, int64_t* seq_indices = reinterpret_cast(allocator.allocate( wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(int64_t))); - thrust::sequence( - thrust::cuda::par(allocator).on(stream), seq_indices, seq_indices + indices_desc.size, 0); + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), + seq_indices, + seq_indices + indices_desc.size, + 0); // use UTypeT to put minus indices at last. using UTypeT = typename UnsignedType::UType; const UTypeT* indices_to_sort = static_cast(indices_before_sort); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh index ea905cd93..a0091c31c 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,7 +80,7 @@ void sort_index_in_pair(const void* indices_before_sort, IndexT* seq_indices = reinterpret_cast(allocator.allocate(indice_count * sizeof(IndexT))); thrust::sequence( - thrust::cuda::par(allocator).on(stream), seq_indices, seq_indices + indice_count, 0); + thrust::cuda::par_nosync(allocator).on(stream), seq_indices, seq_indices + indice_count, 0); // TODO: use unsigned type (wm_ops::UTypeT) can put all negative indices at last. But maybe // later... using UTypeT = typename UnsignedType::UType; auto indices_to_sort = static_cast(indices_before_sort); From b453a1dad2f559cc2fae23c1feb8ffc3b7f59ea6 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:21:52 +0800 Subject: [PATCH 05/15] allow temp_memory_handler to allocate memory for multiple times (#161) fix to[ issue 76](https://github.com/rapidsai/wholegraph/issues/76), which allows temp_memory_handler to allocate memory for multiple times. Authors: - https://github.com/linhu-nv Approvers: - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/161 --- .../wholememory_ops/temp_memory_handle.hpp | 15 +++++++++---- .../pylibwholegraph/torch/wholegraph_env.py | 9 ++++++-- .../torch_cpp_ext/torch_env_func_ptrs.cpp | 6 +++++- .../torch_cpp_ext/wholegraph_torch_ext.cpp | 21 +++++++++++++++++++ 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/cpp/src/wholememory_ops/temp_memory_handle.hpp b/cpp/src/wholememory_ops/temp_memory_handle.hpp index 7f74677ba..408d3bfa1 100644 --- a/cpp/src/wholememory_ops/temp_memory_handle.hpp +++ b/cpp/src/wholememory_ops/temp_memory_handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class temp_memory_handle { ~temp_memory_handle() { free_memory(); } void* device_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -40,7 +40,7 @@ class temp_memory_handle { } void* host_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -49,7 +49,7 @@ class temp_memory_handle { } void* pinned_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -57,6 +57,13 @@ class temp_memory_handle { return ptr_; } [[nodiscard]] void* pointer() const { return ptr_; } + void free_data() + { + if (ptr_ != nullptr) { + temp_mem_fns_->free_fn(memory_context_, temp_mem_fns_->global_context); + ptr_ = nullptr; + } + } void free_memory() { if (ptr_ != nullptr) { diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index aba4b6bea..d083a8abc 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -75,6 +75,11 @@ def free(self): torch_cpp_ext_lib.destroy_output_context(self.get_handle()) self.handle = 0 + def free_data(self): + self.tensor = None + if torch_cpp_ext_loaded and self.get_handle() != 0: + torch_cpp_ext_lib.free_context_data(self.get_handle()) + def torch_create_memory_context_env_fn( global_context: TorchEmptyGlobalContext, @@ -121,7 +126,7 @@ def torch_malloc_env_fn( def torch_free_env_fn( memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext ): - memory_context.free() + memory_context.free_data() class ExtContextWrapper(object): diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp index 15d2e5160..be0385d9c 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,4 +60,8 @@ void destroy_output_context(void* output_context) { destroy_torch_memory_context_func(output_context, nullptr); } +void free_context_data(void* output_context) { + torch_common_free_func(output_context, nullptr); +} + } // namespace wholegraph_torch diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp index f1dcbecdb..d805d24a5 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -24,6 +39,11 @@ void wrapped_destroy_output_context(int64_t output_context) wholegraph_torch::destroy_output_context(reinterpret_cast(output_context)); } +void wrapped_free_context_data(int64_t output_context) +{ + wholegraph_torch::free_context_data(reinterpret_cast(output_context), nullptr); +} + torch::Tensor get_torch_tensor_from_output_context(int64_t output_context) { auto* torch_output_context = @@ -39,6 +59,7 @@ PYBIND11_MODULE(pylibwholegraph_torch_ext, m) m.def("get_stream", &wrapped_get_stream, "Get current CUDA stream."); m.def("create_output_context", &wrapped_create_output_context, "Create output memory context."); m.def("destroy_output_context", &wrapped_destroy_output_context, "Destroy output memory context."); + m.def("free_context_data", &wrapped_free_context_data, "Free data in output memory context."); m.def("get_tensor_from_context", &get_torch_tensor_from_output_context, "Get PyTorch Tensor from output memory context"); From 7d7043e099d3d633f2680cf613ad0858f9c8da26 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:22:38 +0800 Subject: [PATCH 06/15] support read file with multi threads and add test_wholememory_io for round-roubin read (#163) Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/163 --- cpp/src/wholememory/file_io.cpp | 1164 ++++++++++++++++- .../pylibwholegraph/test_wholememory_io.py | 108 +- 2 files changed, 1208 insertions(+), 64 deletions(-) diff --git a/cpp/src/wholememory/file_io.cpp b/cpp/src/wholememory/file_io.cpp index 0a627eed2..1ad7e85aa 100644 --- a/cpp/src/wholememory/file_io.cpp +++ b/cpp/src/wholememory/file_io.cpp @@ -15,12 +15,17 @@ */ #include "file_io.h" +#include +#include + #include +#include #include #include #include #include +#include #include #include "communicator.hpp" @@ -384,6 +389,476 @@ static void read_file_list_to_local_memory(char* local_ptr, "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes); } +/*! + * Read from file list to local memory of WholeMemory. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. WholeMemory will use round-robin sharding + * strategy. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_roundrobin_with_multi_threads( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int round_robin_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + size_t buffer_size; + size_t buffer_entry_count = 1; + if (suggested_buffer_size < entry_size) { + buffer_size = entry_size; + } else { + buffer_entry_count = suggested_buffer_size / entry_size; + buffer_size = buffer_entry_count * entry_size; + } + + std::atomic_size_t total_read_entry = 0; + + if (memory_offset >= memory_entry_stride) + WHOLEMEMORY_ERROR("memory offset %lu should be less than memory entry stride %lu.", + memory_offset, + memory_entry_stride); + size_t total_file_sizes = 0; + for (int i = 0; i < file_count; i++) + total_file_sizes += file_sizes[i]; + size_t total_file_entry_count = total_file_sizes / entry_size; + if (round_robin_size <= 0 || round_robin_size > total_file_entry_count / wm_world_size) + WHOLEMEMORY_ERROR("illegal round_robin_size."); + + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + size_t local_entry_memory_start_index = wm_rank * round_robin_size; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + int extra_entry = total_file_entry_count % (wm_world_size * round_robin_size); + int local_extra_entry = (extra_entry > (wm_rank + 1) * round_robin_size) + ? round_robin_size + : extra_entry - wm_rank * round_robin_size; + local_extra_entry = local_extra_entry > 0 ? local_extra_entry : 0; + size_t local_entry_count = + total_file_entry_count / (wm_world_size * round_robin_size) * round_robin_size; + + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + + int64_t local_round_robin_count = local_entry_count / round_robin_size; + + auto read_file_thread_fun = [=, &total_read_entry](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + std::vector file_read_buffer(buffer_size); + + int64_t round_robin_count_per_thread = (local_round_robin_count + thread_num - 1) / thread_num; + int64_t round_robin_count_this_thread = + std::max(0L, + std::min(round_robin_count_per_thread, + local_round_robin_count - round_robin_count_per_thread * thread_id)); + int64_t local_entry_count_this_thread = round_robin_count_this_thread * round_robin_size; + if (thread_id == thread_num - 1) { + // last thread + local_entry_count_this_thread += local_extra_entry; + } + + if (local_entry_count_this_thread == 0) return; + int64_t start_round_robin_id_in_local = thread_id * round_robin_count_per_thread; + + if (round_robin_count_this_thread == 0) { + // last thread + if (round_robin_count_per_thread != 1) { + WHOLEMEMORY_ERROR("round_robin_count_per_thread should be 1,but get %d \n", + round_robin_count_per_thread); + } + start_round_robin_id_in_local = local_round_robin_count; + } + + size_t local_entry_file_start_index_this_thread = + local_entry_file_start_index + + start_round_robin_id_in_local * wm_world_size * round_robin_size; + char* this_thread_write_ptr = + local_write_ptr + start_round_robin_id_in_local * round_robin_size * memory_entry_stride; + + size_t total_read_entry_this_thread = 0; + size_t next_entry_gap = local_entry_file_start_index_this_thread; + size_t next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + size_t read_file_begin_entry_off = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + if (file_entry_count <= next_entry_gap) { + next_entry_gap -= file_entry_count; + continue; + } + size_t read_size_from_cur_file = 0; + read_file_begin_entry_off = 0; + //$open file get fp + FILE* fp = fopen(file_names[i], "rb"); + if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } + /*|***read_file_begin_entry_off***|***entry_gap***|***cur_file_read_entry_count***|******|*/ + + while (read_file_begin_entry_off < file_entry_count) { + //$fseek by remain_entry_gap + if (read_file_begin_entry_off + next_entry_gap >= file_entry_count) { + next_entry_gap = (read_file_begin_entry_off + next_entry_gap) - file_entry_count; + break; + } + size_t file_read_start_offset = next_entry_gap * entry_size; + if (fseeko(fp, file_read_start_offset, SEEK_CUR) != 0) { + WHOLEMEMORY_ERROR("File %s seek to %ld failed.", file_names[i], file_read_start_offset); + } + + size_t cur_file_read_entry_count; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + cur_file_read_entry_count = file_entry_count - read_file_begin_entry_off - next_entry_gap; + total_read_entry_this_thread += cur_file_read_entry_count; + read_file_begin_entry_off = file_entry_count; + next_continuous_entry_count -= cur_file_read_entry_count; + next_entry_gap = 0; + } else { + cur_file_read_entry_count = next_continuous_entry_count; + total_read_entry_this_thread += cur_file_read_entry_count; + read_file_begin_entry_off += cur_file_read_entry_count + next_entry_gap; + next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + next_entry_gap = (wm_world_size - 1) * round_robin_size; + } + read_size_from_cur_file += cur_file_read_entry_count * entry_size; + // read cur_file_read_entry_count of embeddings + size_t cur_file_read_entry = cur_file_read_entry_count; + while (cur_file_read_entry_count > 0) { + size_t read_entry_count = std::min(cur_file_read_entry_count, buffer_entry_count); + int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); + if (ret != read_entry_count) { + WHOLEMEMORY_ERROR( + "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " + "returned %d, error=%s\n", + __FILE__, + __LINE__, + file_names[i], + read_entry_count, + entry_size, + ret, + strerror(errno)); + } + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(this_thread_write_ptr, + memory_entry_stride, + file_read_buffer.data(), + entry_size, + entry_size, + read_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy(this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count * entry_size, + cudaMemcpyDefault)); + } + this_thread_write_ptr += read_entry_count * memory_entry_stride; + cur_file_read_entry_count -= read_entry_count; + } + if (total_read_entry_this_thread > local_entry_count_this_thread) { + WHOLEMEMORY_ERROR( + "file read error from rank %d, thread_id %d, should read %lu entries, infact %lu " + "entries.", + wm_rank, + thread_id, + local_entry_count, + local_entry_count_this_thread); + break; + } else if (total_read_entry_this_thread == local_entry_count_this_thread) { + break; + } + } + + fclose(fp); + WHOLEMEMORY_INFO("Rank=%d thread_id=%d ,done Reading %ld bytes from file %s size=%ld", + wm_rank, + thread_id, + read_size_from_cur_file, + file_names[i], + file_sizes[i]); + + if (total_read_entry_this_thread == local_entry_count_this_thread) break; + } + total_read_entry.fetch_add(total_read_entry_this_thread); + }; + + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + + WHOLEMEMORY_INFO("Rank=%d done Reading %ld entries, infact read %ld entries", + wm_rank, + total_read_entry.load(), + local_entry_count); +}; + +/*! + * Read from file list to local memory of WholeMemory. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + size_t buffer_size; + size_t buffer_entry_count = 1; + if (suggested_buffer_size < entry_size) { + buffer_size = entry_size; + } else { + buffer_entry_count = suggested_buffer_size / entry_size; + buffer_size = buffer_entry_count * entry_size; + } + + size_t local_entry_memory_start_index = local_offset / memory_entry_stride; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + std::atomic_size_t total_read_bytes = 0; + + auto read_file_thread_fun = [=, &total_read_bytes](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + const size_t entry_count_per_thread = (local_entry_count + thread_num - 1) / thread_num; + const size_t entry_count_this_thread = + std::min(entry_count_per_thread, local_entry_count - entry_count_per_thread * thread_id); + const size_t entry_file_start_index_this_thread = + local_entry_file_start_index + thread_id * entry_count_per_thread; + char* this_thread_write_ptr = + local_write_ptr + entry_count_per_thread * thread_id * memory_entry_stride; + + std::vector file_read_buffer(buffer_size); + + if (entry_count_this_thread <= 0) return; + size_t file_entry_offset = 0; + size_t read_size_this_thread = 0; + + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= (entry_file_start_index_this_thread + entry_count_this_thread)) + break; + + // in reading window + if (file_entry_offset + file_entry_count > entry_file_start_index_this_thread) { + size_t file_read_start_offset = 0; + FILE* fp = fopen(file_names[i], "rb"); + if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = std::min( + file_entry_count, + entry_file_start_index_this_thread + entry_count_this_thread - file_entry_offset); + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + if (file_entry_offset < entry_file_start_index_this_thread) { + size_t skip_entry_count = entry_file_start_index_this_thread - file_entry_offset; + + file_read_start_offset = skip_entry_count * entry_size; + + if (fseeko(fp, file_read_start_offset, SEEK_SET) != 0) { + WHOLEMEMORY_ERROR( + "File %s seek to %ld failed.", file_names[i], skip_entry_count * entry_size); + } + to_read_file_entry_count -= skip_entry_count; + } + // now all data in file_entry_count need to be read. + size_t bytes_to_read = to_read_file_entry_count * entry_size; + size_t left_entry_count = to_read_file_entry_count; + while (left_entry_count > 0) { + size_t read_entry_count = std::min(left_entry_count, buffer_entry_count); + + int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); + if (ret != read_entry_count) { + WHOLEMEMORY_ERROR( + "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " + "returned %d, error=%s\n", + __FILE__, + __LINE__, + file_names[i], + read_entry_count, + entry_size, + ret, + strerror(errno)); + } + + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(this_thread_write_ptr, + memory_entry_stride, + file_read_buffer.data(), + entry_size, + entry_size, + read_entry_count, + cudaMemcpyDefault)); + } else { + WHOLEMEMORY_INFO( + "Rank:%d, threadid:%d, cuda Memcpy : this_thread_write_ptr:%p, " + "file_read_buffer.data():%p, read_entry_count:%d, entry_size:%d\n", + wm_rank, + thread_id, + this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count, + entry_size); + WM_CUDA_CHECK(cudaMemcpy(this_thread_write_ptr, + file_read_buffer.data(), + read_entry_count * entry_size, + cudaMemcpyDefault)); + } + this_thread_write_ptr += read_entry_count * memory_entry_stride; + + left_entry_count -= read_entry_count; + } + fclose(fp); + WHOLEMEMORY_INFO( + "Rank=%d thread_id=%d done Reading %ld bytes from file %s size=%ld, starting from " + "offset=%ld.", + wm_rank, + thread_id, + bytes_to_read, + file_names[i], + file_sizes[i], + file_read_start_offset); + total_read_bytes.fetch_add(bytes_to_read); + read_size_this_thread += bytes_to_read; + } + + file_entry_offset += file_entry_count; + } + + WHOLEMEMORY_INFO("Rank=%d thread_id=%d done Reading %ld bytes from needed files size.", + wm_rank, + thread_id, + read_size_this_thread); + }; + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + WHOLEMEMORY_INFO( + "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes.load()); +} + /*! * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better * performance by bypassing system cache if it is bottleneck. File list are binary files, which are @@ -475,15 +950,7 @@ static void read_file_list_to_local_memory_roundrobin_directio( WHOLEMEMORY_FAIL_NOTHROW( "block_size=%ld for file %s, but alignment is %ld", block_size, file_names[i], kAlignSize); } - /*if ((round_robin_size * entry_size) % block_size != 0) { - WHOLEMEMORY_FAIL_NOTHROW( - "per rank round-robin size (%d x %ld) is not mutiple of block_size (%d)for file %d.", - round_robin_size, - entry_size, - block_size, - i - ); - }*/ + size_t buffer_block_count = suggested_buffer_size / block_size; int fd = open(file_names[i], O_DIRECT | O_RDONLY); if (fd < 0) { WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); } @@ -813,6 +1280,583 @@ static void read_file_list_to_local_memory_directio(char* local_ptr, free(block_buffer); } +/*! + * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better + * performance by bypassing system cache if it is bottleneck. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_directio_with_multi_thread( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int dev_id) +{ + if (memory_offset + entry_size > memory_entry_stride) { + WHOLEMEMORY_FAIL_NOTHROW("Direct io mode only support reading all entries."); + } + size_t local_entry_start_index = local_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + static size_t kAlignSize = 16 * 1024 * 1024; + suggested_buffer_size = round_up_unsafe(suggested_buffer_size, kAlignSize); + + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + + auto read_file_thread_fun = [=](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + + char* block_buffer; + WHOLEMEMORY_CHECK_NOTHROW(posix_memalign(reinterpret_cast(&block_buffer), + kAlignSize, + suggested_buffer_size) == 0); + + const size_t entry_count_per_thread = (local_entry_count + thread_num - 1) / thread_num; + const size_t entry_count_this_thread = + std::min(entry_count_per_thread, local_entry_count - entry_count_per_thread * thread_id); + const size_t entry_file_start_index_this_thread = + local_entry_start_index + thread_id * entry_count_per_thread; + const size_t this_thread_entry_start_index = entry_file_start_index_this_thread; + char* this_thread_write_ptr = + local_write_ptr + entry_count_per_thread * thread_id * memory_entry_stride; + + if (entry_count_this_thread <= 0) return; + + size_t file_entry_offset = 0; + size_t read_entry_count = 0; + + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= this_thread_entry_start_index + entry_count_this_thread) break; + // reading window not reached + if (file_entry_offset + file_entry_count <= this_thread_entry_start_index) { + file_entry_offset += file_entry_count; + continue; + } + // in reading window + auto block_size = StatFileBlockSize(file_names[i]); + if (block_size == 0 || block_size == (size_t)-1 || kAlignSize % block_size != 0) { + WHOLEMEMORY_FAIL_NOTHROW("block_size=%ld for file %s, but alignment is %ld", + block_size, + file_names[i], + kAlignSize); + } + size_t buffer_block_count = suggested_buffer_size / block_size; + int fd = open(file_names[i], O_DIRECT | O_RDONLY); + if (fd < 0) { + WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); + } + + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = + std::min(file_entry_count, + this_thread_entry_start_index + entry_count_this_thread - file_entry_offset); + + size_t file_read_end = to_read_file_entry_count * entry_size; + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + size_t file_read_start = 0; + if (file_entry_offset < this_thread_entry_start_index) { + size_t skip_entry_count = this_thread_entry_start_index - file_entry_offset; + to_read_file_entry_count -= skip_entry_count; + file_read_start = skip_entry_count * entry_size; + } + + size_t file_block_read_offset = file_read_start / block_size * block_size; + size_t skip_head_size = file_read_start - file_block_read_offset; + + char* local_mem_write_entry_for_file = + this_thread_write_ptr + read_entry_count * memory_entry_stride; + size_t first_mem_entry_offset = 0; + size_t useful_data_bytes_read = 0; + size_t physical_data_bytes_read = 0; + while (file_block_read_offset < file_read_end) { + size_t left_size = file_read_end - file_block_read_offset; + size_t left_block_count = div_rounding_up_unsafe(left_size, block_size); + size_t read_block_count = std::min(left_block_count, buffer_block_count); + size_t physical_read_size = read_block_count * block_size; + physical_data_bytes_read += physical_read_size; + + ssize_t pread_size = pread64(fd, block_buffer, physical_read_size, file_block_read_offset); + if (pread_size != physical_read_size && + file_block_read_offset + pread_size != file_sizes[i]) { + WHOLEMEMORY_FAIL_NOTHROW( + "rank=%d, pread_size=%ld, physical_read_size=%ld, file_block_read_offset=%ld, " + "file_sizes[i]=%ld, file=%s", + wm_rank, + pread_size, + physical_read_size, + file_block_read_offset, + file_sizes[i], + file_names[i]); + } + physical_read_size = pread_size; + + size_t drop_tail_size = 0; + if (file_block_read_offset + physical_read_size > file_read_end) { + drop_tail_size = file_block_read_offset + physical_read_size - file_read_end; + } + + char* useful_data_ptr = block_buffer + skip_head_size; + size_t useful_data_size = physical_read_size - skip_head_size - drop_tail_size; + + useful_data_bytes_read += useful_data_size; + + if (first_mem_entry_offset != 0) { + // process head + size_t entry_left_size = entry_size - first_mem_entry_offset; + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(local_mem_write_entry_for_file + first_mem_entry_offset, + useful_data_ptr, + entry_left_size, + cudaMemcpyDefault)); + local_mem_write_entry_for_file += memory_entry_stride; + useful_data_ptr += entry_left_size; + useful_data_size -= entry_left_size; + entry_left_size = 0; + } + + size_t full_entry_count = useful_data_size / entry_size; + size_t full_entry_size = full_entry_count * entry_size; + + if (full_entry_size > 0) { + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_mem_write_entry_for_file, + memory_entry_stride, + useful_data_ptr, + entry_size, + entry_size, + full_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, full_entry_size, cudaMemcpyDefault)); + } + local_mem_write_entry_for_file += memory_entry_stride * full_entry_count; + useful_data_ptr += full_entry_size; + useful_data_size -= full_entry_size; + } + + size_t tail_entry_size = useful_data_size % entry_size; + first_mem_entry_offset = tail_entry_size; + + if (tail_entry_size != 0) { + // process tail + WM_CUDA_CHECK_NO_THROW(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, tail_entry_size, cudaMemcpyDefault)); + // first_mem_entry_offset = tail_entry_size; + } + + file_block_read_offset += physical_read_size; + skip_head_size = 0; + } + + WHOLEMEMORY_INFO( + "Rank=%d threadid=%d done Reading %ld useful bytes by reading %ld block bytes using " + "DirectIO from file " + "%s size=%ld.", + wm_rank, + thread_id, + useful_data_bytes_read, + physical_data_bytes_read, + file_names[i], + file_sizes[i]); + + close(fd); + file_entry_offset += file_entry_count; + read_entry_count += to_read_file_entry_count; + } + free(block_buffer); + }; + + if (threads_per_rank != 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } +} + +/*! + * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better + * performance by bypassing system cache if it is bottleneck. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. Wholememory uses round-robin sharding strategy + * here. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + * @param wm_world_size : WholeMemory world size. + * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param dev_id : the device bound to the rank. + */ +static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads( + char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank, + int wm_world_size, + int round_robin_size, + int dev_id) +{ + int threads_per_rank = 1; + const char* threads_per_rank_env_var = std::getenv("WG_LOAD_THREADS_PER_RANK"); + if (threads_per_rank_env_var != nullptr) { + try { + threads_per_rank = std::stoi(threads_per_rank_env_var); + } catch (const std::invalid_argument& e) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + if (threads_per_rank < 1 || threads_per_rank > std::thread::hardware_concurrency()) { + threads_per_rank = 1; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_THREADS_PER_RANK value %s is not valid,use the default %d", + threads_per_rank_env_var, + threads_per_rank); + } + } + + if (memory_offset + entry_size > memory_entry_stride) + WHOLEMEMORY_FAIL_NOTHROW("Direct io mode only support reading all entries."); + + static size_t kAlignSize = 16 * 1024 * 1024; + suggested_buffer_size = round_up_unsafe(suggested_buffer_size, kAlignSize); + + size_t total_file_sizes = 0; + for (int i = 0; i < file_count; i++) + total_file_sizes += file_sizes[i]; + size_t total_file_entry_count = total_file_sizes / entry_size; + if (round_robin_size <= 0 || round_robin_size > total_file_entry_count / wm_world_size) + WHOLEMEMORY_ERROR("illegal round_robin_size."); + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + size_t local_entry_memory_start_index = wm_rank * round_robin_size; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + + int extra_entry = total_file_entry_count % (wm_world_size * round_robin_size); + int local_extra_entry = (extra_entry > (wm_rank + 1) * round_robin_size) + ? round_robin_size + : extra_entry - wm_rank * round_robin_size; + local_extra_entry = local_extra_entry > 0 ? local_extra_entry : 0; + size_t local_entry_count = + total_file_entry_count / (wm_world_size * round_robin_size) * round_robin_size; + std::atomic_size_t total_read_entry = 0; + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + + int64_t local_round_robin_count = local_entry_count / round_robin_size; + + auto read_file_thread_fun = [=, &total_read_entry](int thread_id, int thread_num) { + WM_CUDA_CHECK(cudaSetDevice(dev_id)); + + char* block_buffer; + WHOLEMEMORY_CHECK_NOTHROW(posix_memalign(reinterpret_cast(&block_buffer), + kAlignSize, + suggested_buffer_size) == 0); + int64_t round_robin_count_per_thread = (local_round_robin_count + thread_num - 1) / thread_num; + int64_t round_robin_count_this_thread = + std::max(0L, + std::min(round_robin_count_per_thread, + local_round_robin_count - round_robin_count_per_thread * thread_id)); + int64_t local_entry_count_this_thread = round_robin_count_this_thread * round_robin_size; + if (thread_id == thread_num - 1) { + // last thread + local_entry_count_this_thread += local_extra_entry; + } + + if (local_entry_count_this_thread == 0) return; + int64_t start_round_robin_id_in_local = thread_id * round_robin_count_per_thread; + + if (round_robin_count_this_thread == 0) { + // last thread + if (round_robin_count_per_thread != 1) { + WHOLEMEMORY_ERROR("round_robin_count_per_thread should be 1,but get %d \n", + round_robin_count_per_thread); + } + start_round_robin_id_in_local = local_round_robin_count; + } + + size_t local_entry_file_start_index_this_thread = + local_entry_file_start_index + + start_round_robin_id_in_local * wm_world_size * round_robin_size; + char* this_thread_write_ptr = + local_write_ptr + start_round_robin_id_in_local * round_robin_size * memory_entry_stride; + + size_t total_read_entry_this_thread = 0; + + size_t next_entry_gap = local_entry_file_start_index_this_thread; + size_t next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + size_t read_file_begin_entry_off = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + if (file_entry_count <= next_entry_gap) { + next_entry_gap -= file_entry_count; + continue; + } + + auto block_size = StatFileBlockSize(file_names[i]); + if (block_size == 0 || block_size == (size_t)-1 || kAlignSize % block_size != 0) { + WHOLEMEMORY_FAIL_NOTHROW("block_size=%ld for file %s, but alignment is %ld", + block_size, + file_names[i], + kAlignSize); + } + + size_t buffer_block_count = suggested_buffer_size / block_size; + int fd = open(file_names[i], O_DIRECT | O_RDONLY); + if (fd < 0) { + WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); + } + + size_t read_size_from_cur_file = 0; + size_t useful_data_bytes_read = 0; + read_file_begin_entry_off = 0; + + /*|***read_file_begin_entry_off***|***entry_gap***|***cur_file_read_entry_count***|******|*/ + while (read_file_begin_entry_off < file_entry_count) { + if (read_file_begin_entry_off + next_entry_gap >= file_entry_count) { + next_entry_gap = (read_file_begin_entry_off + next_entry_gap) - file_entry_count; + break; + } + size_t cur_file_read_entry_count; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + cur_file_read_entry_count = file_entry_count - read_file_begin_entry_off - next_entry_gap; + } else { + cur_file_read_entry_count = next_continuous_entry_count; + } + + // read concerned vars + size_t cur_read_entry_start = read_file_begin_entry_off + next_entry_gap; + size_t cur_read_byte_start = (cur_read_entry_start * entry_size) / block_size * block_size; + size_t cur_read_byte_end = (cur_read_entry_start + cur_file_read_entry_count) * entry_size; + size_t skip_head_size = cur_read_entry_start * entry_size - cur_read_byte_start; + // write concerned vars + char* local_mem_write_entry_for_file = + this_thread_write_ptr + total_read_entry_this_thread * memory_entry_stride; + size_t first_mem_entry_offset = 0; + + while (cur_read_byte_start < cur_read_byte_end) { + size_t left_size = cur_read_byte_end - cur_read_byte_start; + size_t left_block_count = div_rounding_up_unsafe(left_size, block_size); + size_t read_block_count = std::min(left_block_count, buffer_block_count); + size_t physical_read_size = read_block_count * block_size; + // physical_data_bytes_read += physical_read_size; + read_size_from_cur_file += physical_read_size; + + ssize_t pread_size = pread64(fd, block_buffer, physical_read_size, cur_read_byte_start); + if (pread_size != physical_read_size && + cur_read_byte_start + pread_size != file_sizes[i]) { + WHOLEMEMORY_FAIL_NOTHROW( + "rank=%d, pread_size=%ld, physical_read_size=%ld, file_block_read_offset=%ld, " + "file_sizes[i]=%ld, file=%s", + wm_rank, + pread_size, + physical_read_size, + cur_read_byte_start, + file_sizes[i], + file_names[i]); + } + physical_read_size = pread_size; + size_t drop_tail_size = 0; + if (cur_read_byte_start + physical_read_size > cur_read_byte_end) { + drop_tail_size = cur_read_byte_start + physical_read_size - cur_read_byte_end; + } + + char* useful_data_ptr = block_buffer + skip_head_size; + size_t useful_data_size = physical_read_size - skip_head_size - drop_tail_size; + useful_data_bytes_read += useful_data_size; + + if (first_mem_entry_offset != 0) { + size_t entry_left_size = entry_size - first_mem_entry_offset; + WM_CUDA_CHECK_NO_THROW( + cudaMemcpy(local_mem_write_entry_for_file + first_mem_entry_offset, + useful_data_ptr, + entry_left_size, + cudaMemcpyDefault)); + local_mem_write_entry_for_file += memory_entry_stride; + useful_data_ptr += entry_left_size; + useful_data_size -= entry_left_size; + entry_left_size = 0; + } + + size_t full_entry_count = useful_data_size / entry_size; + size_t full_entry_size = full_entry_count * entry_size; + + if (full_entry_size > 0) { + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_mem_write_entry_for_file, + memory_entry_stride, + useful_data_ptr, + entry_size, + entry_size, + full_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy(local_mem_write_entry_for_file, + useful_data_ptr, + full_entry_size, + cudaMemcpyDefault)); + } + local_mem_write_entry_for_file += memory_entry_stride * full_entry_count; + useful_data_ptr += full_entry_size; + useful_data_size -= full_entry_size; + } + + size_t tail_entry_size = useful_data_size % entry_size; + first_mem_entry_offset = tail_entry_size; + if (tail_entry_size != 0) { + // process tail + WM_CUDA_CHECK_NO_THROW(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, tail_entry_size, cudaMemcpyDefault)); + } + + cur_read_byte_start += physical_read_size; + skip_head_size = 0; + } + + total_read_entry_this_thread += cur_file_read_entry_count; + // read_size_from_cur_file += cur_file_read_entry_count * entry_size; + if (read_file_begin_entry_off + next_entry_gap + next_continuous_entry_count > + file_entry_count) { + read_file_begin_entry_off = file_entry_count; + next_continuous_entry_count -= cur_file_read_entry_count; + next_entry_gap = 0; + } else { + read_file_begin_entry_off += cur_file_read_entry_count + next_entry_gap; + next_continuous_entry_count = + round_robin_size > local_entry_count_this_thread - total_read_entry_this_thread + ? local_entry_count_this_thread - total_read_entry_this_thread + : round_robin_size; + next_entry_gap = (wm_world_size - 1) * round_robin_size; + } + if (total_read_entry_this_thread > local_entry_count_this_thread) { + WHOLEMEMORY_ERROR( + "file read error from rank %d, thread_id=%d should read %lu entries, infact %lu " + "entries.", + wm_rank, + thread_id, + local_entry_count_this_thread, + total_read_entry_this_thread); + break; + } else if (total_read_entry_this_thread == local_entry_count_this_thread) { + break; + } + } + close(fd); + WHOLEMEMORY_INFO( + "Rank=%d thread_id=%d done Reading useful %ld bytes by totally reading %ld bytes from " + "file %s size=%ld " + "using direct IO", + wm_rank, + thread_id, + useful_data_bytes_read, + read_size_from_cur_file, + file_names[i], + file_sizes[i]); + if (total_read_entry_this_thread == local_entry_count_this_thread) break; + } + total_read_entry.fetch_add(total_read_entry_this_thread); + }; + + WHOLEMEMORY_INFO("Rank=%d use %d threads to read file.", wm_rank, threads_per_rank); + + if (threads_per_rank > 1) { + std::vector read_file_threads; + read_file_threads.reserve(threads_per_rank); + for (int i = 0; i < threads_per_rank; i++) { + read_file_threads.emplace_back(read_file_thread_fun, i, threads_per_rank); + } + + for (auto&& thread : read_file_threads) { + thread.join(); + } + } else { + read_file_thread_fun(0, 1); + } + + WHOLEMEMORY_INFO("Rank=%d done Reading %ld entries, infact read %ld entries", + wm_rank, + total_read_entry.load(), + local_entry_count); +} + wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_handle, size_t memory_offset, size_t memory_entry_stride, @@ -935,59 +1979,65 @@ wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_ha } if (!use_direct_io) { if (round_robin_size == 0) { - read_file_list_to_local_memory(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank); + read_file_list_to_local_memory_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + wm_comm->dev_id); } else { - read_file_list_to_local_memory_roundrobin(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank, - wm_world_size, - round_robin_size); + read_file_list_to_local_memory_roundrobin_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + round_robin_size, + wm_comm->dev_id); } } else { if (round_robin_size == 0) { - read_file_list_to_local_memory_directio(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank); + read_file_list_to_local_memory_directio_with_multi_thread(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + wm_comm->dev_id); } else { - read_file_list_to_local_memory_roundrobin_directio(local_ptr, - local_size, - local_offset, - entry_size, - memory_entry_stride, - memory_offset, - file_count, - file_names, - file_sizes, - suggested_buffer_size, - wm_rank, - wm_world_size, - round_robin_size); + read_file_list_to_local_memory_roundrobin_directio_with_multi_threads(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank, + wm_world_size, + round_robin_size, + wm_comm->dev_id); } } diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index d9543fbcc..e9bed3a5b 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -55,6 +55,7 @@ def load_routine_func( embedding_dim, embedding_stride, storage_offset, + round_robin_size=0, ): wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -63,11 +64,54 @@ def load_routine_func( data_type = wmb.WholeMemoryDataType.DtInt file_list = [None] * file_part_count - per_rank_entry = wmb.determine_partition_plan(embedding_entry_count, world_size) - rank_start_entry = min(per_rank_entry * world_rank, embedding_entry_count) - rank_end_entry = min(per_rank_entry * (world_rank + 1), embedding_entry_count) + extra_embedding_count = embedding_entry_count + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + world_size * round_robin_size + ) + first_rank_extra_embedding_entry_count = min( + first_rank_extra_embedding_entry_count, round_robin_size + ) + extra_embedding_count = ( + embedding_entry_count + - embedding_entry_count % (world_size * round_robin_size) + + first_rank_extra_embedding_entry_count * world_size + ) + + per_rank_entry = wmb.determine_partition_plan(extra_embedding_count, world_size) + rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) + rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) rank_entry_count = rank_end_entry - rank_start_entry + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + world_size * round_robin_size + ) + per_rank_entry_round_robin = per_rank_entry - per_rank_entry % round_robin_size + if first_rank_extra_embedding_entry_count < round_robin_size: + if world_rank == 0: + rank_entry_count = ( + per_rank_entry_round_robin + first_rank_extra_embedding_entry_count + ) + else: + rank_entry_count = per_rank_entry_round_robin + else: + rank_entry_count = ( + per_rank_entry_round_robin + - round_robin_size + + min( + round_robin_size, + max( + 0, + ( + first_rank_extra_embedding_entry_count + - world_rank * round_robin_size + ), + ), + ) + ) + rank_end_entry = rank_start_entry + rank_entry_count + reference_local_tensor = cpu_embedding_tensor_base[ rank_start_entry:rank_end_entry, : ].cuda() @@ -87,23 +131,29 @@ def load_routine_func( continue wholememory_root_tensor = wmb.create_wholememory_matrix( data_type, - embedding_entry_count, + extra_embedding_count, embedding_dim + storage_offset, embedding_stride, wm_comm, mt, ml, ) + wholememory_tensor = wholememory_root_tensor.get_sub_tensor( [-1, storage_offset], [-1, -1] ) - wholememory_tensor.from_filelist(file_list) + wholememory_tensor.from_filelist(file_list, round_robin_size) local_tensor, local_offset = wholememory_tensor.get_local_tensor( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank, ) + if round_robin_size != 0: + assert local_tensor.shape[0] == per_rank_entry + + local_tensor = local_tensor[:rank_entry_count] + assert local_tensor.dim() == 2 assert local_tensor.shape[0] == rank_entry_count assert local_tensor.shape[1] == embedding_dim @@ -122,17 +172,41 @@ def load_routine_func( @pytest.mark.parametrize("embedding_dim", [16, 31, 33]) @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) +@pytest.mark.parametrize("round_robin_size", [256, 1024, 0]) def test_wholememory_load( file_part_count, embedding_entry_count, embedding_dim, embedding_stride, storage_offset, + round_robin_size, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( "Skipping due to embedding_stride, embedding_dim and storage_offset configuration not valid." ) + if round_robin_size != 0 and storage_offset != 0: + pytest.skip( + "Skipping due to round_robin_size!=0 and storage offset !=0 , the configuration is not valid." + ) + global gpu_count + if not gpu_count: + gpu_count = 1 + extra_embedding_count = embedding_entry_count + if round_robin_size != 0: + first_rank_extra_embedding_entry_count = embedding_entry_count % ( + gpu_count * round_robin_size + ) + first_rank_extra_embedding_entry_count = min( + first_rank_extra_embedding_entry_count, round_robin_size + ) + + extra_embedding_count = ( + embedding_entry_count + - embedding_entry_count % (gpu_count * round_robin_size) + + first_rank_extra_embedding_entry_count * gpu_count + ) + cpu_embedding_tensor_base = torch.randint( -1000000000, 1000000000, @@ -154,6 +228,26 @@ def test_wholememory_load( "%s_part_%d_of_%d" % (file_name_prefix, i, file_part_count) ) + if round_robin_size != 0: + entry_per_rank = wmb.determine_partition_plan(extra_embedding_count, gpu_count) + + cpu_embedding_tensor_base_extra = torch.empty( + (extra_embedding_count, embedding_dim), dtype=torch.int, device="cpu" + ) + global_indices = torch.arange(0, embedding_entry_count, device="cpu") + indices_to_robin_id = global_indices // round_robin_size + indices_to_rank = indices_to_robin_id % gpu_count + indices_to_robin_id_offset = indices_to_robin_id // gpu_count + target_id = ( + indices_to_rank * entry_per_rank + + indices_to_robin_id_offset * round_robin_size + + global_indices % round_robin_size + ) + cpu_embedding_tensor_base_extra[target_id] = cpu_embedding_tensor_base + + cpu_embedding_tensor_base = cpu_embedding_tensor_base_extra + cpu_embedding_tensor_base.contiguous() + cpu_embedding_tensor_base = cpu_embedding_tensor_base.share_memory_() load_routine_func_partial = partial( @@ -165,9 +259,9 @@ def test_wholememory_load( embedding_dim=embedding_dim, embedding_stride=embedding_stride, storage_offset=storage_offset, + round_robin_size=round_robin_size, ) - global gpu_count multiprocess_run(gpu_count, load_routine_func_partial) for i in range(file_part_count): From d77f6a027ffd2afe9906504e5652d3eddeda66f9 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Mon, 6 May 2024 10:40:21 -0700 Subject: [PATCH 07/15] Migrate to `{{ stdlib("c") }}` (#164) The `sysroot*` syntax is getting phased out (conda-forge/conda-forge.github.io#2102). The recommendation is to move to `{{ stdlib("c") }}`. Ref https://github.com/rapidsai/build-planning/issues/39 Authors: - Philip Hyunsu Cho (https://github.com/hcho3) Approvers: - Bradley Dice (https://github.com/bdice) - https://github.com/jakirkham - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/wholegraph/pull/164 --- conda/recipes/libwholegraph/conda_build_config.yaml | 5 ++++- conda/recipes/libwholegraph/meta.yaml | 4 ++-- conda/recipes/pylibwholegraph/conda_build_config.yaml | 5 ++++- conda/recipes/pylibwholegraph/meta.yaml | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/conda/recipes/libwholegraph/conda_build_config.yaml b/conda/recipes/libwholegraph/conda_build_config.yaml index aad996394..ae2c2c714 100644 --- a/conda/recipes/libwholegraph/conda_build_config.yaml +++ b/conda/recipes/libwholegraph/conda_build_config.yaml @@ -25,5 +25,8 @@ gtest_version: gmock_version: - ">=1.13.0" -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" diff --git a/conda/recipes/libwholegraph/meta.yaml b/conda/recipes/libwholegraph/meta.yaml index fd1b3dfa9..9f9b6e5b0 100644 --- a/conda/recipes/libwholegraph/meta.yaml +++ b/conda/recipes/libwholegraph/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. {% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') + environ.get('VERSION_SUFFIX', '') %} {% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} @@ -49,7 +49,7 @@ requirements: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: {% if cuda_major == "11" %} - cudatoolkit diff --git a/conda/recipes/pylibwholegraph/conda_build_config.yaml b/conda/recipes/pylibwholegraph/conda_build_config.yaml index d45aacc92..41050978a 100644 --- a/conda/recipes/pylibwholegraph/conda_build_config.yaml +++ b/conda/recipes/pylibwholegraph/conda_build_config.yaml @@ -16,5 +16,8 @@ cmake_version: scikit_build_core_version: - ">=0.7.0" -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" diff --git a/conda/recipes/pylibwholegraph/meta.yaml b/conda/recipes/pylibwholegraph/meta.yaml index 1caa9573f..829350851 100644 --- a/conda/recipes/pylibwholegraph/meta.yaml +++ b/conda/recipes/pylibwholegraph/meta.yaml @@ -54,7 +54,7 @@ requirements: - cmake {{ cmake_version }} - ninja - doxygen =1.8.20 - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - cuda-version ={{ cuda_version }} {% if cuda_major == "11" %} From 4187cfad1885ba492c7347e660a19e2b196bf0e0 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 9 May 2024 18:41:49 -0700 Subject: [PATCH 08/15] Always use a static gtest (#167) Contributes to https://github.com/rapidsai/build-planning/issues/32 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - https://github.com/damontecres - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/wholegraph/pull/167 --- .../libwholegraph/conda_build_config.yaml | 6 ----- conda/recipes/libwholegraph/meta.yaml | 4 ---- cpp/CMakeLists.txt | 3 ++- cpp/cmake/thirdparty/get_gtest.cmake | 24 ------------------- 4 files changed, 2 insertions(+), 35 deletions(-) delete mode 100644 cpp/cmake/thirdparty/get_gtest.cmake diff --git a/conda/recipes/libwholegraph/conda_build_config.yaml b/conda/recipes/libwholegraph/conda_build_config.yaml index ae2c2c714..52573b012 100644 --- a/conda/recipes/libwholegraph/conda_build_config.yaml +++ b/conda/recipes/libwholegraph/conda_build_config.yaml @@ -19,12 +19,6 @@ doxygen_version: nccl_version: - ">=2.9.9" -gtest_version: - - ">=1.13.0" - -gmock_version: - - ">=1.13.0" - c_stdlib: - sysroot diff --git a/conda/recipes/libwholegraph/meta.yaml b/conda/recipes/libwholegraph/meta.yaml index 9f9b6e5b0..e4c400e60 100644 --- a/conda/recipes/libwholegraph/meta.yaml +++ b/conda/recipes/libwholegraph/meta.yaml @@ -59,8 +59,6 @@ requirements: {% endif %} - cuda-version ={{ cuda_version }} - doxygen {{ doxygen_version }} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} - libraft ={{ minor_version }} - libraft-headers ={{ minor_version }} - librmm ={{ minor_version }} @@ -134,8 +132,6 @@ outputs: {% else %} - cuda-cudart {% endif %} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} about: home: https://rapids.ai/ license: Apache-2.0 diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 788cc2131..b3fdc6d74 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -213,7 +213,8 @@ endif() # optionally build tests if(BUILD_TESTS AND CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) - include(./cmake/thirdparty/get_gtest.cmake) + include(${rapids-cmake-dir}/cpm/gtest.cmake) + rapids_cpm_gtest(BUILD_STATIC) include(CTest) # calls enable_testing() add_subdirectory(tests) diff --git a/cpp/cmake/thirdparty/get_gtest.cmake b/cpp/cmake/thirdparty/get_gtest.cmake deleted file mode 100644 index cdc2c5d88..000000000 --- a/cpp/cmake/thirdparty/get_gtest.cmake +++ /dev/null @@ -1,24 +0,0 @@ -#============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= - -function(find_and_configure_gtest) - - include(${rapids-cmake-dir}/cpm/gtest.cmake) - rapids_cpm_gtest() - -endfunction() - -find_and_configure_gtest() From afb6a112e7d663b3acc1674230211f6293813bcd Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 17 May 2024 10:11:02 -0400 Subject: [PATCH 09/15] Remove scripts/checks/copyright.py (#149) This script appears to have never been used, and with the usage of `pre-commit-hooks`' `verify-copyright` hook is now obsolete. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Bradley Dice (https://github.com/bdice) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/149 --- scripts/checks/copyright.py | 229 ------------------------------------ 1 file changed, 229 deletions(-) delete mode 100644 scripts/checks/copyright.py diff --git a/scripts/checks/copyright.py b/scripts/checks/copyright.py deleted file mode 100644 index a26109c57..000000000 --- a/scripts/checks/copyright.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import datetime -import os -import re -import sys - -import git - -from fileutils import modifiedFiles - -FilesToCheck = [ - re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), - re.compile(r"CMakeLists[.]txt$"), - re.compile(r"setup[.]cfg$"), - re.compile(r"meta[.]yaml$"), -] -ExemptFiles = [ - re.compile(r"versioneer[.]py"), - re.compile(r".*[.]json$"), - re.compile(r"src/io/gzstream[.]hpp$"), -] - -# this will break starting at year 10000, which is probably OK :) -CheckSimple = re.compile( - r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" -) -CheckDouble = re.compile( - r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 -) - - -def checkThisFile(f): - if isinstance(f, git.Diff): - if f.deleted_file or f.b_blob.size == 0: - return False - f = f.b_path - elif not os.path.exists(f) or os.stat(f).st_size == 0: - # This check covers things like symlinks which point to files that DNE - return False - for exempt in ExemptFiles: - if exempt.search(f): - return False - for checker in FilesToCheck: - if checker.search(f): - return True - return False - - -def getCopyrightYears(line): - res = CheckSimple.search(line) - if res: - return int(res.group(1)), int(res.group(1)) - res = CheckDouble.search(line) - if res: - return int(res.group(1)), int(res.group(2)) - return None, None - - -def replaceCurrentYear(line, start, end): - # first turn a simple regex into double (if applicable). then update years - res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) - res = CheckDouble.sub( - rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", - res, - ) - return res - - -def checkCopyright(f, update_current_year): - """Checks for copyright headers and their years.""" - errs = [] - thisYear = datetime.datetime.now().year - lineNum = 0 - crFound = False - yearMatched = False - - if isinstance(f, git.Diff): - path = f.b_path - lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) - else: - path = f - with open(f, encoding="utf-8") as fp: - lines = fp.readlines() - - for line in lines: - lineNum += 1 - start, end = getCopyrightYears(line) - if start is None: - continue - crFound = True - if start > end: - e = [ - path, - lineNum, - "First year after second year in the copyright " - "header (manual fix required)", - None, - ] - errs.append(e) - elif thisYear < start or thisYear > end: - e = [ - path, - lineNum, - f"Current year {thisYear} not included in the copyright header {start}-{end}", - None, - ] - if thisYear < start: - e[-1] = replaceCurrentYear(line, thisYear, end) - if thisYear > end: - e[-1] = replaceCurrentYear(line, start, thisYear) - errs.append(e) - else: - yearMatched = True - # copyright header itself not found - if not crFound: - e = [ - path, - 0, - "Copyright header missing or formatted incorrectly " - "(manual fix required)", - None, - ] - errs.append(e) - # even if the year matches a copyright header, make the check pass - if yearMatched: - errs = [] - - if update_current_year: - errs_update = [x for x in errs if x[-1] is not None] - if len(errs_update) > 0: - lines_changed = ", ".join(str(x[1]) for x in errs_update) - print(f"File: {path}. Changing line(s) {lines_changed}") - for _, lineNum, __, replacement in errs_update: - lines[lineNum - 1] = replacement - with open(path, "w", encoding="utf-8") as out_file: - out_file.writelines(lines) - - return errs - - -def getAllFilesUnderDir(root, pathFilter=None): - retList = [] - for dirpath, dirnames, filenames in os.walk(root): - for fn in filenames: - filePath = os.path.join(dirpath, fn) - if pathFilter(filePath): - retList.append(filePath) - return retList - - -def checkCopyright_main(): - """ - Checks for copyright headers in all the modified files. In case of local - repo, this script will just look for uncommitted files and in case of CI - it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" - """ - retVal = 0 - - argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files" - ) - argparser.add_argument( - "--update-current-year", - dest="update_current_year", - action="store_true", - required=False, - help="If set, " - "update the current year if a header is already " - "present and well formatted.", - ) - argparser.add_argument( - "--git-modified-only", - dest="git_modified_only", - action="store_true", - required=False, - help="If set, " "only files seen as modified by git will be " "processed.", - ) - - args, dirs = argparser.parse_known_args() - - if args.git_modified_only: - files = [f for f in modifiedFiles() if checkThisFile(f)] - else: - files = [] - for d in [os.path.abspath(d) for d in dirs]: - if not os.path.isdir(d): - raise ValueError(f"{d} is not a directory.") - files += getAllFilesUnderDir(d, pathFilter=checkThisFile) - - errors = [] - for f in files: - errors += checkCopyright(f, args.update_current_year) - - if len(errors) > 0: - if any(e[-1] is None for e in errors): - print("Copyright headers incomplete in some of the files!") - for e in errors: - print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) - print("") - n_fixable = sum(1 for e in errors if e[-1] is not None) - file_from_repo = os.path.relpath(os.path.abspath(__file__)) - if n_fixable > 0 and not args.update_current_year: - print( - f"You can run `python {file_from_repo} --git-modified-only " - "--update-current-year` and stage the results in git to " - f"fix {n_fixable} of these errors.\n" - ) - retVal = 1 - - return retVal - - -if __name__ == "__main__": - sys.exit(checkCopyright_main()) From a0870854ac8a301ca16d5e62db7a683c4bb150f7 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 23 May 2024 20:42:54 +0800 Subject: [PATCH 10/15] Fix host view for mnnvl (#166) Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/166 --- cpp/include/wholememory/wholememory.h | 6 ++++++ cpp/src/wholememory/memory_handle.cpp | 8 +++++++- cpp/src/wholememory/wholememory.cpp | 5 +++++ .../pylibwholegraph/binding/wholememory_binding.pyx | 6 +++++- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 66bd993fd..08f16213f 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -387,6 +387,12 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem size_t file_entry_size, const char* local_file_name); +/** + * @param comm : WholeMemory Comm + * @return : bool + */ +bool wholememory_is_intranode_communicator(wholememory_comm_t comm); + bool wholememory_is_build_with_nvshmem(); #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t wholememory_get_nvshmem_reference( diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 8024b461a..ca8b0ad75 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,12 @@ class wholememory_impl { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size; if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset; + if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) && + (!(comm_->is_intranode()))) { + WHOLEMEMORY_WARN( + " Multi-node continuous type wholememory can only be accessed by GPU threads but not CPU " + "threads, regardless of whether the location of wholememory is host."); + } } virtual bool get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 59dcc89bb..180da2f01 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -261,6 +261,11 @@ wholememory_error_code_t wholememory_load_from_hdfs_file(wholememory_handle_t wh return WHOLEMEMORY_NOT_IMPLEMENTED; } +bool wholememory_is_intranode_communicator(wholememory_comm_t comm) +{ + return wholememory::is_intranode_communicator(comm); +} + bool wholememory_is_build_with_nvshmem() { #ifdef WITH_NVSHMEM_SUPPORT diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 7cbffadd4..feffa9162 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -184,7 +184,7 @@ cdef extern from "wholememory/wholememory.h": cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend( wholememory_comm_t comm) - + cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm) cpdef enum WholeMemoryErrorCode: Success = WHOLEMEMORY_SUCCESS @@ -1113,6 +1113,10 @@ cdef class PyWholeMemoryFlattenDlpack: cdef wholememory_comm_t comm cdef int world_rank cdef int world_size + if self.device_type == MlHost and mem_type == MtContinuous: + check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle)) + if wholememory_is_intranode_communicator(comm) == False : + raise ValueError('Multi-node continuous type wholememory does not support host_view. Only supports host_view=false regardless of whether location is host or not.') global_size = wholememory_get_total_size(handle.wholememory_handle) if global_size % elt_size != 0: raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size)) From 7352f1c541eb697ebf75b8fdf9751119692d439b Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 23 May 2024 20:43:33 +0800 Subject: [PATCH 11/15] subwarp version gather op for small embedding size (#165) Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/165 --- .../functions/gather_scatter_func.cuh | 126 +++++++++++++++++- .../wholememory_gather_tests.cu | 12 +- 2 files changed, 136 insertions(+), 2 deletions(-) diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 87c89d9c2..c7983a6dc 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -309,6 +309,62 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, return; } +template +struct IsPowerOfTwo { + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + +template +__global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + const IndexT* indices, + int64_t indice_count, + OutputT* output, + wholememory_matrix_description_t output_desc) +{ + static_assert(IsPowerOfTwo::value && SUB_WARP_SIZE < 32, + "SUB_WARP_SIZE must be the power of 2,and smaller than 32."); + + auto block = cooperative_groups::this_thread_block(); + + auto subwarp = cooperative_groups::tiled_partition(block); + int sub_warp_id = subwarp.meta_group_size() * blockIdx.x + subwarp.meta_group_rank(); + int sub_warp_num = subwarp.meta_group_size() * gridDim.x; + + int lane_id_in_sub_warp = subwarp.thread_rank(); + wholememory::device_reference embedding_dev_ref(embedding_gref); + + int embedding_size = embedding_desc.sizes[1]; + int64_t embedding_stride = embedding_desc.stride; + int64_t output_stride = output_desc.stride; + + typed_data_vector embeddings; + typed_data_vector outputs; + for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + IndexT embedding_table_idx = indices[output_idx]; + if (embedding_table_idx < 0) continue; + int64_t embedding_offset = + embedding_desc.storage_offset + embedding_table_idx * embedding_stride; + + for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; + emb_idx += ALIGNMENT * SUB_WARP_SIZE) { + mov_data(&embeddings, + &embedding_dev_ref[embedding_offset + emb_idx]); +#pragma unroll + for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { + typed_data_vector_at(outputs, sub_idx) = + convert_type(typed_data_vector_at(embeddings, sub_idx)); + } + mov_data(output_ptr + emb_idx, &outputs); + } + } +} + template void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, @@ -338,6 +394,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int64_t, OutputT*, wholememory_matrix_description_t) = nullptr; + switch (alignment) { case 16: { kernel_fn = gather_func_kernel; @@ -367,6 +424,73 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int block_size = 1024; int block_count = indice_count > 1568 ? 1568 : indice_count; if (gather_sms != -1) block_count = gather_sms; + + // for small embedding size ,use subwarp to gather + int min_threads_per_embedding = embedding_desc.sizes[1] / alignment; + if (min_threads_per_embedding < 32) { +#define SWITCH_GATHER_FUNC_WITH_ALIGNMENT(KERNEL_NAME, SUB_WARP_SIZE) \ + switch (alignment) { \ + case 16: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 8: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 4: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 2: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + case 1: { \ + kernel_fn = KERNEL_NAME; \ + break; \ + } \ + default: { \ + WHOLEMEMORY_FAIL("gather func alignment=%d.", alignment); \ + return; \ + } \ + } + + int threads_per_embedding = 16; + if (min_threads_per_embedding >= 16) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 16); + threads_per_embedding = 16; + } else if (min_threads_per_embedding < 16 && min_threads_per_embedding >= 8) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 8); + threads_per_embedding = 8; + } else if (min_threads_per_embedding < 8 && min_threads_per_embedding >= 4) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 4); + threads_per_embedding = 4; + } else if (min_threads_per_embedding < 4 && min_threads_per_embedding >= 2) { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 2); + threads_per_embedding = 2; + } else { + SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 1); + threads_per_embedding = 1; + } + +#undef SWITCH_GATHER_FUNC_WITH_ALIGNMENT + block_size = 128; + int max_blocks_per_sm = 8; + WM_CUDA_CHECK( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel_fn, block_size, 0)); + + int sm_count = 100; + int device_id = 0; + WM_CUDA_CHECK(cudaGetDevice(&device_id)); + WM_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); + + // block_count = indice_count > 1568 ? 1568 : indice_count; + int min_embedding_per_block = block_size / threads_per_embedding; + block_count = min((int)(indice_count + min_embedding_per_block - 1) / min_embedding_per_block, + sm_count * max_blocks_per_sm * 4); + if (gather_sms != -1) block_count = gather_sms * max_blocks_per_sm; + } kernel_fn<<>>(embedding_gref, embedding_desc, static_cast(indices), diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index 330587481..fad314db9 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -311,6 +311,16 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(1) + .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_embedding_dim(1) + .set_embedding_stride(2) + .set_indices_count(100005), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_dim(11) From 4ff587121688336c589925425d9e58f2b5393da8 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 28 May 2024 20:33:44 +0800 Subject: [PATCH 12/15] quick fix to a map_indice bug && add comment for parameter round_robin_size (#172) quick fix to a map_indice bug && add comment for parameter round_robin_size Authors: - https://github.com/linhu-nv Approvers: - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/172 --- cpp/src/wholememory/file_io.cpp | 8 ++++---- cpp/src/wholememory_ops/functions/map_indices_func.cu | 2 +- python/pylibwholegraph/pylibwholegraph/torch/embedding.py | 4 +++- python/pylibwholegraph/pylibwholegraph/torch/tensor.py | 1 + 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/src/wholememory/file_io.cpp b/cpp/src/wholememory/file_io.cpp index 1ad7e85aa..31b87c144 100644 --- a/cpp/src/wholememory/file_io.cpp +++ b/cpp/src/wholememory/file_io.cpp @@ -97,7 +97,7 @@ static size_t get_handle_partial_size(size_t handle_size, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin(char* local_ptr, size_t local_size, @@ -407,7 +407,7 @@ static void read_file_list_to_local_memory(char* local_ptr, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. * @param dev_id : the device bound to the rank. */ static void read_file_list_to_local_memory_roundrobin_with_multi_threads( @@ -878,7 +878,7 @@ static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin_directio( char* local_ptr, @@ -1546,7 +1546,7 @@ static void read_file_list_to_local_memory_directio_with_multi_thread( * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. * @param dev_id : the device bound to the rank. */ static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads( diff --git a/cpp/src/wholememory_ops/functions/map_indices_func.cu b/cpp/src/wholememory_ops/functions/map_indices_func.cu index 97d6ca868..1a1418179 100644 --- a/cpp/src/wholememory_ops/functions/map_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/map_indices_func.cu @@ -58,7 +58,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, if (block_num > 1568) block_num = 1568; IndexT* indice = static_cast(indice_ptr); IndexT* mapped_indice = static_cast(mapped_indice_ptr); - storage_idx2wm_emb_idx_kernel<<>>( + storage_idx2wm_emb_idx_kernel<<>>( indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); return; diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 8ad83bd77..8abc92be9 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -407,7 +407,7 @@ def create_embedding( cache_policy: Union[WholeMemoryCachePolicy, None] = None, random_init: bool = False, gather_sms: int = -1, - round_robin_size=0, + round_robin_size: int = 0, ): r""" Create embedding @@ -419,6 +419,7 @@ def create_embedding( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: WholeMemoryEmbedding """ if optimizer is None: @@ -491,6 +492,7 @@ def create_embedding_from_filelist( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: """ if isinstance(filelist, str): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index ee62e9964..cb4923f41 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -156,6 +156,7 @@ def from_filelist(self, filelist: Union[List[str], str], round_robin_size: int = """ Load WholeMemory Tensor from file lists :param filelist: file list to load from + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: None """ if isinstance(filelist, str): From da0a412ab02587d80897f56ece1610b71fcb9ab4 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 28 May 2024 20:35:48 +0800 Subject: [PATCH 13/15] a quick fix to wholememory tensor gather default data type (#173) A quick fixes to this issue (https://github.com/rapidsai/wholegraph/issues/168). Set correct default wholememory tensor gather results data type. Authors: - https://github.com/linhu-nv Approvers: - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/173 --- python/pylibwholegraph/pylibwholegraph/torch/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index cb4923f41..84ee59eee 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -67,7 +67,7 @@ def gather(self, embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( - force_dtype if force_dtype is not None else self.embedding_tensor.dtype + force_dtype if force_dtype is not None else self.dtype ) output_tensor = torch.empty( [embedding_count, embedding_dim], From 19210806718330f3397c75ef619c83c79102368e Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 29 May 2024 21:54:49 +0800 Subject: [PATCH 14/15] Sort indices before gathering (#174) In continuous/chunked host mode, sorting indices before gathering can improve bandwidth by enhancing memory locality. Authors: - https://github.com/zhuofan1123 Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/174 --- .../wholememory_ops/functions/gather_func.cu | 81 +++++++++++- ...r_func_impl_floating_data_int32_indices.cu | 20 ++- ...r_func_impl_floating_data_int64_indices.cu | 20 ++- ...er_func_impl_integer_data_int32_indices.cu | 20 ++- ...er_func_impl_integer_data_int64_indices.cu | 20 ++- .../functions/gather_scatter_func.cuh | 18 ++- .../functions/gather_scatter_func.h | 14 +- .../functions/sort_indices_func.cu | 125 ++++++++++++++++++ .../functions/sort_indices_func.h | 34 +++++ cpp/src/wholememory_ops/gather_op.cpp | 18 ++- cpp/src/wholememory_ops/gather_op_impl.h | 3 +- .../wholememory_ops/gather_op_impl_mapped.cu | 49 +++++-- .../wholememory_gather_tests.cu | 10 ++ 13 files changed, 401 insertions(+), 31 deletions(-) create mode 100644 cpp/src/wholememory_ops/functions/sort_indices_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_indices_func.h diff --git a/cpp/src/wholememory_ops/functions/gather_func.cu b/cpp/src/wholememory_ops/functions/gather_func.cu index 0b79f0f15..271245d78 100644 --- a/cpp/src/wholememory_ops/functions/gather_func.cu +++ b/cpp/src/wholememory_ops/functions/gather_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -76,6 +84,75 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, void* indices, wholememory_array_description_t, + bool, + void*, + void*, + wholememory_matrix_description_t, + cudaStream_t, + int) = nullptr; + if (embedding_is_float) { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_floating_int32_func; + } else { + p_gather_func = gather_floating_int64_func; + } + } else { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_integer_int32_func; + } else { + p_gather_func = gather_integer_int64_func; + } + } + return p_gather_func(embedding_gref, + embedding_desc, + indices, + indices_desc, + false, + nullptr, + output, + output_desc, + stream, + gather_sms); + } catch (const wholememory::cuda_error& rle) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (const wholememory::logic_error& le) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_LOGIC_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms) +{ + try { + bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype); + WHOLEMEMORY_CHECK(embedding_is_float || + wholememory_dtype_is_integer_number(embedding_desc.dtype)); + bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype); + WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype)); + WHOLEMEMORY_EXPECTS( + embedding_is_float == output_is_float, + "embedding and output should be same number type, e.g. floating number or integer number."); + if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size); + WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype); + wholememory_error_code_t (*p_gather_func)(wholememory_gref_t, + wholememory_matrix_description_t, + void* indices, + wholememory_array_description_t, + bool, + void*, void*, wholememory_matrix_description_t, cudaStream_t, @@ -97,6 +174,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, embedding_desc, indices, indices_desc, + true, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu index c7679c508..a67ac0040 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu index af9d6d6ec..159aaf9a6 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu index bdb7c0be8..9943cb14b 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu index 6a6c7f330..b06ebad9f 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index c7983a6dc..a4979f7be 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -255,6 +255,8 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, const IndexT* indices, int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -284,7 +286,9 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, for (int64_t output_idx = warp_id; output_idx < indice_count; output_idx += gridDim.x * (blockDim.x / 32)) { - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; if (!use_shm) { my_shared = output_ptr; } int64_t embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; @@ -323,6 +327,8 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, const IndexT* indices, int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -345,7 +351,9 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, typed_data_vector embeddings; typed_data_vector outputs; for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; IndexT embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; int64_t embedding_offset = @@ -370,6 +378,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -392,6 +402,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, const IndexT*, int64_t, + bool, + const IndexT*, OutputT*, wholememory_matrix_description_t) = nullptr; @@ -495,6 +507,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, embedding_desc, static_cast(indices), indice_count, + gather_with_sorted_ids, + static_cast(raw_indices), static_cast(output), output_desc); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.h b/cpp/src/wholememory_ops/functions/gather_scatter_func.h index 0c0b9e4a4..374ea2b39 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.h +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,18 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, cudaStream_t stream, int gather_sms = -1); +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms); + wholememory_error_code_t scatter_func(const void* input, wholememory_matrix_description_t input_desc, void* indices, diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_indices_func.cu new file mode 100644 index 000000000..4cbbb0837 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.cu @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_indices_func.h" + +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +struct UnsignedType {}; + +template <> +struct UnsignedType { + using UType = unsigned int; +}; + +template <> +struct UnsignedType { + using UType = uint64_t; +}; + +template +void sort_indices_temp_func(const void* indices_before_sort, + wholememory_array_description_t indices_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + auto index_type = indices_desc.dtype; + WHOLEMEMORY_CHECK(indices_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(index_type == WHOLEMEMORY_DT_INT || index_type == WHOLEMEMORY_DT_INT64); + wm_thrust_allocator& allocator = *p_thrust_allocator; + + IndexT* seq_indices = reinterpret_cast(allocator.allocate( + wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(IndexT))); + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), + seq_indices, + seq_indices + indices_desc.size, + 0); + // use UTypeT to put minus indices at last. + using UTypeT = typename UnsignedType::UType; + const UTypeT* indices_to_sort = static_cast(indices_before_sort); + UTypeT* sorted_indice = static_cast(indices_after_sort); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + allocator.deallocate(reinterpret_cast(seq_indices), + wholememory_get_memory_size_from_array(&indices_desc)); + allocator.deallocate(static_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortIndices, sort_indices_temp_func, SINT3264) + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortIndices, + indices_before_sort, + indice_desc, + indices_after_sort, + raw_indices, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.h b/cpp/src/wholememory_ops/functions/sort_indices_func.h new file mode 100644 index 000000000..98a7932cb --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index a6b2e97b5..98d41d222 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,13 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten void* stream, int gather_sms) { - bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); - wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); + wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_NONE; if (has_handle) { - memory_type = - wholememory_get_memory_type(wholememory_tensor_get_memory_handle(wholememory_tensor)); + auto memory_handle = wholememory_tensor_get_memory_handle(wholememory_tensor); + memory_type = wholememory_get_memory_type(memory_handle); + memory_location = wholememory_get_memory_location(memory_handle); } wholememory_matrix_description_t matrix_description; auto tensor_description = *wholememory_tensor_get_tensor_description(wholememory_tensor); @@ -98,12 +100,18 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten wholememory_gref_t gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference(wholememory_tensor, &gref)); + int64_t entry_size = + tensor_description.sizes[1] * wholememory_dtype_get_element_size(tensor_description.dtype); + bool gather_with_sorted_ids = + (memory_location == WHOLEMEMORY_ML_HOST) && (entry_size <= 512) && + (memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS); return wholememory_ops::wholememory_gather_mapped(gref, matrix_description, indices, indices_desc, output, output_desc, + gather_with_sorted_ids, p_env_fns, static_cast(stream), gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index 6f85d6410..21896ff24 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu index 38e64919d..849005860 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ #include "cuda_macros.hpp" #include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_indices_func.h" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" namespace wholememory_ops { @@ -30,18 +33,46 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, - wholememory_desc, - indices, - indice_desc, - output, - output_desc, - stream, - gather_sms)); + if (gather_with_sorted_ids) { + wm_thrust_allocator thrust_allocator(p_env_fns); + temp_memory_handle dev_indices_after_sort(p_env_fns); + void* dev_indices_after_sort_ptr = + dev_indices_after_sort.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_raw_indices(p_env_fns); + void* dev_raw_indices_ptr = dev_raw_indices.device_malloc(indice_desc.size, indice_desc.dtype); + auto raw_indices_desc = wholememory_create_array_desc(indice_desc.size, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(sort_indices_func(indices, + indice_desc, + dev_indices_after_sort_ptr, + dev_raw_indices_ptr, + &thrust_allocator, + p_env_fns, + stream)); + WHOLEMEMORY_RETURN_ON_FAIL(gather_with_sorted_ids_func(wholememory_gref, + wholememory_desc, + dev_indices_after_sort_ptr, + indice_desc, + dev_raw_indices_ptr, + raw_indices_desc, + output, + output_desc, + stream, + gather_sms)); + } else { + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, + wholememory_desc, + indices, + indice_desc, + output, + output_desc, + stream, + gather_sms)); + } WM_CUDA_DEBUG_SYNC_STREAM(stream); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index fad314db9..ada9c87e1 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -301,6 +301,16 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_dim(11) From ae3748aab1294ce1761dd43a773c5d123b30ea40 Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Wed, 29 May 2024 06:56:49 -0700 Subject: [PATCH 15/15] Add initial support of distributed sampling (#171) This PR introduces support for distributed graph sampling (via NCCL backend). The initial implementation focuses on the uniform neighbor sampler. We are going to extend it to support other samplers in future. Highlights: - Distributed Graph Storage: Now, the graph structure (represented by `row_ptr` and `col_indx` tensors) can be stored as wholememory arrays in a distributed fashion with even distribution across ranks (support both `cpu` and `cuda` storage type). - Distributed Sampling: The sampling process leverages the existing wholegraph gather function to collect the sampled nodes and edges across all ranks. - Uniform Neighbor Sampler Support: Currently, only the uniform neighbor sampler is supported. cc. @linhu-nv @dongxuy04 @BradReesWork @nvcastet @TristonC Authors: - Chang Liu (https://github.com/chang-l) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/171 --- cpp/src/wholegraph_ops/sample_comm.cuh | 39 +- .../unweighted_sample_without_replacement.cpp | 42 +- ...ighted_sample_without_replacement_func.cuh | 6 - ...weighted_sample_without_replacement_impl.h | 19 +- ...ed_sample_without_replacement_impl_nccl.cu | 79 ++++ ...d_sample_without_replacement_nccl_func.cuh | 388 ++++++++++++++++++ .../pylibwholegraph/test_utils/test_comm.py | 2 + ...h_unweighted_sample_without_replacement.py | 4 +- 8 files changed, 566 insertions(+), 13 deletions(-) create mode 100644 cpp/src/wholegraph_ops/unweighted_sample_without_replacement_impl_nccl.cu create mode 100644 cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh diff --git a/cpp/src/wholegraph_ops/sample_comm.cuh b/cpp/src/wholegraph_ops/sample_comm.cuh index 6bf4c66e9..29cd1c472 100644 --- a/cpp/src/wholegraph_ops/sample_comm.cuh +++ b/cpp/src/wholegraph_ops/sample_comm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,4 +57,41 @@ __global__ void sample_all_kernel(wholememory_gref_t wm_csr_row_ptr, } } } + +__device__ __forceinline__ int log2_up_device(int x) +{ + if (x <= 2) return x - 1; + return 32 - __clz(x - 1); +} +template +struct ExpandWithOffsetFunc { + const IdType* indptr; + IdType* indptr_shift; + int length; + __host__ __device__ auto operator()(int64_t tIdx) + { + indptr_shift[tIdx] = indptr[tIdx % length] + tIdx / length; + } +}; + +template +struct ReduceForDegrees { + WMIdType* rowoffsets; + DegreeType* in_degree_ptr; + int length; + __host__ __device__ auto operator()(int64_t tIdx) + { + in_degree_ptr[tIdx] = rowoffsets[tIdx + length] - rowoffsets[tIdx]; + } +}; + +template +struct MinInDegreeFanout { + int max_sample_count; + __host__ __device__ auto operator()(DegreeType degree) + { + return min(static_cast(degree), max_sample_count); + } +}; + } // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp index 89fb2a9f9..b835a4bb5 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( } WHOLEMEMORY_EXPECTS_NOTHROW(!csr_row_ptr_has_handle || csr_row_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED || - csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS, + csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS || + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED, "Memory type not supported."); bool const csr_col_ptr_has_handle = wholememory_tensor_has_handle(wm_csr_col_ptr_tensor); wholememory_memory_type_t csr_col_ptr_memory_type = WHOLEMEMORY_MT_NONE; @@ -51,7 +52,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( } WHOLEMEMORY_EXPECTS_NOTHROW(!csr_col_ptr_has_handle || csr_col_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED || - csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS, + csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS || + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED, "Memory type not supported."); auto csr_row_ptr_tensor_description = @@ -108,6 +110,40 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( void* center_nodes = wholememory_tensor_get_data_pointer(center_nodes_tensor); void* output_sample_offset = wholememory_tensor_get_data_pointer(output_sample_offset_tensor); + if (csr_col_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED && + csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED) { + wholememory_distributed_backend_t distributed_backend_row = wholememory_get_distributed_backend( + wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor)); + wholememory_distributed_backend_t distributed_backend_col = wholememory_get_distributed_backend( + wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor)); + if (distributed_backend_col == WHOLEMEMORY_DB_NCCL && + distributed_backend_row == WHOLEMEMORY_DB_NCCL) { + wholememory_handle_t wm_csr_row_ptr_handle = + wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor); + wholememory_handle_t wm_csr_col_ptr_handle = + wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor); + return wholegraph_ops::wholegraph_csr_unweighted_sample_without_replacement_nccl( + wm_csr_row_ptr_handle, + wm_csr_col_ptr_handle, + csr_row_ptr_tensor_description, + csr_col_ptr_tensor_description, + center_nodes, + center_nodes_desc, + max_sample_count, + output_sample_offset, + output_sample_offset_desc, + output_dest_memory_context, + output_center_localid_memory_context, + output_edge_gid_memory_context, + random_seed, + p_env_fns, + static_cast(stream)); + } else { + WHOLEMEMORY_ERROR("Only NCCL communication backend is supported for sampling."); + return WHOLEMEMORY_INVALID_INPUT; + } + } + wholememory_gref_t wm_csr_row_ptr_gref, wm_csr_col_ptr_gref; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_tensor_get_global_reference(wm_csr_row_ptr_tensor, &wm_csr_row_ptr_gref)); diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index be0b261be..2ee08ce58 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -123,12 +123,6 @@ __global__ void large_sample_kernel( } } -__device__ __forceinline__ int log2_up_device(int x) -{ - if (x <= 2) return x - 1; - return 32 - __clz(x - 1); -} - template + +#include +#include + +#include "unweighted_sample_without_replacement_nccl_func.cuh" +#include "wholememory_ops/register.hpp" + +namespace wholegraph_ops { + +REGISTER_DISPATCH_TWO_TYPES(UnweightedSampleWithoutReplacementCSRNCCL, + wholegraph_csr_unweighted_sample_without_replacement_nccl_func, + SINT3264, + SINT3264) + +wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl( + wholememory_handle_t csr_row_wholememory_handle, + wholememory_handle_t csr_col_wholememory_handle, + wholememory_tensor_description_t wm_csr_row_ptr_desc, + wholememory_tensor_description_t wm_csr_col_ptr_desc, + void* center_nodes, + wholememory_array_description_t center_nodes_desc, + int max_sample_count, + void* output_sample_offset, + wholememory_array_description_t output_sample_offset_desc, + void* output_dest_memory_context, + void* output_center_localid_memory_context, + void* output_edge_gid_memory_context, + unsigned long long random_seed, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_TWO_TYPES(center_nodes_desc.dtype, + wm_csr_col_ptr_desc.dtype, + UnweightedSampleWithoutReplacementCSRNCCL, + csr_row_wholememory_handle, + csr_col_wholememory_handle, + wm_csr_row_ptr_desc, + wm_csr_col_ptr_desc, + center_nodes, + center_nodes_desc, + max_sample_count, + output_sample_offset, + output_sample_offset_desc, + output_dest_memory_context, + output_center_localid_memory_context, + output_edge_gid_memory_context, + random_seed, + p_env_fns, + stream); + + } catch (const wholememory::cuda_error& rle) { + // WHOLEMEMORY_FAIL_NOTHROW("%s", rle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (const wholememory::logic_error& le) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_LOGIC_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh new file mode 100644 index 000000000..18f4db21a --- /dev/null +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_nccl_func.cuh @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "wholememory_ops/output_memory_handle.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "sample_comm.cuh" + +#include "wholememory_ops/gather_op_impl.h" +using wholememory_ops::wholememory_gather_nccl; +#define WARP_SIZE 32 + +namespace wholegraph_ops { + +template +__global__ void unweighted_sample_without_replacement_nccl_kernel( + const IdType* input_nodes, + const WMIdType* csr_row_ptr_sta, + const DegreeType* in_degree, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + int* src_lid, + int64_t* output_edge_gid_ptr) +{ + int gidx = threadIdx.x + blockIdx.x * blockDim.x; + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); + int input_idx = blockIdx.x; + if (input_idx >= input_node_count) return; + + WMIdType start = csr_row_ptr_sta[input_idx]; + WMIdType end = start + in_degree[input_idx]; + int neighbor_count = (int)(in_degree[input_idx]); + if (neighbor_count <= 0) return; + int offset = sample_offset[input_idx]; + // use all neighbors if neighbors less than max_sample_count + if (neighbor_count <= max_sample_count) { + for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += blockDim.x) { + output_edge_gid_ptr[offset + sample_id] = start + sample_id; + if (src_lid) src_lid[offset + sample_id] = input_idx; + } + return; + } + uint64_t sa_p[ITEMS_PER_THREAD]; + int M = max_sample_count; + int N = neighbor_count; + // UnWeightedIndexSampleWithOutReplacement(M, N, + // sa_p, rng); + typedef cub::BlockRadixSort BlockRadixSort; + struct IntArray { + int value[BLOCK_DIM * ITEMS_PER_THREAD]; + }; + struct SampleSharedData { + IntArray s; + IntArray p; + IntArray q; + IntArray chain; + IntArray last_chain_tmp; + }; + __shared__ union { + typename BlockRadixSort::TempStorage temp_storage; + SampleSharedData sample_shared_data; + } shared_data; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; + int32_t rand_num; + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); + int32_t r = idx < M ? rand_num % (N - idx) : N; + sa_p[i] = ((uint64_t)r << 32UL) | idx; + } + __syncthreads(); + BlockRadixSort(shared_data.temp_storage).SortBlockedToStriped(sa_p); + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int s = (int)(sa_p[i] >> 32UL); + shared_data.sample_shared_data.s.value[idx] = s; + int p = sa_p[i] & 0xFFFFFFFF; + shared_data.sample_shared_data.p.value[idx] = p; + if (idx < M) shared_data.sample_shared_data.q.value[p] = idx; + shared_data.sample_shared_data.chain.value[idx] = idx; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int si = shared_data.sample_shared_data.s.value[idx]; + int si1 = shared_data.sample_shared_data.s.value[idx + 1]; + if (idx < M && (idx == M - 1 || si != si1) && si >= N - M) { + shared_data.sample_shared_data.chain.value[N - si - 1] = + shared_data.sample_shared_data.p.value[idx]; + } + } + __syncthreads(); + for (int step = 0; step < log2_up_device(M); ++step) { +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + shared_data.sample_shared_data.last_chain_tmp.value[idx] = + shared_data.sample_shared_data.chain.value[idx]; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + if (idx < M) { + shared_data.sample_shared_data.chain.value[idx] = + shared_data.sample_shared_data.last_chain_tmp + .value[shared_data.sample_shared_data.last_chain_tmp.value[idx]]; + } + } + __syncthreads(); + } +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + shared_data.sample_shared_data.last_chain_tmp.value[idx] = + N - shared_data.sample_shared_data.chain.value[idx] - 1; + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int ai; + if (idx < M) { + int qi = shared_data.sample_shared_data.q.value[idx]; + if (idx == 0 || qi == 0 || + shared_data.sample_shared_data.s.value[qi] != + shared_data.sample_shared_data.s.value[qi - 1]) { + ai = shared_data.sample_shared_data.s.value[qi]; + } else { + int prev_i = shared_data.sample_shared_data.p.value[qi - 1]; + ai = shared_data.sample_shared_data.last_chain_tmp.value[prev_i]; + } + sa_p[i] = ai; + } + } + // Output +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int idx = i * BLOCK_DIM + threadIdx.x; + int ai = sa_p[i]; + if (idx < M) { + output_edge_gid_ptr[offset + idx] = (int64_t)(start + ai); + if (src_lid) src_lid[offset + idx] = (LocalIdType)input_idx; + } + } +} + +template +void wholegraph_csr_unweighted_sample_without_replacement_nccl_func( + wholememory_handle_t csr_row_wholememory_handle, + wholememory_handle_t csr_col_wholememory_handle, + wholememory_tensor_description_t wm_csr_row_ptr_desc, + wholememory_tensor_description_t wm_csr_col_ptr_desc, + void* center_nodes, + wholememory_array_description_t center_nodes_desc, + int max_sample_count, + void* output_sample_offset, + wholememory_array_description_t output_sample_offset_desc, + void* output_dest_memory_context, + void* output_center_localid_memory_context, + void* output_edge_gid_memory_context, + unsigned long long random_seed, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + int center_node_count = center_nodes_desc.size; + WHOLEMEMORY_EXPECTS(wm_csr_row_ptr_desc.dtype == WHOLEMEMORY_DT_INT64, + "wholegraph_csr_unweighted_sample_without_replacement_nccl_func(). " + "wm_csr_row_ptr_desc.dtype != WHOLEMEMORY_DT_INT64, " + "wm_csr_row_ptr_desc.dtype = %d", + wm_csr_row_ptr_desc.dtype); + + WHOLEMEMORY_EXPECTS(output_sample_offset_desc.dtype == WHOLEMEMORY_DT_INT, + "wholegraph_csr_unweighted_sample_without_replacement_nccl_func(). " + "output_sample_offset_desc.dtype != WHOLEMEMORY_DT_INT, " + "output_sample_offset_desc.dtype = %d", + output_sample_offset_desc.dtype); + + auto double_center_node_count = center_node_count + center_node_count; + wholememory_ops::temp_memory_handle center_nodes_buf(p_env_fns); + IdType* center_nodes_expandshift_one = static_cast( + center_nodes_buf.device_malloc(double_center_node_count, center_nodes_desc.dtype)); + // fill center_nodes_shift_one with [center_nodes, center_nodes+1] + wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); + thrust::counting_iterator iota(0); + thrust::for_each(thrust::cuda::par_nosync(thrust_allocator).on(stream), + iota, + iota + double_center_node_count, + ExpandWithOffsetFunc{ + (const IdType*)center_nodes, center_nodes_expandshift_one, center_node_count}); + + // gathering [rowoffsets, rowoffsets+1] + wholememory_ops::temp_memory_handle output_buf(p_env_fns); + int64_t* center_nodes_indptr = + static_cast(output_buf.device_malloc(double_center_node_count, WHOLEMEMORY_DT_INT64)); + wholememory_array_description_t center_nodes_expandshift_one_desc{ + double_center_node_count, 0, center_nodes_desc.dtype}; + wholememory_matrix_description_t center_nodes_indptr_desc{ + {double_center_node_count, 1}, 1, 0, wm_csr_row_ptr_desc.dtype}; + wholememory_matrix_description_t wm_csr_row_ptr_mat_desc; + wholememory_convert_tensor_desc_to_matrix(&wm_csr_row_ptr_mat_desc, &wm_csr_row_ptr_desc); + wholememory_ops::wholememory_gather_nccl(csr_row_wholememory_handle, + wm_csr_row_ptr_mat_desc, + center_nodes_expandshift_one, + center_nodes_expandshift_one_desc, + center_nodes_indptr, + center_nodes_indptr_desc, + p_env_fns, + stream, + -1); + // find the in_degree (subtraction) and sample count + // temporarily store sampled_csr_ptr_buf (# of degrees/samples per node) in int32; + // can be changed to int8_t/16_t later + wholememory_ops::temp_memory_handle sampled_csr_ptr_buf(p_env_fns); + int* in_degree = + static_cast(sampled_csr_ptr_buf.device_malloc(center_node_count + 1, WHOLEMEMORY_DT_INT)); + thrust::for_each( + thrust::cuda::par_nosync(thrust_allocator).on(stream), + iota, + iota + center_node_count, + ReduceForDegrees{center_nodes_indptr, in_degree, center_node_count}); + // prefix sum to get the output_sample_offset (depending on min(max_sample_count and in_degree)) + int sampled_count = max_sample_count <= 0 ? std::numeric_limits::max() : max_sample_count; + thrust::transform_exclusive_scan(thrust::cuda::par_nosync(thrust_allocator).on(stream), + in_degree, + in_degree + center_node_count + 1, + (int*)output_sample_offset, + MinInDegreeFanout{sampled_count}, + 0, + thrust::plus()); + // start local sampling + int count; + WM_CUDA_CHECK(cudaMemcpyAsync(&count, + ((int*)output_sample_offset) + center_node_count, + sizeof(int), + cudaMemcpyDeviceToHost, + stream)); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + int64_t* output_edge_gid_ptr = nullptr; + wholememory_ops::temp_memory_handle edge_gid_buffer_mh(p_env_fns); + if (output_edge_gid_memory_context) { + wholememory_ops::output_memory_handle gen_output_edge_gid_buffer_mh( + p_env_fns, output_edge_gid_memory_context); + output_edge_gid_ptr = + (int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); + } else { + output_edge_gid_ptr = (int64_t*)edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); + } + + wholememory_ops::output_memory_handle gen_output_dest_buffer_mh(p_env_fns, + output_dest_memory_context); + WMIdType* output_dest_node_ptr = + (WMIdType*)gen_output_dest_buffer_mh.device_malloc(count, wm_csr_col_ptr_desc.dtype); + + int* output_center_localid_ptr = nullptr; + if (output_center_localid_memory_context) { + wholememory_ops::output_memory_handle gen_output_center_localid_buffer_mh( + p_env_fns, output_center_localid_memory_context); + output_center_localid_ptr = + (int*)gen_output_center_localid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT); + } + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + { + typedef void (*unweighted_sample_func_type)( + const IdType* input_nodes, + const int64_t* center_nodes_indptr, + const int* in_degree, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + int* src_lid, + int64_t* output_edge_gid_ptr); + static const unweighted_sample_func_type func_array[32] = { + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel, + unweighted_sample_without_replacement_nccl_kernel}; + static const int warp_count_array[32] = {1, 1, 1, 2, 2, 2, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}; + int func_idx = (max_sample_count - 1) / 32; + func_array[func_idx]<<>>( + (const IdType*)center_nodes, + (const int64_t*)center_nodes_indptr, + (const int*)in_degree, + center_node_count, + sampled_count, + rngstate, + (const int*)output_sample_offset, + output_sample_offset_desc, + (int*)output_center_localid_ptr, + (int64_t*)output_edge_gid_ptr); + + wholememory_matrix_description_t wm_csr_col_ptr_mat_desc; + wholememory_matrix_description_t output_dest_node_ptr_desc{ + {count, 1}, 1, 0, wm_csr_col_ptr_desc.dtype}; + wholememory_array_description_t output_edge_gid_ptr_desc{count, 0, WHOLEMEMORY_DT_INT64}; + wholememory_convert_tensor_desc_to_matrix(&wm_csr_col_ptr_mat_desc, &wm_csr_col_ptr_desc); + wholememory_ops::wholememory_gather_nccl(csr_col_wholememory_handle, + wm_csr_col_ptr_mat_desc, + output_edge_gid_ptr, + output_edge_gid_ptr_desc, + output_dest_node_ptr, + output_dest_node_ptr_desc, + p_env_fns, + stream, + -1); + } + WM_CUDA_CHECK(cudaGetLastError()); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); +} +} // namespace wholegraph_ops diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index 0c822be0b..438a485e1 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -192,5 +192,7 @@ def int_to_wholememory_type(value: int): return wmb.WholeMemoryMemoryType.MtContinuous if value == 1: return wmb.WholeMemoryMemoryType.MtChunked + if value == 2: + return wmb.WholeMemoryMemoryType.MtDistributed else: raise ValueError("invalid int_to_wholememory_type value") diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py index 1953419f5..366a4298b 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -364,7 +364,7 @@ def routine_func(world_rank: int, world_size: int, **kwargs): @pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("col_id_dtype", [0, 1]) @pytest.mark.parametrize("wholememory_location", ([0, 1])) -@pytest.mark.parametrize("wholememory_type", ([0, 1])) +@pytest.mark.parametrize("wholememory_type", ([0, 1, 2])) @pytest.mark.parametrize("need_center_local_output", [True, False]) @pytest.mark.parametrize("need_edge_output", [True, False]) def test_wholegraph_unweighted_sample(