diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9d35e3f97f..594ba8c3c4 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -5,6 +5,13 @@ ARG PYTHON_PACKAGE_MANAGER=conda FROM ${BASE} as pip-base +RUN apt update -y \ + && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \ + # faiss dependencies + libblas-dev \ + liblapack-dev \ + && rm -rf /tmp/* /var/tmp/* /var/cache/apt/* /var/lib/apt/lists/*; + ENV DEFAULT_VIRTUAL_ENV=rapids FROM ${BASE} as conda-base diff --git a/.devcontainer/cuda11.8-conda/devcontainer.json b/.devcontainer/cuda11.8-conda/devcontainer.json index 2682510ed1..536537f07f 100644 --- a/.devcontainer/cuda11.8-conda/devcontainer.json +++ b/.devcontainer/cuda11.8-conda/devcontainer.json @@ -5,12 +5,17 @@ "args": { "CUDA": "11.8", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:24.04-cpp-llvm16-cuda11.8-mambaforge-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-mambaforge-ubuntu22.04" } }, + "runArgs": [ + "--rm", + "--name", + "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-conda" + ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.4": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda11.8-pip/devcontainer.json b/.devcontainer/cuda11.8-pip/devcontainer.json index de039eeb11..92e7613a9b 100644 --- a/.devcontainer/cuda11.8-pip/devcontainer.json +++ b/.devcontainer/cuda11.8-pip/devcontainer.json @@ -5,22 +5,27 @@ "args": { "CUDA": "11.8", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:24.04-cpp-cuda11.8-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-ubuntu22.04" } }, + "runArgs": [ + "--rm", + "--name", + "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-pip" + ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/ucx:24.4": { - "version": "1.14.1" + "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": { + "version": "1.15.0" }, - "ghcr.io/rapidsai/devcontainers/features/cuda:24.4": { + "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": { "version": "11.8", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.4": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.devcontainer/cuda12.2-conda/devcontainer.json b/.devcontainer/cuda12.2-conda/devcontainer.json index 4b24d94dd1..948680eaf6 100644 --- a/.devcontainer/cuda12.2-conda/devcontainer.json +++ b/.devcontainer/cuda12.2-conda/devcontainer.json @@ -5,12 +5,17 @@ "args": { "CUDA": "12.2", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:24.04-cpp-mambaforge-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-mambaforge-ubuntu22.04" } }, + "runArgs": [ + "--rm", + "--name", + "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-conda" + ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.4": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda12.2-pip/devcontainer.json b/.devcontainer/cuda12.2-pip/devcontainer.json index 489546cb21..cd287569d8 100644 --- a/.devcontainer/cuda12.2-pip/devcontainer.json +++ b/.devcontainer/cuda12.2-pip/devcontainer.json @@ -5,22 +5,27 @@ "args": { "CUDA": "12.2", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:24.04-cpp-cuda12.2-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-cuda12.2-ubuntu22.04" } }, + "runArgs": [ + "--rm", + "--name", + "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-pip" + ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/ucx:24.4": { - "version": "1.14.1" + "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": { + "version": "1.15.0" }, - "ghcr.io/rapidsai/devcontainers/features/cuda:24.4": { + "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": { "version": "12.2", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.4": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index fc4fcd458b..d1cc52592c 100755 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -11,11 +11,14 @@ python/setup.py @rapidsai/raft-cmake-codeowners build.sh @rapidsai/raft-cmake-codeowners **/build.sh @rapidsai/raft-cmake-codeowners -#build/ops code owners -.github/ @rapidsai/ops-codeowners -ci/ @rapidsai/ops-codeowners -conda/ @rapidsai/ops-codeowners -**/Dockerfile @rapidsai/ops-codeowners -**/.dockerignore @rapidsai/ops-codeowners -docker/ @rapidsai/ops-codeowners -dependencies.yaml @rapidsai/ops-codeowners +#CI code owners +/.github/ @rapidsai/ci-codeowners +/ci/ @rapidsai/ci-codeowners +/.pre-commit-config.yaml @rapidsai/ci-codeowners + +#packaging code owners +/.devcontainers/ @rapidsai/packaging-codeowners +/conda/ @rapidsai/packaging-codeowners +/dependencies.yaml @rapidsai/packaging-codeowners +/build.sh @rapidsai/packaging-codeowners +pyproject.toml @rapidsai/packaging-codeowners diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index bd8b13d21e..e013d4f1c5 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -37,7 +37,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -46,7 +46,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -57,7 +57,7 @@ jobs: if: github.ref_type == 'branch' needs: python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -69,7 +69,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibraft: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -79,7 +79,7 @@ jobs: wheel-publish-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -87,9 +87,8 @@ jobs: date: ${{ inputs.date }} package-name: pylibraft wheel-build-raft-dask: - needs: wheel-publish-pylibraft secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -99,7 +98,7 @@ jobs: wheel-publish-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.06 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index ada46141a7..c2d9556859 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -25,29 +25,29 @@ jobs: - wheel-tests-raft-dask - devcontainer secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.06 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.06 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.06 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: pull-request conda-cpp-checks: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.06 with: build_type: pull-request enable_check_symbols: true @@ -55,19 +55,19 @@ jobs: conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.06 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.06 with: build_type: pull-request node_type: "gpu-v100-latest-1" @@ -77,34 +77,34 @@ jobs: wheel-build-pylibraft: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: pull-request script: ci/build_wheel_pylibraft.sh wheel-tests-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: pull-request script: ci/test_wheel_pylibraft.sh wheel-build-raft-dask: needs: wheel-tests-pylibraft secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.06 with: build_type: pull-request script: "ci/build_wheel_raft_dask.sh" wheel-tests-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: pull-request script: ci/test_wheel_raft_dask.sh devcontainer: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@fix/devcontainer-json-location + uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@branch-24.06 with: arch: '["amd64"]' cuda: '["12.2"]' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2a557a8b84..18094cc05a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -26,7 +26,7 @@ jobs: symbol_exclusions: _ZN\d+raft_cutlass conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -34,7 +34,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -42,7 +42,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibraft: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} @@ -51,7 +51,7 @@ jobs: script: ci/test_wheel_pylibraft.sh wheel-tests-raft-dask: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.04 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.06 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a4da6197e..e0599dae8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,82 @@ +# raft 24.06.00 (5 Jun 2024) + +## 🚨 Breaking Changes + +- Rename raft-ann-bench module to raft_ann_bench ([#2333](https://github.com/rapidsai/raft/pull/2333)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Scaling workspace resources ([#2322](https://github.com/rapidsai/raft/pull/2322)) [@achirkin](https://github.com/achirkin) +- [REVIEW] Adjust UCX dependencies ([#2304](https://github.com/rapidsai/raft/pull/2304)) [@pentschev](https://github.com/pentschev) +- Convert device_memory_resource* to device_async_resource_ref ([#2269](https://github.com/rapidsai/raft/pull/2269)) [@harrism](https://github.com/harrism) + +## 🐛 Bug Fixes + +- Fix import of VERSION file in raft-ann-bench ([#2338](https://github.com/rapidsai/raft/pull/2338)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Rename raft-ann-bench module to raft_ann_bench ([#2333](https://github.com/rapidsai/raft/pull/2333)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Support building faiss main statically ([#2323](https://github.com/rapidsai/raft/pull/2323)) [@robertmaynard](https://github.com/robertmaynard) +- Refactor spectral scale_obs to use existing normalization function ([#2319](https://github.com/rapidsai/raft/pull/2319)) [@ChuckHastings](https://github.com/ChuckHastings) +- Correct initializer list order found by cuvs ([#2317](https://github.com/rapidsai/raft/pull/2317)) [@robertmaynard](https://github.com/robertmaynard) +- ANN_BENCH: enable move semantics for configured_raft_resources ([#2311](https://github.com/rapidsai/raft/pull/2311)) [@achirkin](https://github.com/achirkin) +- Revert "Build C++ wheel ([#2264)" (#2305](https://github.com/rapidsai/raft/pull/2264)" (#2305)) [@vyasr](https://github.com/vyasr) +- Revert "Add `compile-library` by default on pylibraft build" ([#2300](https://github.com/rapidsai/raft/pull/2300)) [@vyasr](https://github.com/vyasr) +- Add VERSION to raft-ann-bench package ([#2299](https://github.com/rapidsai/raft/pull/2299)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Remove nonexistent job from workflow ([#2298](https://github.com/rapidsai/raft/pull/2298)) [@vyasr](https://github.com/vyasr) +- `libucx` should be run dependency of `raft-dask` ([#2296](https://github.com/rapidsai/raft/pull/2296)) [@divyegala](https://github.com/divyegala) +- Fix clang intrinsic warning ([#2292](https://github.com/rapidsai/raft/pull/2292)) [@aaronmondal](https://github.com/aaronmondal) +- Replace too long index file name with hash in ANN bench ([#2280](https://github.com/rapidsai/raft/pull/2280)) [@tfeher](https://github.com/tfeher) +- Fix build command for C++ compilation ([#2270](https://github.com/rapidsai/raft/pull/2270)) [@lowener](https://github.com/lowener) +- Fix a compilation error in CAGRA when enabling log output ([#2262](https://github.com/rapidsai/raft/pull/2262)) [@enp1s0](https://github.com/enp1s0) +- Correct member initialization order ([#2254](https://github.com/rapidsai/raft/pull/2254)) [@robertmaynard](https://github.com/robertmaynard) +- Fix time computation in CAGRA notebook ([#2231](https://github.com/rapidsai/raft/pull/2231)) [@lowener](https://github.com/lowener) + +## 📖 Documentation + +- Fix citation info ([#2318](https://github.com/rapidsai/raft/pull/2318)) [@enp1s0](https://github.com/enp1s0) + +## 🚀 New Features + +- Scaling workspace resources ([#2322](https://github.com/rapidsai/raft/pull/2322)) [@achirkin](https://github.com/achirkin) +- ANN_BENCH: AnnGPU::uses_stream() for optional algo GPU sync ([#2314](https://github.com/rapidsai/raft/pull/2314)) [@achirkin](https://github.com/achirkin) +- [FEA] Split Bitset code ([#2295](https://github.com/rapidsai/raft/pull/2295)) [@lowener](https://github.com/lowener) +- [FEA] support of prefiltered brute force ([#2294](https://github.com/rapidsai/raft/pull/2294)) [@rhdong](https://github.com/rhdong) +- Always use a static gtest and gbench ([#2265](https://github.com/rapidsai/raft/pull/2265)) [@robertmaynard](https://github.com/robertmaynard) +- Build C++ wheel ([#2264](https://github.com/rapidsai/raft/pull/2264)) [@vyasr](https://github.com/vyasr) +- InnerProduct Distance Metric for CAGRA search ([#2260](https://github.com/rapidsai/raft/pull/2260)) [@tarang-jain](https://github.com/tarang-jain) +- [FEA] Add support for `select_k` on CSR matrix ([#2140](https://github.com/rapidsai/raft/pull/2140)) [@rhdong](https://github.com/rhdong) + +## 🛠️ Improvements + +- ANN_BENCH: common AnnBase::index_type ([#2315](https://github.com/rapidsai/raft/pull/2315)) [@achirkin](https://github.com/achirkin) +- ANN_BENCH: split instances of RaftCagra into multiple files ([#2313](https://github.com/rapidsai/raft/pull/2313)) [@achirkin](https://github.com/achirkin) +- ANN_BENCH: a global pool of result buffers across benchmark cases ([#2312](https://github.com/rapidsai/raft/pull/2312)) [@achirkin](https://github.com/achirkin) +- Remove the shared state and the mutex from NVTX internals ([#2310](https://github.com/rapidsai/raft/pull/2310)) [@achirkin](https://github.com/achirkin) +- docs: update README.md ([#2308](https://github.com/rapidsai/raft/pull/2308)) [@eltociear](https://github.com/eltociear) +- [REVIEW] Reenable raft-dask wheel tests requiring UCX-Py ([#2307](https://github.com/rapidsai/raft/pull/2307)) [@pentschev](https://github.com/pentschev) +- [REVIEW] Adjust UCX dependencies ([#2304](https://github.com/rapidsai/raft/pull/2304)) [@pentschev](https://github.com/pentschev) +- Overhaul ops-codeowners ([#2303](https://github.com/rapidsai/raft/pull/2303)) [@raydouglass](https://github.com/raydouglass) +- Make thrust nosync execution policy the default thrust policy ([#2302](https://github.com/rapidsai/raft/pull/2302)) [@abc99lr](https://github.com/abc99lr) +- InnerProduct testing for CAGRA+HNSW ([#2297](https://github.com/rapidsai/raft/pull/2297)) [@divyegala](https://github.com/divyegala) +- Enable warnings as errors for Python tests ([#2288](https://github.com/rapidsai/raft/pull/2288)) [@mroeschke](https://github.com/mroeschke) +- Normalize dataset vectors in the CAGRA InnerProduct tests ([#2287](https://github.com/rapidsai/raft/pull/2287)) [@enp1s0](https://github.com/enp1s0) +- Use dynamic version for raft-ann-bench ([#2285](https://github.com/rapidsai/raft/pull/2285)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Make 'librmm' a 'host' dependency for conda packages ([#2284](https://github.com/rapidsai/raft/pull/2284)) [@jameslamb](https://github.com/jameslamb) +- Fix comments in cpp/include/raft/neighbors/cagra_serialize.cuh ([#2283](https://github.com/rapidsai/raft/pull/2283)) [@jiangyinzuo](https://github.com/jiangyinzuo) +- Only use functions in the limited API ([#2282](https://github.com/rapidsai/raft/pull/2282)) [@vyasr](https://github.com/vyasr) +- define 'ucx' pytest marker ([#2281](https://github.com/rapidsai/raft/pull/2281)) [@jameslamb](https://github.com/jameslamb) +- Migrate to `{{ stdlib("c") }}` ([#2278](https://github.com/rapidsai/raft/pull/2278)) [@hcho3](https://github.com/hcho3) +- add --rm and --name to devcontainer run args ([#2275](https://github.com/rapidsai/raft/pull/2275)) [@trxcllnt](https://github.com/trxcllnt) +- Update pip devcontainers to UCX v1.15.0 ([#2274](https://github.com/rapidsai/raft/pull/2274)) [@trxcllnt](https://github.com/trxcllnt) +- `#ifdef` out pragma deprecation warning messages ([#2271](https://github.com/rapidsai/raft/pull/2271)) [@trxcllnt](https://github.com/trxcllnt) +- Convert device_memory_resource* to device_async_resource_ref ([#2269](https://github.com/rapidsai/raft/pull/2269)) [@harrism](https://github.com/harrism) +- Update the developer's guide with new copyright hook ([#2266](https://github.com/rapidsai/raft/pull/2266)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- Improve coalesced reduction performance for tall and thin matrices (up to 2.6x faster) ([#2259](https://github.com/rapidsai/raft/pull/2259)) [@Nyrio](https://github.com/Nyrio) +- Adds missing files to `update-version.sh` ([#2255](https://github.com/rapidsai/raft/pull/2255)) [@AyodeAwe](https://github.com/AyodeAwe) +- Enable all tests for `arm64` jobs ([#2248](https://github.com/rapidsai/raft/pull/2248)) [@galipremsagar](https://github.com/galipremsagar) +- Update nvtx3 link in cmake ([#2246](https://github.com/rapidsai/raft/pull/2246)) [@lowener](https://github.com/lowener) +- Add CAGRA-Q subspace dim = 4 support ([#2244](https://github.com/rapidsai/raft/pull/2244)) [@enp1s0](https://github.com/enp1s0) +- Get rid of `cuco::sentinel` namespace ([#2243](https://github.com/rapidsai/raft/pull/2243)) [@PointKernel](https://github.com/PointKernel) +- Replace usages of raw `get_upstream` with `get_upstream_resource()` ([#2207](https://github.com/rapidsai/raft/pull/2207)) [@miscco](https://github.com/miscco) +- Set the import mode for dask tests ([#2142](https://github.com/rapidsai/raft/pull/2142)) [@vyasr](https://github.com/vyasr) +- Add UCXX support ([#1983](https://github.com/rapidsai/raft/pull/1983)) [@pentschev](https://github.com/pentschev) + # raft 24.04.00 (10 Apr 2024) ## 🐛 Bug Fixes diff --git a/README.md b/README.md index 7833a5cfa3..fc56859557 100755 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ - [RAFT Reference Documentation](https://docs.rapids.ai/api/raft/stable/): API Documentation. - [RAFT Getting Started](./docs/source/quick_start.md): Getting started with RAFT. - [Build and Install RAFT](./docs/source/build.md): Instructions for installing and building RAFT. -- [Example Notebooks](./notebooks): Example jupyer notebooks +- [Example Notebooks](./notebooks): Example jupyter notebooks - [RAPIDS Community](https://rapids.ai/community.html): Get help, contribute, and collaborate. - [GitHub repository](https://github.com/rapidsai/raft): Download the RAFT source code. - [Issue tracker](https://github.com/rapidsai/raft/issues): Report issues or request features. @@ -293,7 +293,7 @@ You can also install the conda packages individually using the `mamba` command a mamba install -c rapidsai -c conda-forge -c nvidia libraft libraft-headers cuda-version=12.0 ``` -If installing the C++ APIs please see [using libraft](https://docs.rapids.ai/api/raft/nightly/using_libraft/) for more information on using the pre-compiled shared library. You can also refer to the [example C++ template project](https://github.com/rapidsai/raft/tree/branch-24.04/cpp/template) for a ready-to-go CMake configuration that you can drop into your project and build against installed RAFT development artifacts above. +If installing the C++ APIs please see [using libraft](https://docs.rapids.ai/api/raft/nightly/using_libraft/) for more information on using the pre-compiled shared library. You can also refer to the [example C++ template project](https://github.com/rapidsai/raft/tree/branch-24.06/cpp/template) for a ready-to-go CMake configuration that you can drop into your project and build against installed RAFT development artifacts above. ### Installing Python through Pip @@ -354,10 +354,8 @@ If citing CAGRA, please consider the following bibtex: @misc{ootomo2023cagra, title={CAGRA: Highly Parallel Graph Construction and Approximate Nearest Neighbor Search for GPUs}, author={Hiroyuki Ootomo and Akira Naruse and Corey Nolet and Ray Wang and Tamas Feher and Yong Wang}, - year={2023}, - eprint={2308.15136}, - archivePrefix={arXiv}, - primaryClass={cs.DS} + year={2024}, + series = {ICDE '24} } ``` @@ -365,13 +363,14 @@ If citing the k-selection routines, please consider the following bibtex: ```bibtex @proceedings{10.1145/3581784, - title = {SC '23: Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis}, + title = {Parallel Top-K Algorithms on GPU: A Comprehensive Study and New Methods}, + author={Jingrong Zhang, Akira Naruse, Xipeng Li, and Yong Wang}, year = {2023}, isbn = {9798400701092}, publisher = {Association for Computing Machinery}, address = {New York, NY, USA}, - abstract = {Started in 1988, the SC Conference has become the annual nexus for researchers and practitioners from academia, industry and government to share information and foster collaborations to advance the state of the art in High Performance Computing (HPC), Networking, Storage, and Analysis.}, - location = {, Denver, CO, USA, } + location = {Denver, CO, USA} + series = {SC '23} } ``` @@ -394,4 +393,4 @@ If citing the nearest neighbors descent API, please consider the following bibte location = {Virtual Event, Queensland, Australia}, series = {CIKM '21} } -``` \ No newline at end of file +``` diff --git a/VERSION b/VERSION index 4a2fe8aa57..0bff6981a3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -24.04.00 +24.06.00 diff --git a/build.sh b/build.sh index 45c7d1380f..148d23c9c1 100755 --- a/build.sh +++ b/build.sh @@ -305,7 +305,7 @@ if hasArg --allgpuarch; then BUILD_ALL_GPU_ARCH=1 fi -if hasArg --compile-lib || hasArg pylibraft || (( ${NUMARGS} == 0 )); then +if hasArg --compile-lib || (( ${NUMARGS} == 0 )); then COMPILE_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_lib" fi @@ -405,7 +405,7 @@ fi ################################################################################ # Configure for building all C++ targets -if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann || ((${COMPILE_LIBRARY} == ON )); then +if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann; then if (( ${BUILD_ALL_GPU_ARCH} == 0 )); then RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 5d06e46303..e3e7ce9c89 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -7,6 +7,10 @@ package_name=$1 package_dir=$2 underscore_package_name=$(echo "${package_name}" | tr "-" "_") +# Clear out system ucx files to ensure that we're getting ucx from the wheel. +rm -rf /usr/lib64/ucx +rm -rf /usr/lib64/libuc* + source rapids-configure-sccache source rapids-date-string @@ -38,9 +42,11 @@ fi if [[ ${package_name} == "raft-dask" ]]; then sed -r -i "s/pylibraft==(.*)\"/pylibraft${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} + sed -r -i "s/libucx(.*)\"/libucx${PACKAGE_CUDA_SUFFIX}\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/ucx-py==(.*)\"/ucx-py${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/rapids-dask-dependency==(.*)\"/rapids-dask-dependency==\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/dask-cuda==(.*)\"/dask-cuda==\1${alpha_spec}\"/g" ${pyproject_file} + sed -r -i "s/distributed-ucxx==(.*)\"/distributed-ucxx${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} else sed -r -i "s/rmm(.*)\"/rmm${PACKAGE_CUDA_SUFFIX}\1${alpha_spec}\"/g" ${pyproject_file} fi @@ -56,6 +62,6 @@ cd "${package_dir}" python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check mkdir -p final_dist -python -m auditwheel repair -w final_dist dist/* +python -m auditwheel repair -w final_dist --exclude "libucp.so.0" dist/* RAPIDS_PY_WHEEL_NAME="${underscore_package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 final_dist diff --git a/ci/build_wheel_pylibraft.sh b/ci/build_wheel_pylibraft.sh index ec30a28b92..895c311f46 100755 --- a/ci/build_wheel_pylibraft.sh +++ b/ci/build_wheel_pylibraft.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail diff --git a/ci/build_wheel_raft_dask.sh b/ci/build_wheel_raft_dask.sh index 5ae12303d0..feba2d7a5b 100755 --- a/ci/build_wheel_raft_dask.sh +++ b/ci/build_wheel_raft_dask.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 636f637d0c..9554a7dde8 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -37,6 +37,8 @@ function sed_runner() { } sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/template/cmake/thirdparty/fetch_rapids.cmake +sed_runner 's/'"find_and_configure_ucxx(VERSION .*"'/'"find_and_configure_ucxx(VERSION ${NEXT_UCX_PY_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake +sed_runner 's/'"branch-.*"'/'"branch-${NEXT_UCX_PY_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake # Centralized version file update echo "${NEXT_FULL_TAG}" > VERSION @@ -50,7 +52,7 @@ DEPENDENCIES=( rmm-cu11 rmm-cu12 rapids-dask-dependency - # ucx-py is handled separately below + # ucx-py and ucxx are handled separately below ) for FILE in dependencies.yaml conda/environments/*.yaml; do for DEP in "${DEPENDENCIES[@]}"; do @@ -59,6 +61,10 @@ for FILE in dependencies.yaml conda/environments/*.yaml; do sed_runner "/-.* ucx-py==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; sed_runner "/-.* ucx-py-cu11==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; sed_runner "/-.* ucx-py-cu12==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* libucxx==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx-cu11==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx-cu12==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; done for FILE in python/*/pyproject.toml; do for DEP in "${DEPENDENCIES[@]}"; do @@ -68,6 +74,7 @@ for FILE in python/*/pyproject.toml; do done sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml +sed_runner "/^ucxx_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" @@ -85,5 +92,7 @@ sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" README.md find .devcontainer/ -type f -name devcontainer.json -print0 | while IFS= read -r -d '' filename; do sed_runner "s@rapidsai/devcontainers:[0-9.]*@rapidsai/devcontainers:${NEXT_SHORT_TAG}@g" "${filename}" sed_runner "s@rapidsai/devcontainers/features/ucx:[0-9.]*@rapidsai/devcontainers/features/ucx:${NEXT_SHORT_TAG_PEP440}@" "${filename}" + sed_runner "s@rapidsai/devcontainers/features/cuda:[0-9.]*@rapidsai/devcontainers/features/cuda:${NEXT_SHORT_TAG_PEP440}@" "${filename}" sed_runner "s@rapidsai/devcontainers/features/rapids-build-utils:[0-9.]*@rapidsai/devcontainers/features/rapids-build-utils:${NEXT_SHORT_TAG_PEP440}@" "${filename}" + sed_runner "s@rapids-\${localWorkspaceFolderBasename}-${CURRENT_SHORT_TAG}@rapids-\${localWorkspaceFolderBasename}-${NEXT_SHORT_TAG}@g" "${filename}" done diff --git a/ci/run_raft_dask_pytests.sh b/ci/run_raft_dask_pytests.sh index 46cd211d2e..07d0b5baa0 100755 --- a/ci/run_raft_dask_pytests.sh +++ b/ci/run_raft_dask_pytests.sh @@ -6,4 +6,4 @@ set -euo pipefail # Support invoking run_raft_dask_pytests.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/raft-dask/raft_dask -pytest --cache-clear "$@" test +pytest --cache-clear --import-mode=append "$@" test diff --git a/ci/test_python.sh b/ci/test_python.sh index f5b188ca0b..59da1f0bc4 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -59,5 +59,23 @@ rapids-logger "pytest raft-dask" --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-coverage.xml" \ --cov-report=term +rapids-logger "pytest raft-dask (ucx-py only)" +./ci/run_raft_dask_pytests.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-raft-dask-ucx.xml" \ + --cov-config=../.coveragerc \ + --cov=raft_dask \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-ucx-coverage.xml" \ + --cov-report=term \ + --run_ucx + +rapids-logger "pytest raft-dask (ucxx only)" +./ci/run_raft_dask_pytests.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-raft-dask-ucxx.xml" \ + --cov-config=../.coveragerc \ + --cov=raft_dask \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-ucxx-coverage.xml" \ + --cov-report=term \ + --run_ucxx + rapids-logger "Test script exiting with value: $EXITCODE" exit ${EXITCODE} diff --git a/ci/test_wheel_pylibraft.sh b/ci/test_wheel_pylibraft.sh index d990a0e6c2..b38f5a690b 100755 --- a/ci/test_wheel_pylibraft.sh +++ b/ci/test_wheel_pylibraft.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -10,9 +10,4 @@ RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels # echo to expand wildcard before adding `[extra]` requires for pip python -m pip install $(echo ./dist/pylibraft*.whl)[test] -# Run smoke tests for aarch64 pull requests -if [[ "$(arch)" == "aarch64" && "${RAPIDS_BUILD_TYPE}" == "pull-request" ]]; then - python ./ci/wheel_smoke_test_pylibraft.py -else - python -m pytest ./python/pylibraft/pylibraft/test -fi +python -m pytest ./python/pylibraft/pylibraft/test diff --git a/ci/test_wheel_raft_dask.sh b/ci/test_wheel_raft_dask.sh index b70563b7a1..bd531e7e85 100755 --- a/ci/test_wheel_raft_dask.sh +++ b/ci/test_wheel_raft_dask.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -11,12 +11,15 @@ RAPIDS_PY_WHEEL_NAME="raft_dask_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibraft-dep python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl -# echo to expand wildcard before adding `[extra]` requires for pip -python -m pip install $(echo ./dist/raft_dask*.whl)[test] +python -m pip install "raft_dask-${RAPIDS_PY_CUDA_SUFFIX}[test]>=0.0.0a0" --find-links dist/ -# Run smoke tests for aarch64 pull requests -if [[ "$(arch)" == "aarch64" && "${RAPIDS_BUILD_TYPE}" == "pull-request" ]]; then - python ./ci/wheel_smoke_test_raft_dask.py -else - python -m pytest ./python/raft-dask/raft_dask/test -fi +test_dir="python/raft-dask/raft_dask/test" + +rapids-logger "pytest raft-dask" +python -m pytest --import-mode=append ${test_dir} + +rapids-logger "pytest raft-dask (ucx-py only)" +python -m pytest --import-mode=append ${test_dir} --run_ucx + +rapids-logger "pytest raft-dask (ucxx only)" +python -m pytest --import-mode=append ${test_dir} --run_ucxx diff --git a/ci/wheel_smoke_test_pylibraft.py b/ci/wheel_smoke_test_pylibraft.py deleted file mode 100644 index c0df2fe45c..0000000000 --- a/ci/wheel_smoke_test_pylibraft.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import numpy as np -from scipy.spatial.distance import cdist - -from pylibraft.common import Handle, Stream, device_ndarray -from pylibraft.distance import pairwise_distance - - -if __name__ == "__main__": - metric = "euclidean" - n_rows = 1337 - n_cols = 1337 - - input1 = np.random.random_sample((n_rows, n_cols)) - input1 = np.asarray(input1, order="C").astype(np.float64) - - output = np.zeros((n_rows, n_rows), dtype=np.float64) - - expected = cdist(input1, input1, metric) - - expected[expected <= 1e-5] = 0.0 - - input1_device = device_ndarray(input1) - output_device = None - - s2 = Stream() - handle = Handle(stream=s2) - ret_output = pairwise_distance( - input1_device, input1_device, output_device, metric, handle=handle - ) - handle.sync() - - output_device = ret_output - - actual = output_device.copy_to_host() - - actual[actual <= 1e-5] = 0.0 - - assert np.allclose(expected, actual, rtol=1e-4) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py deleted file mode 100644 index 5709ac901c..0000000000 --- a/ci/wheel_smoke_test_raft_dask.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from dask.distributed import Client, get_worker, wait -from dask_cuda import LocalCUDACluster, initialize - -from raft_dask.common import ( - Comms, - local_handle, - perform_test_comm_split, - perform_test_comms_allgather, - perform_test_comms_allreduce, - perform_test_comms_bcast, - perform_test_comms_device_multicast_sendrecv, - perform_test_comms_device_send_or_recv, - perform_test_comms_device_sendrecv, - perform_test_comms_gather, - perform_test_comms_gatherv, - perform_test_comms_reduce, - perform_test_comms_reducescatter, - perform_test_comms_send_recv, -) - -import os -os.environ["UCX_LOG_LEVEL"] = "error" - - -def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId, dask_worker=get_worker()) - return perform_test_comms_send_recv(handle, n_trials) - - -def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId, dask_worker=get_worker()) - return func(handle, root) - - -if __name__ == "__main__": - # initial setup - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) - client = Client(cluster) - - n_trials = 5 - root_location = "client" - - # p2p test for ucx - cb = Comms(comms_p2p=True, verbose=True) - cb.init() - - dfs = [ - client.submit( - func_test_send_recv, - cb.sessionId, - n_trials, - pure=False, - workers=[w], - ) - for w in cb.worker_addresses - ] - - wait(dfs, timeout=5) - - assert list(map(lambda x: x.result(), dfs)) - - cb.destroy() - - # collectives test for nccl - - cb = Comms( - verbose=True, client=client, nccl_root_location=root_location - ) - cb.init() - - for k, v in cb.worker_info(cb.worker_addresses).items(): - - dfs = [ - client.submit( - func_test_collective, - perform_test_comms_allgather, - cb.sessionId, - v["rank"], - pure=False, - workers=[w], - ) - for w in cb.worker_addresses - ] - wait(dfs, timeout=5) - - assert all([x.result() for x in dfs]) - - cb.destroy() - - # final client and cluster teardown - client.close() - cluster.close() diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index e27532a489..590c3eb68b 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -20,12 +20,11 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-cuda==24.4.* +- dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-aarch64=11.* -- gmock>=1.13.0 - graphviz -- gtest>=1.13.0 - ipython - joblib>=0.11 - libcublas-dev=11.11.3.6 @@ -36,6 +35,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 @@ -46,16 +46,14 @@ dependencies: - pydata-sphinx-theme - pytest-cov - pytest==7.* -- rapids-dask-dependency==24.4.* +- rapids-dask-dependency==24.6.* - recommonmark -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - scikit-learn - scipy - sphinx-copybutton - sphinx-markdown-tables - sysroot_linux-aarch64==2.17 -- ucx-proc=*=gpu -- ucx-py==0.37.* -- ucx>=1.15.0,<1.16.0 +- ucx-py==0.38.* name: all_cuda-118_arch-aarch64 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index bf535c5c04..00ed8fa65e 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -20,12 +20,11 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-cuda==24.4.* +- dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-64=11.* -- gmock>=1.13.0 - graphviz -- gtest>=1.13.0 - ipython - joblib>=0.11 - libcublas-dev=11.11.3.6 @@ -36,6 +35,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 @@ -46,16 +46,14 @@ dependencies: - pydata-sphinx-theme - pytest-cov - pytest==7.* -- rapids-dask-dependency==24.4.* +- rapids-dask-dependency==24.6.* - recommonmark -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - scikit-learn - scipy - sphinx-copybutton - sphinx-markdown-tables - sysroot_linux-64==2.17 -- ucx-proc=*=gpu -- ucx-py==0.37.* -- ucx>=1.15.0,<1.16.0 +- ucx-py==0.38.* name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-122_arch-aarch64.yaml b/conda/environments/all_cuda-122_arch-aarch64.yaml index 8ea3843841..f1f346706d 100644 --- a/conda/environments/all_cuda-122_arch-aarch64.yaml +++ b/conda/environments/all_cuda-122_arch-aarch64.yaml @@ -21,18 +21,18 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-cuda==24.4.* +- dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-aarch64=11.* -- gmock>=1.13.0 - graphviz -- gtest>=1.13.0 - ipython - joblib>=0.11 - libcublas-dev - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 @@ -42,16 +42,14 @@ dependencies: - pydata-sphinx-theme - pytest-cov - pytest==7.* -- rapids-dask-dependency==24.4.* +- rapids-dask-dependency==24.6.* - recommonmark -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - scikit-learn - scipy - sphinx-copybutton - sphinx-markdown-tables - sysroot_linux-aarch64==2.17 -- ucx-proc=*=gpu -- ucx-py==0.37.* -- ucx>=1.15.0,<1.16.0 +- ucx-py==0.38.* name: all_cuda-122_arch-aarch64 diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index a3f6f7e99f..505a4f1a97 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -21,18 +21,18 @@ dependencies: - cupy>=12.0.0 - cxx-compiler - cython>=3.0.0 -- dask-cuda==24.4.* +- dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-64=11.* -- gmock>=1.13.0 - graphviz -- gtest>=1.13.0 - ipython - joblib>=0.11 - libcublas-dev - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 @@ -42,16 +42,14 @@ dependencies: - pydata-sphinx-theme - pytest-cov - pytest==7.* -- rapids-dask-dependency==24.4.* +- rapids-dask-dependency==24.6.* - recommonmark -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - scikit-learn - scipy - sphinx-copybutton - sphinx-markdown-tables - sysroot_linux-64==2.17 -- ucx-proc=*=gpu -- ucx-py==0.37.* -- ucx>=1.15.0,<1.16.0 +- ucx-py==0.38.* name: all_cuda-122_arch-x86_64 diff --git a/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml index 0e0385ceeb..7315f82c13 100644 --- a/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml @@ -30,6 +30,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja @@ -38,7 +39,7 @@ dependencies: - openblas - pandas - pyyaml -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - sysroot_linux-aarch64==2.17 name: bench_ann_cuda-118_arch-aarch64 diff --git a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml index dfe76a2948..ff973acc0c 100644 --- a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml @@ -30,6 +30,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja @@ -38,7 +39,7 @@ dependencies: - openblas - pandas - pyyaml -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - sysroot_linux-64==2.17 name: bench_ann_cuda-118_arch-x86_64 diff --git a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml index 0a6567c646..056550fc07 100644 --- a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml @@ -27,6 +27,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja @@ -34,7 +35,7 @@ dependencies: - openblas - pandas - pyyaml -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - sysroot_linux-aarch64==2.17 name: bench_ann_cuda-120_arch-aarch64 diff --git a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml index a89d5317b6..41a48f4a12 100644 --- a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml @@ -27,6 +27,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja @@ -34,7 +35,7 @@ dependencies: - openblas - pandas - pyyaml -- rmm==24.4.* +- rmm==24.6.* - scikit-build-core>=0.7.0 - sysroot_linux-64==2.17 name: bench_ann_cuda-120_arch-x86_64 diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index 9c39da4507..bb9c715e3a 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -10,7 +10,10 @@ cuda_compiler: cuda11_compiler: - nvcc -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" cmake_version: @@ -19,12 +22,6 @@ cmake_version: nccl_version: - ">=2.9.9" -gbench_version: - - "==1.8.0" - -gtest_version: - - ">=1.13.0" - glog_version: - ">=0.6.0" diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 55f326dc53..a075308500 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -58,12 +58,13 @@ outputs: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - cuda-version ={{ cuda_version }} {% if cuda_major != "11" %} - cuda-cudart-dev {% endif %} + - librmm ={{ minor_version }} run: - {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }} {% if cuda_major == "11" %} @@ -93,6 +94,7 @@ outputs: requirements: host: - cuda-version ={{ cuda_version }} + - librmm ={{ minor_version }} run: - {{ pin_subpackage('libraft-headers-only', exact=True) }} - librmm ={{ minor_version }} @@ -150,7 +152,7 @@ outputs: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - {{ pin_subpackage('libraft-headers', exact=True) }} - cuda-version ={{ cuda_version }} @@ -212,7 +214,7 @@ outputs: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - {{ pin_subpackage('libraft-headers', exact=True) }} - cuda-version ={{ cuda_version }} @@ -278,7 +280,7 @@ outputs: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: # We must include both libraft and libraft-static to prevent the test # builds from packaging those libraries. However, tests only depend on @@ -304,9 +306,6 @@ outputs: - libcusolver-dev - libcusparse-dev {% endif %} - - benchmark {{ gbench_version }} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} run: - {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }} {% if cuda_major == "11" %} @@ -319,9 +318,6 @@ outputs: - libcusparse {% endif %} - {{ pin_subpackage('libraft', exact=True) }} - - benchmark {{ gbench_version }} - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} about: home: https://rapids.ai/ license: Apache-2.0 @@ -353,7 +349,7 @@ outputs: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - {{ pin_subpackage('libraft', exact=True) }} - {{ pin_subpackage('libraft-headers', exact=True) }} diff --git a/conda/recipes/pylibraft/conda_build_config.yaml b/conda/recipes/pylibraft/conda_build_config.yaml index e28b98da7f..e3ca633eb9 100644 --- a/conda/recipes/pylibraft/conda_build_config.yaml +++ b/conda/recipes/pylibraft/conda_build_config.yaml @@ -10,7 +10,10 @@ cuda_compiler: cuda11_compiler: - nvcc -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" cmake_version: diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index e524a68f9e..cbeaec3b55 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -39,7 +39,7 @@ requirements: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: {% if cuda_major == "11" %} - cuda-python >=11.7.1,<12.0a0 diff --git a/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml b/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml index 93a5532962..4de3b98f48 100644 --- a/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml +++ b/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml @@ -4,7 +4,10 @@ c_compiler_version: cxx_compiler_version: - 11 -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" cmake_version: diff --git a/conda/recipes/raft-ann-bench-cpu/meta.yaml b/conda/recipes/raft-ann-bench-cpu/meta.yaml index fce85d5ffc..d0748fdb16 100644 --- a/conda/recipes/raft-ann-bench-cpu/meta.yaml +++ b/conda/recipes/raft-ann-bench-cpu/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Usage: # conda build . -c conda-forge -c nvidia -c rapidsai @@ -42,7 +42,7 @@ requirements: - {{ compiler('cxx') }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - glog {{ glog_version }} diff --git a/conda/recipes/raft-ann-bench/conda_build_config.yaml b/conda/recipes/raft-ann-bench/conda_build_config.yaml index da0b893c1d..cf025a06a4 100644 --- a/conda/recipes/raft-ann-bench/conda_build_config.yaml +++ b/conda/recipes/raft-ann-bench/conda_build_config.yaml @@ -10,7 +10,10 @@ cuda_compiler: cuda11_compiler: - nvcc -sysroot_version: +c_stdlib: + - sysroot + +c_stdlib_version: - "2.17" cmake_version: @@ -19,9 +22,6 @@ cmake_version: nccl_version: - ">=2.9.9" -gtest_version: - - ">=1.13.0" - glog_version: - ">=0.6.0" diff --git a/conda/recipes/raft-ann-bench/meta.yaml b/conda/recipes/raft-ann-bench/meta.yaml index ec24501475..8a6a3d033d 100644 --- a/conda/recipes/raft-ann-bench/meta.yaml +++ b/conda/recipes/raft-ann-bench/meta.yaml @@ -57,7 +57,7 @@ requirements: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: - python diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index d2bdcbb351..b157e41753 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -10,14 +10,17 @@ cuda_compiler: cuda11_compiler: - nvcc -sysroot_version: - - "2.17" +c_stdlib: + - sysroot -ucx_version: - - ">=1.15.0,<1.16.0" +c_stdlib_version: + - "2.17" ucx_py_version: - - "0.37.*" + - "0.38.*" + +ucxx_version: + - "0.38.*" cmake_version: - ">=3.26.4" diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index 6910905d07..af22c8853e 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -39,7 +39,7 @@ requirements: - cuda-version ={{ cuda_version }} - cmake {{ cmake_version }} - ninja - - sysroot_{{ target_platform }} {{ sysroot_version }} + - {{ stdlib("c") }} host: {% if cuda_major == "11" %} - cuda-python >=11.7.1,<12.0a0 @@ -56,9 +56,8 @@ requirements: - rmm ={{ minor_version }} - scikit-build-core >=0.7.0 - setuptools - - ucx {{ ucx_version }} - - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} + - ucxx {{ ucxx_version }} run: {% if cuda_major == "11" %} - cudatoolkit @@ -73,9 +72,8 @@ requirements: - pylibraft {{ version }} - python x.x - rmm ={{ minor_version }} - - ucx {{ ucx_version }} - - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} + - distributed-ucxx {{ ucxx_version }} tests: requirements: diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cbae4bfb3f..39472cae67 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -185,12 +185,13 @@ if(NOT BUILD_CPU_ONLY) endif() if(BUILD_TESTS) - include(cmake/thirdparty/get_gtest.cmake) + include(${rapids-cmake-dir}/cpm/gtest.cmake) + rapids_cpm_gtest(BUILD_STATIC) endif() if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) include(${rapids-cmake-dir}/cpm/gbench.cmake) - rapids_cpm_gbench() + rapids_cpm_gbench(BUILD_STATIC) endif() if(BUILD_CAGRA_HNSWLIB) @@ -274,7 +275,7 @@ else() "\" OFF)" [=[ -target_link_libraries(raft::raft INTERFACE $<$:CUDA::nvToolsExt>) +target_link_libraries(raft::raft INTERFACE $<$:CUDA::nvtx3>) target_compile_definitions(raft::raft INTERFACE $<$:NVTX_ENABLED>) ]=] @@ -564,7 +565,6 @@ if(RAFT_COMPILE_LIBRARY) src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu - src/util/memory_pool.cpp ) set_target_properties( raft_objs @@ -650,12 +650,21 @@ rapids_find_generate_module( INSTALL_EXPORT_SET raft-distributed-exports ) -rapids_export_package(BUILD ucx raft-distributed-exports) -rapids_export_package(INSTALL ucx raft-distributed-exports) +rapids_export_package( + BUILD ucxx raft-distributed-exports COMPONENTS ucxx python GLOBAL_TARGETS ucxx::ucxx ucxx::python +) +rapids_export_package( + INSTALL ucxx raft-distributed-exports COMPONENTS ucxx python GLOBAL_TARGETS ucxx::ucxx + ucxx::python +) rapids_export_package(BUILD NCCL raft-distributed-exports) rapids_export_package(INSTALL NCCL raft-distributed-exports) -target_link_libraries(raft_distributed INTERFACE ucx::ucp NCCL::NCCL) +# ucx is a requirement for raft_distributed, but its config is not safe to be found multiple times, +# so rather than exporting a package dependency on it above we rely on consumers to find it +# themselves. Once https://github.com/rapidsai/ucxx/issues/173 is resolved we can export it above +# again. +target_link_libraries(raft_distributed INTERFACE ucx::ucp ucxx::ucxx NCCL::NCCL) # ################################################################################################## # * install targets----------------------------------------------------------- @@ -816,26 +825,26 @@ rapids_export( # * shared test/bench headers ------------------------------------------------ if(BUILD_TESTS OR BUILD_PRIMS_BENCH) - include(internal/CMakeLists.txt) + add_subdirectory(internal) endif() # ################################################################################################## # * build test executable ---------------------------------------------------- if(BUILD_TESTS) - include(test/CMakeLists.txt) + add_subdirectory(test) endif() # ################################################################################################## # * build benchmark executable ----------------------------------------------- if(BUILD_PRIMS_BENCH) - include(bench/prims/CMakeLists.txt) + add_subdirectory(bench/prims/) endif() # ################################################################################################## # * build ann benchmark executable ----------------------------------------------- if(BUILD_ANN_BENCH) - include(bench/ann/CMakeLists.txt) + add_subdirectory(bench/ann/) endif() diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index ee84f7515a..f489cc62c6 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -12,6 +12,8 @@ # the License. # ============================================================================= +list(APPEND CMAKE_MODULE_PATH "${RAFT_SOURCE_DIR}") + # ################################################################################################## # * benchmark options ------------------------------------------------------------------------------ @@ -40,48 +42,26 @@ option(RAFT_ANN_BENCH_SINGLE_EXE find_package(Threads REQUIRED) +set(RAFT_ANN_BENCH_USE_FAISS ON) +set(RAFT_FAISS_ENABLE_GPU ON) +set(RAFT_USE_FAISS_STATIC ON) + if(BUILD_CPU_ONLY) # Include necessary logging dependencies - include(cmake/thirdparty/get_fmt.cmake) - include(cmake/thirdparty/get_spdlog.cmake) - + include(cmake/thirdparty/get_fmt) + include(cmake/thirdparty/get_spdlog) set(RAFT_FAISS_ENABLE_GPU OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF) set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) -else() +elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0.0) # Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled. # https://github.com/rapidsai/raft/issues/1627 - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0.0) - set(RAFT_FAISS_ENABLE_GPU OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT OFF) - set(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ OFF) - set(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT OFF) - set(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ OFF) - set(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT OFF) - else() - set(RAFT_FAISS_ENABLE_GPU ON) - endif() -endif() - -set(RAFT_ANN_BENCH_USE_FAISS OFF) -if(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT - OR RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ - OR RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT - OR RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT - OR RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ - OR RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT -) - set(RAFT_ANN_BENCH_USE_FAISS ON) - set(RAFT_USE_FAISS_STATIC ON) + set(RAFT_FAISS_ENABLE_GPU OFF) endif() set(RAFT_ANN_BENCH_USE_RAFT OFF) @@ -98,21 +78,17 @@ endif() # * Fetch requirements ------------------------------------------------------------- if(RAFT_ANN_BENCH_USE_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) - include(cmake/thirdparty/get_hnswlib.cmake) + include(cmake/thirdparty/get_hnswlib) endif() -include(cmake/thirdparty/get_nlohmann_json.cmake) +include(cmake/thirdparty/get_nlohmann_json) if(RAFT_ANN_BENCH_USE_GGNN) - include(cmake/thirdparty/get_ggnn.cmake) + include(cmake/thirdparty/get_ggnn) endif() if(RAFT_ANN_BENCH_USE_FAISS) - # We need to ensure that faiss has all the conda information. So we currently use the very ugly - # hammer of `link_libraries` to ensure that all targets in this directory and the faiss directory - # will have the conda includes/link dirs - link_libraries($) - include(cmake/thirdparty/get_faiss.cmake) + include(cmake/thirdparty/get_faiss) endif() # ################################################################################################## @@ -173,8 +149,6 @@ function(ConfigureAnnBench) $<$:${RAFT_CTK_MATH_DEPENDENCIES}> $ $ - -static-libgcc - -static-libstdc++ $<$:fmt::fmt-header-only> $<$:spdlog::spdlog_header_only> ) @@ -225,7 +199,7 @@ endfunction() if(RAFT_ANN_BENCH_USE_HNSWLIB) ConfigureAnnBench( - NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib + NAME HNSWLIB PATH src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib ) endif() @@ -235,8 +209,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) NAME RAFT_IVF_PQ PATH - bench/ann/src/raft/raft_benchmark.cu - $<$:bench/ann/src/raft/raft_ivf_pq.cu> + src/raft/raft_benchmark.cu + src/raft/raft_ivf_pq.cu LINKS raft::compiled ) @@ -247,8 +221,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) NAME RAFT_IVF_FLAT PATH - bench/ann/src/raft/raft_benchmark.cu - $<$:bench/ann/src/raft/raft_ivf_flat.cu> + src/raft/raft_benchmark.cu + src/raft/raft_ivf_flat.cu LINKS raft::compiled ) @@ -256,7 +230,7 @@ endif() if(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE) ConfigureAnnBench( - NAME RAFT_BRUTE_FORCE PATH bench/ann/src/raft/raft_benchmark.cu LINKS raft::compiled + NAME RAFT_BRUTE_FORCE PATH src/raft/raft_benchmark.cu LINKS raft::compiled ) endif() @@ -265,8 +239,11 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA) NAME RAFT_CAGRA PATH - bench/ann/src/raft/raft_benchmark.cu - $<$:bench/ann/src/raft/raft_cagra.cu> + src/raft/raft_benchmark.cu + src/raft/raft_cagra_float.cu + src/raft/raft_cagra_half.cu + src/raft/raft_cagra_int8_t.cu + src/raft/raft_cagra_uint8_t.cu LINKS raft::compiled ) @@ -274,76 +251,63 @@ endif() if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ConfigureAnnBench( - NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled + NAME RAFT_CAGRA_HNSWLIB PATH src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled hnswlib::hnswlib ) endif() -set(RAFT_FAISS_TARGETS faiss::faiss) -if(TARGET faiss::faiss_avx2) - set(RAFT_FAISS_TARGETS faiss::faiss_avx2) -endif() - message("RAFT_FAISS_TARGETS: ${RAFT_FAISS_TARGETS}") message("CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}") if(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT) ConfigureAnnBench( - NAME FAISS_CPU_FLAT PATH bench/ann/src/faiss/faiss_cpu_benchmark.cpp LINKS + NAME FAISS_CPU_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${RAFT_FAISS_TARGETS} ) endif() if(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT) ConfigureAnnBench( - NAME FAISS_CPU_IVF_FLAT PATH bench/ann/src/faiss/faiss_cpu_benchmark.cpp LINKS + NAME FAISS_CPU_IVF_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${RAFT_FAISS_TARGETS} ) endif() if(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ) ConfigureAnnBench( - NAME FAISS_CPU_IVF_PQ PATH bench/ann/src/faiss/faiss_cpu_benchmark.cpp LINKS + NAME FAISS_CPU_IVF_PQ PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${RAFT_FAISS_TARGETS} ) endif() -if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT) +if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT AND RAFT_FAISS_ENABLE_GPU) ConfigureAnnBench( - NAME FAISS_GPU_IVF_FLAT PATH bench/ann/src/faiss/faiss_gpu_benchmark.cu LINKS + NAME FAISS_GPU_IVF_FLAT PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${RAFT_FAISS_TARGETS} ) endif() -if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ) +if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ AND RAFT_FAISS_ENABLE_GPU) ConfigureAnnBench( - NAME FAISS_GPU_IVF_PQ PATH bench/ann/src/faiss/faiss_gpu_benchmark.cu LINKS + NAME FAISS_GPU_IVF_PQ PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${RAFT_FAISS_TARGETS} ) endif() -if(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT) +if(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT AND RAFT_FAISS_ENABLE_GPU) ConfigureAnnBench( - NAME FAISS_GPU_FLAT PATH bench/ann/src/faiss/faiss_gpu_benchmark.cu LINKS ${RAFT_FAISS_TARGETS} + NAME FAISS_GPU_FLAT PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${RAFT_FAISS_TARGETS} ) endif() if(RAFT_ANN_BENCH_USE_GGNN) - include(cmake/thirdparty/get_glog.cmake) - ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn) + include(cmake/thirdparty/get_glog) + ConfigureAnnBench(NAME GGNN PATH src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn) endif() # ################################################################################################## # * Dynamically-loading ANN_BENCH executable ------------------------------------------------------- if(RAFT_ANN_BENCH_SINGLE_EXE) - add_executable(ANN_BENCH bench/ann/src/common/benchmark.cpp) - - # Build and link static version of the GBench to keep ANN_BENCH self-contained. - get_target_property(TMP_PROP benchmark::benchmark SOURCES) - add_library(benchmark_static STATIC ${TMP_PROP}) - get_target_property(TMP_PROP benchmark::benchmark INCLUDE_DIRECTORIES) - target_include_directories(benchmark_static PUBLIC ${TMP_PROP}) - get_target_property(TMP_PROP benchmark::benchmark LINK_LIBRARIES) - target_link_libraries(benchmark_static PUBLIC ${TMP_PROP}) + add_executable(ANN_BENCH src/common/benchmark.cpp) target_include_directories(ANN_BENCH PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -351,7 +315,7 @@ if(RAFT_ANN_BENCH_SINGLE_EXE) ANN_BENCH PRIVATE raft::raft nlohmann_json::nlohmann_json - benchmark_static + benchmark::benchmark dl -static-libgcc fmt::fmt-header-only diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index c6213059dc..b010063dee 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -73,6 +73,8 @@ struct AlgoProperty { class AnnBase { public: + using index_type = size_t; + inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {} virtual ~AnnBase() noexcept = default; @@ -98,7 +100,16 @@ class AnnGPU { * end. */ [[nodiscard]] virtual auto get_sync_stream() const noexcept -> cudaStream_t = 0; - virtual ~AnnGPU() noexcept = default; + /** + * By default a GPU algorithm uses a fixed stream to order GPU operations. + * However, an algorithm may need to synchronize with the host at the end of its execution. + * In that case, also synchronizing with a benchmark event would put it at disadvantage. + * + * We can disable event sync by passing `false` here + * - ONLY IF THE ALGORITHM HAS PRODUCED ITS OUTPUT BY THE TIME IT SYNCHRONIZES WITH CPU. + */ + [[nodiscard]] virtual auto uses_stream() const noexcept -> bool { return true; } + virtual ~AnnGPU() noexcept = default; }; template @@ -118,8 +129,11 @@ class ANN : public AnnBase { virtual void set_search_param(const AnnSearchParam& param) = 0; // TODO: this assumes that an algorithm can always return k results. // This is not always possible. - virtual void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0; + virtual void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const = 0; virtual void save(const std::string& file) const = 0; virtual void load(const std::string& file) = 0; diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index d7bcd17a00..8762ccd1fe 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -280,10 +280,16 @@ void bench_search(::benchmark::State& state, /** * Each thread will manage its own outputs */ - std::shared_ptr> distances = - std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); - std::shared_ptr> neighbors = - std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); + using index_type = AnnBase::index_type; + constexpr size_t kAlignResultBuf = 64; + size_t result_elem_count = k * query_set_size; + result_elem_count = + ((result_elem_count + kAlignResultBuf - 1) / kAlignResultBuf) * kAlignResultBuf; + auto& result_buf = + get_result_buffer_from_global_pool(result_elem_count * (sizeof(float) + sizeof(index_type))); + auto* neighbors_ptr = + reinterpret_cast(result_buf.data(current_algo_props->query_memory_type)); + auto* distances_ptr = reinterpret_cast(neighbors_ptr + result_elem_count); { nvtx_case nvtx{state.name()}; @@ -305,8 +311,8 @@ void bench_search(::benchmark::State& state, algo->search(query_set + batch_offset * dataset->dim(), n_queries, k, - neighbors->data + out_offset * k, - distances->data + out_offset * k); + neighbors_ptr + out_offset * k, + distances_ptr + out_offset * k); } catch (const std::exception& e) { state.SkipWithError("Benchmark loop: " + std::string(e.what())); break; @@ -338,12 +344,13 @@ void bench_search(::benchmark::State& state, // Each thread calculates recall on their partition of queries. // evaluate recall if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - buf neighbors_host = neighbors->move(MemoryType::Host); - std::size_t rows = std::min(queries_processed, query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); + const std::int32_t* gt = dataset->gt_set(); + const std::uint32_t max_k = dataset->max_k(); + result_buf.transfer_data(MemoryType::Host, current_algo_props->query_memory_type); + auto* neighbors_host = reinterpret_cast(result_buf.data(MemoryType::Host)); + std::size_t rows = std::min(queries_processed, query_set_size); + std::size_t match_count = 0; + std::size_t total_count = rows * static_cast(k); // We go through the groundtruth with same stride as the benchmark loop. size_t out_offset = 0; @@ -354,7 +361,7 @@ void bench_search(::benchmark::State& state, size_t i_out_idx = out_offset + i; if (i_out_idx < rows) { for (std::uint32_t j = 0; j < k; j++) { - auto act_idx = std::int32_t(neighbors_host.data[i_out_idx * k + j]); + auto act_idx = std::int32_t(neighbors_host[i_out_idx * k + j]); for (std::uint32_t l = 0; l < k; l++) { auto exp_idx = gt[i_orig_idx * max_k + l]; if (act_idx == exp_idx) { @@ -717,7 +724,7 @@ inline auto run_main(int argc, char** argv) -> int // to a shared library it depends on (dynamic benchmark executable). current_algo.reset(); current_algo_props.reset(); - reset_global_stream_pool(); + reset_global_device_resources(); return 0; } }; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/util.hpp b/cpp/bench/ann/src/common/util.hpp index 6cdff316e9..96185c79eb 100644 --- a/cpp/bench/ann/src/common/util.hpp +++ b/cpp/bench/ann/src/common/util.hpp @@ -56,57 +56,6 @@ inline thread_local int benchmark_thread_id = 0; */ inline thread_local int benchmark_n_threads = 1; -template -struct buf { - MemoryType memory_type; - std::size_t size; - T* data; - buf(MemoryType memory_type, std::size_t size) - : memory_type(memory_type), size(size), data(nullptr) - { - switch (memory_type) { -#ifndef BUILD_CPU_ONLY - case MemoryType::Device: { - cudaMalloc(reinterpret_cast(&data), size * sizeof(T)); - cudaMemset(data, 0, size * sizeof(T)); - } break; -#endif - default: { - data = reinterpret_cast(malloc(size * sizeof(T))); - std::memset(data, 0, size * sizeof(T)); - } - } - } - ~buf() noexcept - { - if (data == nullptr) { return; } - switch (memory_type) { -#ifndef BUILD_CPU_ONLY - case MemoryType::Device: { - cudaFree(data); - } break; -#endif - default: { - free(data); - } - } - } - - [[nodiscard]] auto move(MemoryType target_memory_type) -> buf - { - buf r{target_memory_type, size}; -#ifndef BUILD_CPU_ONLY - if ((memory_type == MemoryType::Device && target_memory_type != MemoryType::Device) || - (memory_type != MemoryType::Device && target_memory_type == MemoryType::Device)) { - cudaMemcpy(r.data, data, size * sizeof(T), cudaMemcpyDefault); - return r; - } -#endif - std::swap(data, r.data); - return r; - } -}; - struct cuda_timer { private: std::optional stream_; @@ -118,7 +67,9 @@ struct cuda_timer { static inline auto extract_stream(AnnT* algo) -> std::optional { auto gpu_ann = dynamic_cast(algo); - if (gpu_ann != nullptr) { return std::make_optional(gpu_ann->get_sync_stream()); } + if (gpu_ann != nullptr && gpu_ann->uses_stream()) { + return std::make_optional(gpu_ann->get_sync_stream()); + } return std::nullopt; } @@ -242,16 +193,102 @@ inline auto get_stream_from_global_pool() -> cudaStream_t #endif } +struct result_buffer { + explicit result_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream} + { + if (size_ == 0) { return; } + data_host_ = malloc(size_); +#ifndef BUILD_CPU_ONLY + cudaMallocAsync(&data_device_, size_, stream_); + cudaStreamSynchronize(stream_); +#endif + } + result_buffer() = delete; + result_buffer(result_buffer&&) = delete; + result_buffer& operator=(result_buffer&&) = delete; + result_buffer(const result_buffer&) = delete; + result_buffer& operator=(const result_buffer&) = delete; + ~result_buffer() noexcept + { + if (size_ == 0) { return; } +#ifndef BUILD_CPU_ONLY + cudaFreeAsync(data_device_, stream_); + cudaStreamSynchronize(stream_); +#endif + free(data_host_); + } + + [[nodiscard]] auto size() const noexcept { return size_; } + [[nodiscard]] auto data(ann::MemoryType loc) const noexcept + { + switch (loc) { + case MemoryType::Device: return data_device_; + default: return data_host_; + } + } + + void transfer_data(ann::MemoryType dst, ann::MemoryType src) + { + auto dst_ptr = data(dst); + auto src_ptr = data(src); + if (dst_ptr == src_ptr) { return; } +#ifndef BUILD_CPU_ONLY + cudaMemcpyAsync(dst_ptr, src_ptr, size_, cudaMemcpyDefault, stream_); + cudaStreamSynchronize(stream_); +#endif + } + + private: + size_t size_{0}; + cudaStream_t stream_ = nullptr; + void* data_host_ = nullptr; + void* data_device_ = nullptr; +}; + +namespace detail { +inline std::vector> global_result_buffer_pool(0); +inline std::mutex grp_mutex; +} // namespace detail + +/** + * Get a result buffer associated with the current benchmark thread. + * + * Note, the allocations are reused between the benchmark cases. + * This reduces the setup overhead and number of times the context is being blocked + * (this is relevant if there is a persistent kernel running across multiples benchmark cases). + */ +inline auto get_result_buffer_from_global_pool(size_t size) -> result_buffer& +{ + auto stream = get_stream_from_global_pool(); + auto& rb = [stream, size]() -> result_buffer& { + std::lock_guard guard(detail::grp_mutex); + if (static_cast(detail::global_result_buffer_pool.size()) < benchmark_n_threads) { + detail::global_result_buffer_pool.resize(benchmark_n_threads); + } + auto& rb = detail::global_result_buffer_pool[benchmark_thread_id]; + if (!rb || rb->size() < size) { rb = std::make_unique(size, stream); } + return *rb; + }(); + + memset(rb.data(MemoryType::Host), 0, size); +#ifndef BUILD_CPU_ONLY + cudaMemsetAsync(rb.data(MemoryType::Device), 0, size, stream); + cudaStreamSynchronize(stream); +#endif + return rb; +} + /** - * Delete all streams in the global pool. + * Delete all streams and memory allocations in the global pool. * It's called at the end of the `main` function - before global/static variables and cuda context * is destroyed - to make sure they are destroyed gracefully and correctly seen by analysis tools * such as nsys. */ -inline void reset_global_stream_pool() +inline void reset_global_device_resources() { #ifndef BUILD_CPU_ONLY std::lock_guard guard(detail::gsp_mutex); + detail::global_result_buffer_pool.resize(0); detail::global_stream_pool.resize(0); #endif } diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 407f7148df..3caca15b7f 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -88,8 +88,11 @@ class FaissCpu : public ANN { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; AlgoProperty get_preference() const override { @@ -169,7 +172,7 @@ void FaissCpu::set_search_param(const AnnSearchParam& param) template void FaissCpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index 633098fd1d..2effe631e5 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -111,8 +111,11 @@ class FaissGpu : public ANN, public AnnGPU { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -196,7 +199,7 @@ void FaissGpu::build(const T* dataset, size_t nrow) template void FaissGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh index c89f02d974..59cf3df806 100644 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh @@ -58,8 +58,11 @@ class Ggnn : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); } void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); } - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override { impl_->search(queries, batch_size, k, neighbors, distances); } @@ -123,8 +126,11 @@ class GgnnImpl : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; } void save(const std::string& file) const override; @@ -243,7 +249,7 @@ void GgnnImpl::set_search_param(const AnnSearc template void GgnnImpl::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different"); if (k != KQuery) { diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index a8f7dd824f..5743632bf4 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -79,8 +79,11 @@ class HnswLib : public ANN { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const override; + void search(const T* query, + int batch_size, + int k, + AnnBase::index_type* indices, + float* distances) const override; void save(const std::string& path_to_index) const override; void load(const std::string& path_to_index) override; @@ -97,7 +100,10 @@ class HnswLib : public ANN { void set_base_layer_only() { appr_alg_->base_layer_only = true; } private: - void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; + void get_search_knn_results_(const T* query, + int k, + AnnBase::index_type* indices, + float* distances) const; std::shared_ptr::type>> appr_alg_; std::shared_ptr::type>> space_; @@ -176,7 +182,7 @@ void HnswLib::set_search_param(const AnnSearchParam& param_) template void HnswLib::search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const + const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const { auto f = [&](int i) { // hnsw can only handle a single vector at a time. @@ -217,7 +223,7 @@ void HnswLib::load(const std::string& path_to_index) template void HnswLib::get_search_knn_results_(const T* query, int k, - size_t* indices, + AnnBase::index_type* indices, float* distances) const { auto result = appr_alg_->searchKnn(query, k); diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 40c1ecfa5e..9b086fdb23 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -19,14 +19,19 @@ #include #include +#include +#include #include #include #include +#include #include #include #include +#include #include +#include #include #include @@ -70,13 +75,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool */ class shared_raft_resources { public: - using pool_mr_type = rmm::mr::pool_memory_resource; - using mr_type = rmm::mr::failure_callback_resource_adaptor; + using pool_mr_type = rmm::mr::pool_memory_resource; + using mr_type = rmm::mr::failure_callback_resource_adaptor; + using large_mr_type = rmm::mr::managed_memory_resource; shared_raft_resources() try : orig_resource_{rmm::mr::get_current_device_resource()}, pool_resource_(orig_resource_, 1024 * 1024 * 1024ull), - resource_(&pool_resource_, rmm_oom_callback, nullptr) { + resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() { rmm::mr::set_current_device_resource(&resource_); } catch (const std::exception& e) { auto cuda_status = cudaGetLastError(); @@ -99,10 +105,16 @@ class shared_raft_resources { ~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); } + auto get_large_memory_resource() noexcept + { + return static_cast(&large_mr_); + } + private: rmm::mr::device_memory_resource* orig_resource_; pool_mr_type pool_resource_; mr_type resource_; + large_mr_type large_mr_; }; /** @@ -121,8 +133,16 @@ class configured_raft_resources { * It's used by the copy constructor. */ explicit configured_raft_resources(const std::shared_ptr& shared_res) - : shared_res_{shared_res}, res_{rmm::cuda_stream_view(get_stream_from_global_pool())} + : shared_res_{shared_res}, + res_{std::make_unique( + rmm::cuda_stream_view(get_stream_from_global_pool()))} { + // set the large workspace resource to the raft handle, but without the deleter + // (this resource is managed by the shared_res). + raft::resource::set_large_workspace_resource( + *res_, + std::shared_ptr(shared_res_->get_large_memory_resource(), + raft::void_op{})); } /** Default constructor creates all resources anew. */ @@ -130,9 +150,9 @@ class configured_raft_resources { { } - configured_raft_resources(configured_raft_resources&&) = default; - configured_raft_resources& operator=(configured_raft_resources&&) = default; - ~configured_raft_resources() = default; + configured_raft_resources(configured_raft_resources&&); + configured_raft_resources& operator=(configured_raft_resources&&); + ~configured_raft_resources() = default; configured_raft_resources(const configured_raft_resources& res) : configured_raft_resources{res.shared_res_} { @@ -143,11 +163,11 @@ class configured_raft_resources { return *this; } - operator raft::resources&() noexcept { return res_; } - operator const raft::resources&() const noexcept { return res_; } + operator raft::resources&() noexcept { return *res_; } + operator const raft::resources&() const noexcept { return *res_; } /** Get the main stream */ - [[nodiscard]] auto get_sync_stream() const noexcept { return resource::get_cuda_stream(res_); } + [[nodiscard]] auto get_sync_stream() const noexcept { return resource::get_cuda_stream(*res_); } private: /** The resources shared among multiple raft handles / threads. */ @@ -156,7 +176,80 @@ class configured_raft_resources { * Until we make the use of copies of raft::resources thread-safe, each benchmark wrapper must * have its own copy of it. */ - raft::device_resources res_; + std::unique_ptr res_ = std::make_unique(); }; +inline configured_raft_resources::configured_raft_resources(configured_raft_resources&&) = default; +inline configured_raft_resources& configured_raft_resources::operator=( + configured_raft_resources&&) = default; + +/** A helper to refine the neighbors when the data is on device or on host. */ +template +void refine_helper(const raft::resources& res, + DatasetT dataset, + QueriesT queries, + CandidatesT candidates, + int k, + AnnBase::index_type* neighbors, + float* distances, + raft::distance::DistanceType metric) +{ + using data_type = typename DatasetT::value_type; + using index_type = AnnBase::index_type; + using extents_type = index_type; // device-side refine requires this + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + extents_type batch_size = queries.extent(0); + extents_type dim = queries.extent(1); + extents_type k0 = candidates.extent(1); + + if (raft::get_device_for_address(dataset.data_handle()) >= 0) { + auto dataset_device = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_device = raft::make_device_matrix_view( + queries.data_handle(), batch_size, dim); + auto candidates_device = raft::make_device_matrix_view( + candidates.data_handle(), batch_size, k0); + auto neighbors_device = + raft::make_device_matrix_view(neighbors, batch_size, k); + auto distances_device = + raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::refine(res, + dataset_device, + queries_device, + candidates_device, + neighbors_device, + distances_device, + metric); + } else { + auto dataset_host = raft::make_host_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_host = raft::make_host_matrix(batch_size, dim); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + auto stream = resource::get_cuda_stream(res); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + metric); + + raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + } +} + } // namespace raft::bench::ann diff --git a/cpp/include/raft/util/memory_pool-ext.hpp b/cpp/bench/ann/src/raft/raft_cagra_float.cu similarity index 63% rename from cpp/include/raft/util/memory_pool-ext.hpp rename to cpp/bench/ann/src/raft/raft_cagra_float.cu index 030a9c681e..058f5bf34a 100644 --- a/cpp/include/raft/util/memory_pool-ext.hpp +++ b/cpp/bench/ann/src/raft/raft_cagra_float.cu @@ -13,16 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "raft_cagra_wrapper.h" -#pragma once -#include // rmm::mr::device_memory_resource - -#include // size_t -#include // std::unique_ptr - -namespace raft { - -std::unique_ptr get_pool_memory_resource( - rmm::mr::device_memory_resource*& mr, size_t initial_size); - -} // namespace raft +namespace raft::bench::ann { +template class RaftCagra; +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_half.cu b/cpp/bench/ann/src/raft/raft_cagra_half.cu new file mode 100644 index 0000000000..a015819ec5 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra_half.cu @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "raft_cagra_wrapper.h" + +namespace raft::bench::ann { +template class RaftCagra; +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu index 709b08db76..d9ef1d74a3 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu @@ -20,6 +20,7 @@ #include #include +#include #define JSON_DIAGNOSTICS 1 #include @@ -89,10 +90,11 @@ int main(int argc, char** argv) // and is initially sized to half of free device memory. rmm::mr::pool_memory_resource pool_mr{ &cuda_mr, rmm::percent_of_free_device_memory(50)}; - rmm::mr::set_current_device_resource( - &pool_mr); // Updates the current device resource pointer to `pool_mr` - rmm::mr::device_memory_resource* mr = - rmm::mr::get_current_device_resource(); // Points to `pool_mr` - return raft::bench::ann::run_main(argc, argv); + // Updates the current device resource pointer to `pool_mr` + auto old_mr = rmm::mr::set_current_device_resource(&pool_mr); + auto ret = raft::bench::ann::run_main(argc, argv); + // Restores the current device resource pointer to its previous value + rmm::mr::set_current_device_resource(old_mr); + return ret; } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index ed9c120ed4..1c4b847d1a 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -99,7 +100,7 @@ void RaftCagraHnswlib::load(const std::string& file) template void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { hnswlib_search_.search(queries, batch_size, k, neighbors, distances); } diff --git a/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu b/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu new file mode 100644 index 0000000000..be3b83ee60 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "raft_cagra_wrapper.h" + +namespace raft::bench::ann { +template class RaftCagra; +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra.cu b/cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu similarity index 85% rename from cpp/bench/ann/src/raft/raft_cagra.cu rename to cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu index c0c1352a43..c9679e404d 100644 --- a/cpp/bench/ann/src/raft/raft_cagra.cu +++ b/cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu @@ -17,7 +17,4 @@ namespace raft::bench::ann { template class RaftCagra; -template class RaftCagra; -template class RaftCagra; -template class RaftCagra; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 70fd22001e..0b892dec35 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include @@ -96,12 +96,16 @@ class RaftCagra : public ANN, public AnnGPU { void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; - void search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -138,7 +142,7 @@ class RaftCagra : public ANN, public AnnGPU { std::shared_ptr> dataset_; std::shared_ptr> input_dataset_v_; - inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type) + inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type) { switch (mem_type) { case (AllocatorType::HostPinned): return &mr_pinned_; @@ -272,15 +276,18 @@ std::unique_ptr> RaftCagra::copy() template void RaftCagra::search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + IdxT* neighbors_IdxT; - rmm::device_uvector neighbors_storage(0, resource::get_cuda_stream(handle_)); - if constexpr (std::is_same_v) { - neighbors_IdxT = neighbors; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); } else { - neighbors_storage.resize(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage.data(); + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); } auto queries_view = @@ -291,76 +298,36 @@ void RaftCagra::search_base( raft::neighbors::cagra::search( handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - if constexpr (!std::is_same_v) { + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { raft::linalg::unaryOp(neighbors, neighbors_IdxT, batch_size * k, - raft::cast_op(), + raft::cast_op(), raft::resource::get_cuda_stream(handle_)); } } template void RaftCagra::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto k0 = static_cast(refine_ratio_ * k); const bool disable_refinement = k0 <= static_cast(k); const raft::resources& res = handle_; - auto stream = resource::get_cuda_stream(res); if (disable_refinement) { search_base(queries, batch_size, k, neighbors, distances); } else { - auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); - search_base(queries, - batch_size, - k0, - reinterpret_cast(candidate_ixs.data_handle()), - candidate_dists.data_handle()); - - if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto neighours_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::refine( - res, - *input_dataset_v_, - queries_v, - raft::make_const_mdspan(candidate_ixs.view()), - neighours_v, - distances_v, - index_->metric()); - } else { - auto dataset_host = raft::make_host_matrix_view( - input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dimension_); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); - - raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(neighbors, - reinterpret_cast(neighbors_host.data_handle()), - neighbors_host.size(), - stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 7f2996d77a..83a3a63aba 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -61,10 +61,11 @@ class RaftIvfFlatGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -131,10 +132,34 @@ std::unique_ptr> RaftIvfFlatGpu::copy() template void RaftIvfFlatGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { - static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); - raft::neighbors::ivf_flat::search( - handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances); + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); + } else { + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } + raft::neighbors::ivf_flat::search(handle_, + search_params_, + *index_, + queries, + batch_size, + k, + neighbors_IdxT, + distances, + resource::get_workspace_resource(handle_)); + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 5d8b682264..7201467969 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -32,9 +32,6 @@ #include #include -#include -#include - #include namespace raft::bench::ann { @@ -64,10 +61,16 @@ class RaftIvfPQ : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -140,68 +143,61 @@ void RaftIvfPQ::set_search_dataset(const T* dataset, size_t nrow) dataset_ = raft::make_device_matrix_view(dataset, nrow, index_->dim()); } +template +void RaftIvfPQ::search_base( + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const +{ + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); + } else { + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } + + auto queries_view = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighbors_view = + raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); + auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::ivf_pq::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); + + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } +} + template void RaftIvfPQ::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { - if (refine_ratio_ > 1.0f) { - uint32_t k0 = static_cast(refine_ratio_ * k); - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto distances_tmp = raft::make_device_matrix(handle_, batch_size, k0); - auto candidates = raft::make_device_matrix(handle_, batch_size, k0); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, candidates.view(), distances_tmp.view()); - - if (raft::get_device_for_address(dataset_.data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::refine(handle_, - dataset_, - queries_v, - candidates.view(), - neighbors_v, - distances_v, - index_->metric()); - } else { - auto queries_host = raft::make_host_matrix(batch_size, index_->dim()); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - auto stream = resource::get_cuda_stream(handle_); - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); - - auto dataset_v = raft::make_host_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1)); - - raft::resource::sync_stream(handle_); // wait for the queries and candidates - raft::neighbors::refine(handle_, - dataset_v, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + auto k0 = static_cast(refine_ratio_ * k); + const bool disable_refinement = k0 <= static_cast(k); + const raft::resources& res = handle_; + + if (disable_refinement) { + search_base(queries, batch_size, k, neighbors, distances); } else { auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = - raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, dataset_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_wrapper.h b/cpp/bench/ann/src/raft/raft_wrapper.h index 586b81ae06..2c996058b2 100644 --- a/cpp/bench/ann/src/raft/raft_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_wrapper.h @@ -56,10 +56,11 @@ class RaftGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; // to enable dataset access from GPU memory AlgoProperty get_preference() const override @@ -133,15 +134,16 @@ void RaftGpu::load(const std::string& file) template void RaftGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto queries_view = raft::make_device_matrix_view(queries, batch_size, this->dim_); - auto neighbors_view = raft::make_device_matrix_view(neighbors, batch_size, k); + auto neighbors_view = + raft::make_device_matrix_view(neighbors, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::brute_force::search( + raft::neighbors::brute_force::search( handle_, *index_, queries_view, neighbors_view, distances_view); } diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 9f23c44a5c..0771a60e58 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -75,31 +75,31 @@ endfunction() if(BUILD_PRIMS_BENCH) ConfigureBench( - NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/core/copy.cu bench/prims/main.cpp + NAME CORE_BENCH PATH core/bitset.cu core/copy.cu main.cpp ) ConfigureBench( - NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu - bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY + NAME CLUSTER_BENCH PATH cluster/kmeans_balanced.cu cluster/kmeans.cu + main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( - NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu - bench/prims/distance/tune_pairwise/bench.cu bench/prims/main.cpp + NAME TUNE_DISTANCE PATH distance/tune_pairwise/kernel.cu + distance/tune_pairwise/bench.cu main.cpp ) ConfigureBench( NAME DISTANCE_BENCH PATH - bench/prims/distance/distance_cosine.cu - bench/prims/distance/distance_exp_l2.cu - bench/prims/distance/distance_l1.cu - bench/prims/distance/distance_unexp_l2.cu - bench/prims/distance/fused_l2_nn.cu - bench/prims/distance/masked_nn.cu - bench/prims/distance/kernels.cu - bench/prims/main.cpp + distance/distance_cosine.cu + distance/distance_exp_l2.cu + distance/distance_l1.cu + distance/distance_unexp_l2.cu + distance/fused_l2_nn.cu + distance/masked_nn.cu + distance/kernels.cu + main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY @@ -109,63 +109,64 @@ if(BUILD_PRIMS_BENCH) NAME LINALG_BENCH PATH - bench/prims/linalg/add.cu - bench/prims/linalg/map_then_reduce.cu - bench/prims/linalg/matrix_vector_op.cu - bench/prims/linalg/norm.cu - bench/prims/linalg/normalize.cu - bench/prims/linalg/reduce_cols_by_key.cu - bench/prims/linalg/reduce_rows_by_key.cu - bench/prims/linalg/reduce.cu - bench/prims/linalg/sddmm.cu - bench/prims/main.cpp + linalg/add.cu + linalg/map_then_reduce.cu + linalg/matrix_vector_op.cu + linalg/norm.cu + linalg/normalize.cu + linalg/reduce_cols_by_key.cu + linalg/reduce_rows_by_key.cu + linalg/reduce.cu + linalg/sddmm.cu + main.cpp ) ConfigureBench( - NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH matrix/argmin.cu matrix/gather.cu + matrix/select_k.cu main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( - NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp + NAME RANDOM_BENCH PATH random/make_blobs.cu random/permute.cu + random/rng.cu random/subsample.cu main.cpp ) ConfigureBench( NAME SPARSE_BENCH PATH - bench/prims/sparse/bitmap_to_csr.cu - bench/prims/sparse/convert_csr.cu - bench/prims/main.cpp + sparse/bitmap_to_csr.cu + sparse/convert_csr.cu + sparse/select_k_csr.cu + main.cpp ) ConfigureBench( NAME NEIGHBORS_BENCH PATH - bench/prims/neighbors/knn/brute_force_float_int64_t.cu - bench/prims/neighbors/knn/brute_force_float_uint32_t.cu - bench/prims/neighbors/knn/cagra_float_uint32_t.cu - bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu - bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu - bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu - bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu - bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu - bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu - bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu - bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu - src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu - bench/prims/neighbors/refine_float_int64_t.cu - bench/prims/neighbors/refine_uint8_t_int64_t.cu - bench/prims/main.cpp + neighbors/knn/brute_force_float_int64_t.cu + neighbors/knn/brute_force_float_uint32_t.cu + neighbors/knn/cagra_float_uint32_t.cu + neighbors/knn/ivf_flat_filter_float_int64_t.cu + neighbors/knn/ivf_flat_float_int64_t.cu + neighbors/knn/ivf_flat_int8_t_int64_t.cu + neighbors/knn/ivf_flat_uint8_t_int64_t.cu + neighbors/knn/ivf_pq_float_int64_t.cu + neighbors/knn/ivf_pq_filter_float_int64_t.cu + neighbors/knn/ivf_pq_int8_t_int64_t.cu + neighbors/knn/ivf_pq_uint8_t_int64_t.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu + ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu + neighbors/refine_float_int64_t.cu + neighbors/refine_uint8_t_int64_t.cu + main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 4ecad6df3d..3ce43cc1e7 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index 078f9e6198..876e47525c 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -24,6 +24,7 @@ #include #include +#include #include namespace raft::bench::matrix { diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index aea7168142..6499078623 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -27,10 +27,12 @@ #include #include +#include #include #include #include #include +#include #include @@ -101,7 +103,7 @@ struct device_resource { if (managed_) { delete res_; } } - [[nodiscard]] auto get() const -> rmm::mr::device_memory_resource* { return res_; } + [[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; } private: const bool managed_; @@ -158,8 +160,15 @@ struct ivf_flat_knn { IdxT* out_idxs) { search_params.n_probes = 20; - raft::neighbors::ivf_flat::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + raft::neighbors::ivf_flat::search(handle, + search_params, + *index, + search_items, + ps.n_queries, + ps.k, + out_idxs, + out_dists, + resource::get_workspace_resource(handle)); } }; diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu index 4c8ca2bf31..70a9c65e0d 100644 --- a/cpp/bench/prims/random/subsample.cu +++ b/cpp/bench/prims/random/subsample.cu @@ -27,6 +27,7 @@ #include #include +#include #include #include diff --git a/cpp/bench/prims/sparse/select_k_csr.cu b/cpp/bench/prims/sparse/select_k_csr.cu new file mode 100644 index 0000000000..a91e6c8514 --- /dev/null +++ b/cpp/bench/prims/sparse/select_k_csr.cu @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; + bool select_min = true; + bool customized_indices = false; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << params.n_rows << "#" << params.n_cols << "#" << params.top_k << "#" << params.sparsity; + return os; +} + +template +struct SelectKCsrTest : public fixture { + SelectKCsrTest(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + values_d(0, stream), + indptr_d(0, stream), + indices_d(0, stream), + customized_indices_d(0, stream), + dst_values_d(0, stream), + dst_indices_d(0, stream) + { + std::vector dense_values_h(params.n_rows * params.n_cols); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector indices_h(nnz); + std::vector customized_indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols * 100)); + + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + values_d.resize(nnz, stream); + + if (nnz) { + auto blobs_values = raft::make_device_matrix(handle, 1, nnz); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_values.data_handle(), + labels.data_handle(), + 1, + nnz, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-10.0f), + value_t(10.0f), + uint64_t(2024)); + raft::copy(values_d.data(), blobs_values.data_handle(), nnz, stream); + resource::sync_stream(handle); + } + + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + + if (params.customized_indices) { + customized_indices_d.resize(nnz, stream); + update_device(customized_indices_d.data(), + customized_indices_h.data(), + customized_indices_h.size(), + stream); + } + } + + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + template + std::optional get_opt_var(data_t x) + { + if (params.customized_indices) { + return x; + } else { + return std::nullopt; + } + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto in_val_structure = raft::make_device_compressed_structure_view( + indptr_d.data(), + indices_d.data(), + params.n_rows, + params.n_cols, + static_cast(indices_d.size())); + + auto in_val = + raft::make_device_csr_matrix_view(values_d.data(), in_val_structure); + + std::optional> in_idx; + + in_idx = get_opt_var( + raft::make_device_vector_view(customized_indices_d.data(), nnz)); + + auto out_val = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto out_idx = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::sparse::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + resource::sync_stream(handle); + loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { + raft::sparse::matrix::select_k( + handle, in_val, in_idx, out_val, out_idx, params.select_min, false); + resource::sync_stream(handle); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector customized_indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_indices_d; +}; // struct SelectKCsrTest + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + index_t k; + }; + + const std::vector params_group{ + {20000, 500, 1}, {20000, 500, 2}, {20000, 500, 4}, {20000, 500, 8}, + {20000, 500, 16}, {20000, 500, 32}, {20000, 500, 64}, {20000, 500, 128}, + {20000, 500, 256}, + + {1000, 10000, 1}, {1000, 10000, 2}, {1000, 10000, 4}, {1000, 10000, 8}, + {1000, 10000, 16}, {1000, 10000, 32}, {1000, 10000, 64}, {1000, 10000, 128}, + {1000, 10000, 256}, + + {100, 100000, 1}, {100, 100000, 2}, {100, 100000, 4}, {100, 100000, 8}, + {100, 100000, 16}, {100, 100000, 32}, {100, 100000, 64}, {100, 100000, 128}, + {100, 100000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}}; + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.1})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.2})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.5})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index ea8a077b0c..b364d8418d 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,8 +13,8 @@ # ============================================================================= if(DISABLE_DEPRECATION_WARNINGS) - list(APPEND RAFT_CXX_FLAGS -Wno-deprecated-declarations) - list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wno-deprecated-declarations) + list(APPEND RAFT_CXX_FLAGS -Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS) + list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS) endif() # Be very strict when compiling with GCC as host compiler (and thus more lenient when compiling with diff --git a/cpp/cmake/patches/faiss_override.json b/cpp/cmake/patches/faiss_override.json new file mode 100644 index 0000000000..19dad362b9 --- /dev/null +++ b/cpp/cmake/patches/faiss_override.json @@ -0,0 +1,9 @@ +{ + "packages" : { + "faiss" : { + "version": "1.7.4", + "git_url": "https://github.com/facebookresearch/faiss.git", + "git_tag": "main" + } + } +} diff --git a/cpp/cmake/patches/ggnn_override.json b/cpp/cmake/patches/ggnn_override.json new file mode 100644 index 0000000000..768fae8b0c --- /dev/null +++ b/cpp/cmake/patches/ggnn_override.json @@ -0,0 +1,16 @@ +{ + "packages" : { + "ggnn" : { + "version": "0.5", + "git_url": "https://github.com/cgtuebingen/ggnn.git", + "git_tag": "release_${version}", + "patches" : [ + { + "file" : "${current_json_dir}/ggnn.diff", + "issue" : "Correct compilation issues", + "fixed_in" : "" + } + ] + } + } +} diff --git a/cpp/cmake/patches/hnswlib_override.json b/cpp/cmake/patches/hnswlib_override.json new file mode 100644 index 0000000000..d6ab8a18a5 --- /dev/null +++ b/cpp/cmake/patches/hnswlib_override.json @@ -0,0 +1,16 @@ +{ + "packages" : { + "hnswlib" : { + "version": "0.6.2", + "git_url": "https://github.com/nmslib/hnswlib.git", + "git_tag": "v${version}", + "patches" : [ + { + "file" : "${current_json_dir}/hnswlib.diff", + "issue" : "Correct compilation issues", + "fixed_in" : "" + } + ] + } + } +} diff --git a/cpp/cmake/thirdparty/get_faiss.cmake b/cpp/cmake/thirdparty/get_faiss.cmake index 85829554ae..288da763bf 100644 --- a/cpp/cmake/thirdparty/get_faiss.cmake +++ b/cpp/cmake/thirdparty/get_faiss.cmake @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,96 +15,104 @@ #============================================================================= function(find_and_configure_faiss) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG BUILD_STATIC_LIBS EXCLUDE_FROM_ALL ENABLE_GPU) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG BUILD_STATIC_LIBS EXCLUDE_FROM_ALL ENABLE_GPU) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + rapids_find_generate_module(faiss + HEADER_NAMES faiss/IndexFlat.h + LIBRARY_NAMES faiss + ) + + set(patch_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../patches") + rapids_cpm_package_override("${patch_dir}/faiss_override.json") + + include("${rapids-cmake-dir}/cpm/detail/package_details.cmake") + rapids_cpm_package_details(faiss version repository tag shallow exclude) + + include("${rapids-cmake-dir}/cpm/detail/generate_patch_command.cmake") + rapids_cpm_generate_patch_command(faiss ${version} patch_command) + + set(BUILD_SHARED_LIBS ON) + if (PKG_BUILD_STATIC_LIBS) + set(BUILD_SHARED_LIBS OFF) + set(CPM_DOWNLOAD_faiss ON) + endif() + + include(cmake/modules/FindAVX) + # Link against AVX CPU lib if it exists + set(RAFT_FAISS_OPT_LEVEL "generic") + if(CXX_AVX2_FOUND) + set(RAFT_FAISS_OPT_LEVEL "avx2") + endif() + + rapids_cpm_find(faiss ${version} + GLOBAL_TARGETS faiss faiss_avx2 faiss_gpu faiss::faiss faiss::faiss_avx2 + CPM_ARGS + GIT_REPOSITORY ${repository} + GIT_TAG ${tag} + GIT_SHALLOW ${shallow} ${patch_command} + EXCLUDE_FROM_ALL ${exclude} + OPTIONS + "FAISS_ENABLE_GPU ${PKG_ENABLE_GPU}" + "FAISS_ENABLE_PYTHON OFF" + "FAISS_OPT_LEVEL ${RAFT_FAISS_OPT_LEVEL}" + "FAISS_USE_CUDA_TOOLKIT_STATIC ${CUDA_STATIC_RUNTIME}" + "BUILD_TESTING OFF" + "CMAKE_MESSAGE_LOG_LEVEL VERBOSE" + ) + + include("${rapids-cmake-dir}/cpm/detail/display_patch_status.cmake") + rapids_cpm_display_patch_status(hnswlib) + + if(TARGET faiss AND NOT TARGET faiss::faiss) + add_library(faiss::faiss ALIAS faiss) + # We need to ensure that faiss has all the conda information. So we use this approach so that + # faiss will have the conda includes/link dirs + target_link_libraries(faiss PRIVATE $) + endif() + if(TARGET faiss_avx2 AND NOT TARGET faiss::faiss_avx2) + add_library(faiss::faiss_avx2 ALIAS faiss_avx2) + # We need to ensure that faiss has all the conda information. So we use this approach so that + # faiss will have the conda includes/link dirs + target_link_libraries(faiss_avx2 PRIVATE $) + endif() + if(TARGET faiss_gpu AND NOT TARGET faiss::faiss_gpu) + add_library(faiss::faiss_gpu ALIAS faiss_gpu) + # We need to ensure that faiss has all the conda information. So we use this approach so that + # faiss will have the conda includes/link dirs + target_link_libraries(faiss_gpu PRIVATE $) + endif() + + if(faiss_ADDED) + rapids_export(BUILD faiss + EXPORT_SET faiss-targets + GLOBAL_TARGETS ${RAFT_FAISS_EXPORT_GLOBAL_TARGETS} + NAMESPACE faiss::) + endif() + + # Need to tell CMake to rescan the link group of faiss::faiss_gpu and faiss + # so that we get proper link order when they are static + # + # We don't look at the existence of `faiss_avx2` as it will always exist + # even when CXX_AVX2_FOUND is false. In addition for arm builds the + # faiss_avx2 is marked as `EXCLUDE_FROM_ALL` so we don't want to add + # a dependency to it. Adding a dependency will cause it to compile, + # and fail due to invalid compiler flags. + if(PKG_ENABLE_GPU AND PKG_BUILD_STATIC_LIBS AND CXX_AVX2_FOUND) + set(RAFT_FAISS_TARGETS "$,faiss::faiss_avx2>" PARENT_SCOPE) + elseif(PKG_ENABLE_GPU AND PKG_BUILD_STATIC_LIBS) + set(RAFT_FAISS_TARGETS "$,faiss::faiss>" PARENT_SCOPE) + elseif(CXX_AVX2_FOUND) + set(RAFT_FAISS_TARGETS faiss::faiss_avx2 PARENT_SCOPE) + else() + set(RAFT_FAISS_TARGETS faiss::faiss PARENT_SCOPE) + endif() - rapids_find_generate_module(faiss - HEADER_NAMES faiss/IndexFlat.h - LIBRARY_NAMES faiss - ) - - set(BUILD_SHARED_LIBS ON) - if (PKG_BUILD_STATIC_LIBS) - set(BUILD_SHARED_LIBS OFF) - set(CPM_DOWNLOAD_faiss ON) - endif() - - include(cmake/modules/FindAVX.cmake) - - # Link against AVX CPU lib if it exists - set(RAFT_FAISS_GLOBAL_TARGETS faiss::faiss) - set(RAFT_FAISS_EXPORT_GLOBAL_TARGETS faiss) - set(RAFT_FAISS_OPT_LEVEL "generic") - if(CXX_AVX_FOUND) - set(RAFT_FAISS_OPT_LEVEL "avx2") - list(APPEND RAFT_FAISS_GLOBAL_TARGETS faiss::faiss_avx2) - list(APPEND RAFT_FAISS_EXPORT_GLOBAL_TARGETS faiss_avx2) - endif() - - rapids_cpm_find(faiss ${PKG_VERSION} - GLOBAL_TARGETS ${RAFT_FAISS_GLOBAL_TARGETS} - CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} - OPTIONS - "FAISS_ENABLE_GPU ${PKG_ENABLE_GPU}" - "FAISS_ENABLE_PYTHON OFF" - "FAISS_OPT_LEVEL ${RAFT_FAISS_OPT_LEVEL}" - "FAISS_USE_CUDA_TOOLKIT_STATIC ${CUDA_STATIC_RUNTIME}" - "BUILD_TESTING OFF" - "CMAKE_MESSAGE_LOG_LEVEL VERBOSE" - ) - - if(TARGET faiss AND NOT TARGET faiss::faiss) - add_library(faiss::faiss ALIAS faiss) - endif() - - if(CXX_AVX_FOUND) - - if(TARGET faiss_avx2 AND NOT TARGET faiss::faiss_avx2) - add_library(faiss::faiss_avx2 ALIAS faiss_avx2) - endif() - endif() - - - if(faiss_ADDED) - rapids_export(BUILD faiss - EXPORT_SET faiss-targets - GLOBAL_TARGETS ${RAFT_FAISS_EXPORT_GLOBAL_TARGETS} - NAMESPACE faiss::) - endif() - - # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` - rapids_export_package(BUILD OpenMP raft-ann-bench-exports) # faiss uses openMP but doesn't export a need for it - rapids_export_package(BUILD faiss raft-ann-bench-exports GLOBAL_TARGETS ${RAFT_FAISS_GLOBAL_TARGETS} ${RAFT_FAISS_EXPORT_GLOBAL_TARGETS}) - rapids_export_package(INSTALL faiss raft-ann-bench-exports GLOBAL_TARGETS ${RAFT_FAISS_GLOBAL_TARGETS} ${RAFT_FAISS_EXPORT_GLOBAL_TARGETS}) - - # Tell cmake where it can find the generated faiss-config.cmake we wrote. - include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(BUILD faiss [=[${CMAKE_CURRENT_LIST_DIR}]=] - EXPORT_SET raft-ann-bench-exports) endfunction() -if(NOT RAFT_FAISS_GIT_TAG) - # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC - # (https://github.com/facebookresearch/faiss/pull/2446) - set(RAFT_FAISS_GIT_TAG fea/statically-link-ctk) - # set(RAFT_FAISS_GIT_TAG bde7c0027191f29c9dadafe4f6e68ca0ee31fb30) -endif() - -if(NOT RAFT_FAISS_GIT_REPOSITORY) - # TODO: Remove this once faiss supports FAISS_USE_CUDA_TOOLKIT_STATIC - # (https://github.com/facebookresearch/faiss/pull/2446) - set(RAFT_FAISS_GIT_REPOSITORY https://github.com/cjnolet/faiss.git) - # set(RAFT_FAISS_GIT_REPOSITORY https://github.com/facebookresearch/faiss.git) -endif() - -find_and_configure_faiss(VERSION 1.7.4 - REPOSITORY ${RAFT_FAISS_GIT_REPOSITORY} - PINNED_TAG ${RAFT_FAISS_GIT_TAG} - BUILD_STATIC_LIBS ${RAFT_USE_FAISS_STATIC} - EXCLUDE_FROM_ALL ${RAFT_EXCLUDE_FAISS_FROM_ALL} - ENABLE_GPU ${RAFT_FAISS_ENABLE_GPU}) +find_and_configure_faiss( + BUILD_STATIC_LIBS ${RAFT_USE_FAISS_STATIC} + ENABLE_GPU ${RAFT_FAISS_ENABLE_GPU} +) diff --git a/cpp/cmake/thirdparty/get_ggnn.cmake b/cpp/cmake/thirdparty/get_ggnn.cmake index 8137ef84eb..d8af4971a7 100644 --- a/cpp/cmake/thirdparty/get_ggnn.cmake +++ b/cpp/cmake/thirdparty/get_ggnn.cmake @@ -15,29 +15,31 @@ #============================================================================= function(find_and_configure_ggnn) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) + include(${rapids-cmake-dir}/cpm/package_override.cmake) + set(patch_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../patches") + rapids_cpm_package_override("${patch_dir}/ggnn_override.json") - set(patch_files_to_run "${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/ggnn.diff") - set(patch_issues_to_ref "fix compile issues") - set(patch_script "${CMAKE_BINARY_DIR}/rapids-cmake/patches/ggnn/patch.cmake") - set(log_file "${CMAKE_BINARY_DIR}/rapids-cmake/patches/ggnn/log") - string(TIMESTAMP current_year "%Y" UTC) - configure_file(${rapids-cmake-dir}/cpm/patches/command_template.cmake.in "${patch_script}" - @ONLY) + include("${rapids-cmake-dir}/cpm/detail/package_details.cmake") + rapids_cpm_package_details(ggnn version repository tag shallow exclude) + + include("${rapids-cmake-dir}/cpm/detail/generate_patch_command.cmake") + rapids_cpm_generate_patch_command(ggnn ${version} patch_command) rapids_cpm_find( - ggnn ${PKG_VERSION} + ggnn ${version} GLOBAL_TARGETS ggnn::ggnn CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - GIT_SHALLOW TRUE + GIT_REPOSITORY ${repository} + GIT_TAG ${tag} + GIT_SHALLOW ${shallow} ${patch_command} + EXCLUDE_FROM_ALL ${exclude} DOWNLOAD_ONLY ON - PATCH_COMMAND ${CMAKE_COMMAND} -P ${patch_script} ) + + include("${rapids-cmake-dir}/cpm/detail/display_patch_status.cmake") + rapids_cpm_display_patch_status(ggnn) + if(NOT TARGET ggnn::ggnn) add_library(ggnn INTERFACE) target_include_directories(ggnn INTERFACE "$") @@ -45,14 +47,4 @@ function(find_and_configure_ggnn) endif() endfunction() -if(NOT RAFT_GGNN_GIT_TAG) - set(RAFT_GGNN_GIT_TAG release_0.5) -endif() - -if(NOT RAFT_GGNN_GIT_REPOSITORY) - set(RAFT_GGNN_GIT_REPOSITORY https://github.com/cgtuebingen/ggnn.git) -endif() -find_and_configure_ggnn(VERSION 0.5 - REPOSITORY ${RAFT_GGNN_GIT_REPOSITORY} - PINNED_TAG ${RAFT_GGNN_GIT_TAG} - ) +find_and_configure_ggnn() diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 4d28e9a064..6ef493336f 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -15,78 +15,74 @@ #============================================================================= function(find_and_configure_hnswlib) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG EXCLUDE_FROM_ALL) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) + set(oneValueArgs) - set(patch_files_to_run "${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.diff") - set(patch_issues_to_ref "fix compile issues") - set(patch_script "${CMAKE_BINARY_DIR}/rapids-cmake/patches/hnswlib/patch.cmake") - set(log_file "${CMAKE_BINARY_DIR}/rapids-cmake/patches/hnswlib/log") - string(TIMESTAMP current_year "%Y" UTC) - configure_file(${rapids-cmake-dir}/cpm/patches/command_template.cmake.in "${patch_script}" - @ONLY) + include(${rapids-cmake-dir}/cpm/package_override.cmake) + set(patch_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../patches") + rapids_cpm_package_override("${patch_dir}/hnswlib_override.json") + + include("${rapids-cmake-dir}/cpm/detail/package_details.cmake") + rapids_cpm_package_details(hnswlib version repository tag shallow exclude) + + include("${rapids-cmake-dir}/cpm/detail/generate_patch_command.cmake") + rapids_cpm_generate_patch_command(hnswlib ${version} patch_command) rapids_cpm_find( - hnswlib ${PKG_VERSION} - GLOBAL_TARGETS hnswlib::hnswlib - BUILD_EXPORT_SET raft-exports - INSTALL_EXPORT_SET raft-exports + hnswlib ${version} + GLOBAL_TARGETS hnswlib hnswlib::hnswlib CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - GIT_SHALLOW TRUE + GIT_REPOSITORY ${repository} + GIT_TAG ${tag} + GIT_SHALLOW ${shallow} ${patch_command} + EXCLUDE_FROM_ALL ${exclude} DOWNLOAD_ONLY ON - PATCH_COMMAND ${CMAKE_COMMAND} -P ${patch_script} ) + + include("${rapids-cmake-dir}/cpm/detail/display_patch_status.cmake") + rapids_cpm_display_patch_status(hnswlib) + if(NOT TARGET hnswlib::hnswlib) add_library(hnswlib INTERFACE ) add_library(hnswlib::hnswlib ALIAS hnswlib) target_include_directories(hnswlib INTERFACE "$" "$") + endif() - if(NOT PKG_EXCLUDE_FROM_ALL) - install(TARGETS hnswlib EXPORT hnswlib-exports) + if(hnswlib_ADDED) + # write build export rules + install(TARGETS hnswlib EXPORT hnswlib-exports) + if(NOT exclude) install(DIRECTORY "${hnswlib_SOURCE_DIR}/hnswlib/" DESTINATION include/hnswlib) # write install export rules rapids_export( INSTALL hnswlib - VERSION ${PKG_VERSION} + VERSION ${version} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) endif() - # write build export rules rapids_export( BUILD hnswlib - VERSION ${PKG_VERSION} + VERSION ${version} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) - include("${rapids-cmake-dir}/export/find_package_root.cmake") + include("${rapids-cmake-dir}/export/package.cmake") + rapids_export_package(INSTALL hnswlib raft-exports VERSION ${version} GLOBAL_TARGETS hnswlib hnswlib::hnswlib) + rapids_export_package(BUILD hnswlib raft-exports VERSION ${version} GLOBAL_TARGETS hnswlib hnswlib::hnswlib) + # When using RAFT from the build dir, ensure hnswlib is also found in RAFT's build dir. This # line adds `set(hnswlib_ROOT "${CMAKE_CURRENT_LIST_DIR}")` to build/raft-dependencies.cmake + include("${rapids-cmake-dir}/export/find_package_root.cmake") rapids_export_find_package_root( BUILD hnswlib [=[${CMAKE_CURRENT_LIST_DIR}]=] EXPORT_SET raft-exports ) endif() endfunction() - -if(NOT RAFT_HNSWLIB_GIT_TAG) - set(RAFT_HNSWLIB_GIT_TAG v0.6.2) -endif() - -if(NOT RAFT_HNSWLIB_GIT_REPOSITORY) - set(RAFT_HNSWLIB_GIT_REPOSITORY https://github.com/nmslib/hnswlib.git) -endif() -find_and_configure_hnswlib(VERSION 0.6.2 - REPOSITORY ${RAFT_HNSWLIB_GIT_REPOSITORY} - PINNED_TAG ${RAFT_HNSWLIB_GIT_TAG} - EXCLUDE_FROM_ALL OFF - ) +find_and_configure_hnswlib() diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 6d3f430e88..0a5a3ba5aa 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -43,15 +43,14 @@ #include #include -#include -#include #include -#include +#include #include #include #include +#include #include #include @@ -91,7 +90,7 @@ inline std::enable_if_t> predict_core( const MathT* dataset_norm, IdxT n_rows, LabelT* labels, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { auto stream = resource::get_cuda_stream(handle); switch (params.metric) { @@ -263,10 +262,9 @@ void calc_centers_and_sizes(const raft::resources& handle, const LabelT* labels, bool reset_counters, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::device_async_resource_ref mr) { auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } if (!reset_counters) { raft::linalg::matrixVectorOp( @@ -322,12 +320,12 @@ void compute_norm(const raft::resources& handle, IdxT dim, IdxT n_rows, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) + std::optional mr = std::nullopt) { common::nvtx::range fun_scope("compute_norm"); auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } - rmm::device_uvector mapped_dataset(0, stream, mr); + rmm::device_uvector mapped_dataset( + 0, stream, mr.value_or(resource::get_workspace_resource(handle))); const MathT* dataset_ptr = nullptr; @@ -338,7 +336,7 @@ void compute_norm(const raft::resources& handle, linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); - dataset_ptr = (const MathT*)mapped_dataset.data(); + dataset_ptr = static_cast(mapped_dataset.data()); } raft::linalg::rowNorm( @@ -376,22 +374,22 @@ void predict(const raft::resources& handle, IdxT n_rows, LabelT* labels, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr, - const MathT* dataset_norm = nullptr) + std::optional mr = std::nullopt, + const MathT* dataset_norm = nullptr) { auto stream = resource::get_cuda_stream(handle); common::nvtx::range fun_scope( "predict(%zu, %u)", static_cast(n_rows), n_clusters); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } + auto mem_res = mr.value_or(resource::get_workspace_resource(handle)); auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); + std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); bool need_compute_norm = dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || params.metric == raft::distance::DistanceType::L2SqrtExpanded); rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mr); + need_compute_norm ? max_minibatch_size : 0, stream, mem_res); const MathT* dataset_norm_ptr = nullptr; auto cur_dataset_ptr = cur_dataset.data(); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { @@ -407,7 +405,7 @@ void predict(const raft::resources& handle, // Compute the norm now if it hasn't been pre-computed. if (need_compute_norm) { compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr); + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); dataset_norm_ptr = cur_dataset_norm.data(); } else if (dataset_norm != nullptr) { dataset_norm_ptr = dataset_norm + offset; @@ -422,7 +420,7 @@ void predict(const raft::resources& handle, dataset_norm_ptr, minibatch_size, labels + offset, - mr); + mem_res); } } @@ -530,7 +528,7 @@ auto adjust_centers(MathT* centers, MathT threshold, MappingOpT mapping_op, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* device_memory) -> bool + rmm::device_async_resource_ref device_memory) -> bool { common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); @@ -628,7 +626,7 @@ void balancing_em_iters(const raft::resources& handle, uint32_t balancing_pullback, MathT balancing_threshold, MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory) + rmm::device_async_resource_ref device_memory) { auto stream = resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; @@ -711,7 +709,7 @@ void build_clusters(const raft::resources& handle, LabelT* cluster_labels, CounterT* cluster_sizes, MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory, + rmm::device_async_resource_ref device_memory, const MathT* dataset_norm = nullptr) { auto stream = resource::get_cuda_stream(handle); @@ -853,8 +851,8 @@ auto build_fine_clusters(const raft::resources& handle, IdxT fine_clusters_nums_max, MathT* cluster_centers, MappingOpT mapping_op, - rmm::mr::device_memory_resource* managed_memory, - rmm::mr::device_memory_resource* device_memory) -> IdxT + rmm::device_async_resource_ref managed_memory, + rmm::device_async_resource_ref device_memory) -> IdxT { auto stream = resource::get_cuda_stream(handle); rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); @@ -971,7 +969,7 @@ void build_hierarchical(const raft::resources& handle, // TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf. rmm::mr::managed_memory_resource managed_memory; - rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh index 8cd7730814..a1a182608b 100644 --- a/cpp/include/raft/cluster/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/kmeans_balanced.cuh @@ -358,7 +358,8 @@ void calc_centers_and_sizes(const raft::resources& handle, X.extent(0), labels.data_handle(), reset_counters, - mapping_op); + mapping_op, + resource::get_workspace_resource(handle)); } } // namespace helpers diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh index 9588a7f329..e85b05575f 100644 --- a/cpp/include/raft/cluster/specializations.cuh +++ b/cpp/include/raft/cluster/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message( \ __FILE__ \ " is deprecated and will be removed." \ " Including specializations is not necessary any more." \ " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") +#endif diff --git a/cpp/include/raft/common/cub_wrappers.cuh b/cpp/include/raft/common/cub_wrappers.cuh index dd8fc2d103..239d6e08f6 100644 --- a/cpp/include/raft/common/cub_wrappers.cuh +++ b/cpp/include/raft/common/cub_wrappers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,11 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please note that there is no equivalent in RAFT's public API" " so this file will eventually be removed altogether.") +#endif #include diff --git a/cpp/include/raft/common/device_loads_stores.cuh b/cpp/include/raft/common/device_loads_stores.cuh index 6c62cd70cc..53724f4ae1 100644 --- a/cpp/include/raft/common/device_loads_stores.cuh +++ b/cpp/include/raft/common/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,10 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the raft/util version instead.") +#endif #include diff --git a/cpp/include/raft/common/scatter.cuh b/cpp/include/raft/common/scatter.cuh index 72de79a596..dcbd46b236 100644 --- a/cpp/include/raft/common/scatter.cuh +++ b/cpp/include/raft/common/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,10 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the raft/matrix version instead.") +#endif #include diff --git a/cpp/include/raft/common/seive.hpp b/cpp/include/raft/common/seive.hpp index 433b032b0f..56b41a41f4 100644 --- a/cpp/include/raft/common/seive.hpp +++ b/cpp/include/raft/common/seive.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,10 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the raft/util version instead.") +#endif #include diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 6e7ff7106f..cb1accc95e 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,17 @@ namespace raft { namespace comms { namespace detail { +using ucp_endpoint_array_t = std::shared_ptr; +using ucxx_endpoint_array_t = std::shared_ptr; +using ucp_worker_t = ucp_worker_h; +using ucxx_worker_t = ucxx::Worker*; + +struct ucx_objects_t { + public: + std::variant endpoints; + std::variant worker; +}; + class std_comms : public comms_iface { public: std_comms() = delete; @@ -64,8 +76,7 @@ class std_comms : public comms_iface { * @param subcomms_ucp use ucp for subcommunicators */ std_comms(ncclComm_t nccl_comm, - ucp_worker_h ucp_worker, - std::shared_ptr eps, + ucx_objects_t ucx_objects, int num_ranks, int rank, rmm::cuda_stream_view stream, @@ -76,9 +87,8 @@ class std_comms : public comms_iface { num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), + ucx_objects_(ucx_objects), own_nccl_comm_(false), - ucp_worker_(ucp_worker), - ucp_eps_(eps), next_request_id_(0) { initialize(); @@ -205,96 +215,209 @@ class std_comms : public comms_iface { void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + if (std::holds_alternative(ucx_objects_.worker)) { + get_request_id(request); - get_request_id(request); - ucp_ep_h ep_ptr = (*ucp_eps_)[dest]; + ucxx::Endpoint* ep_ptr = (*std::get(ucx_objects_.endpoints))[dest]; - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + ucp_tag_t ucp_tag = build_message_tag(get_rank(), tag); + auto ucxx_req = ep_ptr->tagSend(const_cast(buf), size, ucxx::Tag(ucp_tag)); - this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); + requests_in_flight_.insert(std::make_pair(*request, ucxx_req)); + } else { + ASSERT(std::get(ucx_objects_.worker) != nullptr, + "ERROR: UCX comms not initialized on communicator."); - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); - } + get_request_id(request); + ucp_ep_h ep_ptr = (*std::get(ucx_objects_.endpoints))[dest]; - void irecv(void* buf, size_t size, int source, int tag, request_t* request) const - { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - get_request_id(request); + this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); - ucp_ep_h ep_ptr = (*ucp_eps_)[source]; - - ucp_tag_t tag_mask = default_tag_mask; - - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - ucp_handler_.ucp_irecv(ucp_req, ucp_worker_, ep_ptr, buf, size, tag, tag_mask, source); - - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } } - void waitall(int count, request_t array_of_requests[]) const + void irecv(void* buf, size_t size, int source, int tag, request_t* request) const { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + if (std::holds_alternative(ucx_objects_.worker)) { + get_request_id(request); - std::vector requests; - requests.reserve(count); + ucxx::Endpoint* ep_ptr = (*std::get(ucx_objects_.endpoints))[source]; - time_t start = time(NULL); + ucp_tag_t ucp_tag = build_message_tag(get_rank(), tag); + auto ucxx_req = + ep_ptr->tagRecv(buf, size, ucxx::Tag(ucp_tag), ucxx::TagMask(default_tag_mask)); - for (int i = 0; i < count; ++i) { - auto req_it = requests_in_flight_.find(array_of_requests[i]); - ASSERT(requests_in_flight_.end() != req_it, - "ERROR: waitall on invalid request: %d", - array_of_requests[i]); - requests.push_back(req_it->second); - free_requests_.insert(req_it->first); - requests_in_flight_.erase(req_it); - } - - while (requests.size() > 0) { - time_t now = time(NULL); + requests_in_flight_.insert(std::make_pair(*request, ucxx_req)); + } else { + ASSERT(std::get(ucx_objects_.worker) != nullptr, + "ERROR: UCX comms not initialized on communicator."); - // Timeout if we have not gotten progress or completed any requests - // in 10 or more seconds. - ASSERT(now - start < 10, "Timed out waiting for requests."); + get_request_id(request); - for (std::vector::iterator it = requests.begin(); it != requests.end();) { - bool restart = false; // resets the timeout when any progress was made + ucp_ep_h ep_ptr = (*std::get(ucx_objects_.endpoints))[source]; - // Causes UCP to progress through the send/recv message queue - while (ucp_worker_progress(ucp_worker_) != 0) { - restart = true; - } + ucp_tag_t tag_mask = default_tag_mask; - auto req = *it; + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + ucp_handler_.ucp_irecv(ucp_req, + std::get(ucx_objects_.worker), + ep_ptr, + buf, + size, + tag, + tag_mask, + source); - // If the message needs release, we know it will be sent/received - // asynchronously, so we will need to track and verify its state - if (req->needs_release) { - ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); - ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); - ASSERT(req->req->completed == 1 || req->req->completed == 0, - "request->completed not a valid value: %d\n", - req->req->completed); - } + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } + } - // If a message was sent synchronously (eg. completed before - // `isend`/`irecv` completed) or an asynchronous message - // is complete, we can go ahead and clean it up. - if (!req->needs_release || req->req->completed == 1) { - restart = true; + void waitall(int count, request_t array_of_requests[]) const + { + if (std::holds_alternative(ucx_objects_.worker)) { + ucxx_worker_t worker = std::get(ucx_objects_.worker); + + std::vector> requests; + requests.reserve(count); + + time_t start = time(NULL); + + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + ASSERT(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(std::get>(req_it->second)); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } - // perform cleanup - ucp_handler_.free_ucp_request(req); + while (requests.size() > 0) { + time_t now = time(NULL); + + // Timeout if we have not gotten progress or completed any requests + // in 10 or more seconds. + ASSERT(now - start < 10, "Timed out waiting for requests."); + + for (std::vector>::iterator it = requests.begin(); + it != requests.end();) { + bool restart = false; // resets the timeout when any progress was made + + if (worker->isProgressThreadRunning()) { + // Wait for a UCXX progress thread roundtrip + ucxx::utils::CallbackNotifier callbackNotifierPre{}; + worker->registerGenericPre([&callbackNotifierPre]() { callbackNotifierPre.set(); }); + callbackNotifierPre.wait(); + + ucxx::utils::CallbackNotifier callbackNotifierPost{}; + worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); }); + callbackNotifierPost.wait(); + } else { + // Causes UCXX to progress through the send/recv message queue + while (!worker->progress()) { + restart = true; + } + } + + auto req = *it; + + // If the message needs release, we know it will be sent/received + // asynchronously, so we will need to track and verify its state + if (req->isCompleted()) { + auto status = req->getStatus(); + ASSERT(req->getStatus() == UCS_OK, + "UCX Request Error: %d (%s)\n", + status, + ucs_status_string(status)); + } + + // If a message was sent synchronously (eg. completed before + // `isend`/`irecv` completed) or an asynchronous message + // is complete, we can go ahead and clean it up. + if (req->isCompleted()) { + restart = true; + + auto status = req->getStatus(); + ASSERT(req->getStatus() == UCS_OK, + "UCX Request Error: %d (%s)\n", + status, + ucs_status_string(status)); + + // remove from pending requests + it = requests.erase(it); + } else { + ++it; + } + // if any progress was made, reset the timeout start time + if (restart) { start = time(NULL); } + } + } + } else { + ucp_worker_t worker = std::get(ucx_objects_.worker); + ASSERT(worker != nullptr, "ERROR: UCX comms not initialized on communicator."); + + std::vector requests; + requests.reserve(count); + + time_t start = time(NULL); + + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + ASSERT(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(std::get(req_it->second)); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } - // remove from pending requests - it = requests.erase(it); - } else { - ++it; + while (requests.size() > 0) { + time_t now = time(NULL); + + // Timeout if we have not gotten progress or completed any requests + // in 10 or more seconds. + ASSERT(now - start < 10, "Timed out waiting for requests."); + + for (std::vector::iterator it = requests.begin(); it != requests.end();) { + bool restart = false; // resets the timeout when any progress was made + + // Causes UCP to progress through the send/recv message queue + while (ucp_worker_progress(worker) != 0) { + restart = true; + } + + auto req = *it; + + // If the message needs release, we know it will be sent/received + // asynchronously, so we will need to track and verify its state + if (req->needs_release) { + ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); + ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); + ASSERT(req->req->completed == 1 || req->req->completed == 0, + "request->completed not a valid value: %d\n", + req->req->completed); + } + + // If a message was sent synchronously (eg. completed before + // `isend`/`irecv` completed) or an asynchronous message + // is complete, we can go ahead and clean it up. + if (!req->needs_release || req->req->completed == 1) { + restart = true; + + // perform cleanup + ucp_handler_.free_ucp_request(req); + + // remove from pending requests + it = requests.erase(it); + } else { + ++it; + } + // if any progress was made, reset the timeout start time + if (restart) { start = time(NULL); } } - // if any progress was made, reset the timeout start time - if (restart) { start = time(NULL); } } } } @@ -524,10 +647,11 @@ class std_comms : public comms_iface { bool own_nccl_comm_; comms_ucp_handler ucp_handler_; - ucp_worker_h ucp_worker_; - std::shared_ptr ucp_eps_; + ucx_objects_t ucx_objects_; mutable request_t next_request_id_; - mutable std::unordered_map requests_in_flight_; + mutable std::unordered_map>> + requests_in_flight_; mutable std::unordered_set free_requests_; }; } // namespace detail diff --git a/cpp/include/raft/comms/detail/ucp_helper.hpp b/cpp/include/raft/comms/detail/ucp_helper.hpp index 5896248c1d..65e1957e54 100644 --- a/cpp/include/raft/comms/detail/ucp_helper.hpp +++ b/cpp/include/raft/comms/detail/ucp_helper.hpp @@ -46,9 +46,7 @@ struct ucx_context { class ucp_request { public: struct ucx_context* req; - bool needs_release = true; - int other_rank = -1; - bool is_send_request = false; + bool needs_release = true; }; // by default, match the whole tag @@ -72,17 +70,16 @@ static void recv_callback(void* request, ucs_status_t status, ucp_tag_recv_info_ context->completed = 1; } +ucp_tag_t build_message_tag(int rank, int tag) +{ + // keeping the rank in the lower bits enables debugging. + return ((uint32_t)tag << 31) | (uint32_t)rank; +} + /** * Helper class for interacting with ucp. */ class comms_ucp_handler { - private: - ucp_tag_t build_message_tag(int rank, int tag) const - { - // keeping the rank in the lower bits enables debugging. - return ((uint32_t)tag << 31) | (uint32_t)rank; - } - public: /** * @brief Frees any memory underlying the given ucp request object @@ -132,9 +129,7 @@ class comms_ucp_handler { req->needs_release = false; } - req->other_rank = rank; - req->is_send_request = true; - req->req = ucp_req; + req->req = ucp_req; } /** @@ -156,10 +151,8 @@ class comms_ucp_handler { struct ucx_context* ucp_req = (struct ucx_context*)recv_result; - req->req = ucp_req; - req->needs_release = true; - req->is_send_request = false; - req->other_rank = sender_rank; + req->req = ucp_req; + req->needs_release = true; ASSERT(!UCS_PTR_IS_ERR(recv_result), "unable to receive UCX data message (%d)\n", diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index c81b19c9ba..667c8be285 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -24,6 +24,7 @@ #include #include +#include #include @@ -81,6 +82,8 @@ void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_rank * * @param handle raft::resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives + * @param is_ucxx whether `ucp_worker` and `eps` objects are UCXX (true) or + * pure UCX (false). * @param ucp_worker of local process * Note: This is purposefully left as void* so that the ucp_worker_h * doesn't need to be exposed through the cython layer @@ -112,30 +115,55 @@ void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_rank * comm.sync_stream(resource::get_cuda_stream(handle)); * @endcode */ -void build_comms_nccl_ucx( - resources* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) +void build_comms_nccl_ucx(resources* handle, + ncclComm_t nccl_comm, + bool is_ucxx, + void* ucp_worker, + void* eps, + int num_ranks, + int rank) { - auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); + detail::ucx_objects_t ucx_objects; + if (is_ucxx) { + ucx_objects.endpoints = std::make_shared(new ucxx::Endpoint*[num_ranks]); + ucx_objects.worker = static_cast(ucp_worker); + } else { + ucx_objects.endpoints = std::make_shared(new ucp_ep_h[num_ranks]); + ucx_objects.worker = static_cast(ucp_worker); + } auto size_t_ep_arr = reinterpret_cast(eps); for (int i = 0; i < num_ranks; i++) { - size_t ptr = size_t_ep_arr[i]; - auto ucp_ep_v = reinterpret_cast(*eps_sp); - - if (ptr != 0) { - auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); - ucp_ep_v[i] = eps_ptr; + size_t ptr = size_t_ep_arr[i]; + + if (is_ucxx) { + auto ucp_ep_v = reinterpret_cast( + *std::get(ucx_objects.endpoints)); + + if (ptr != 0) { + auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); + ucp_ep_v[i] = eps_ptr; + } else { + ucp_ep_v[i] = nullptr; + } } else { - ucp_ep_v[i] = nullptr; + auto ucp_ep_v = + reinterpret_cast(*std::get(ucx_objects.endpoints)); + + if (ptr != 0) { + auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); + ucp_ep_v[i] = eps_ptr; + } else { + ucp_ep_v[i] = nullptr; + } } } cudaStream_t stream = resource::get_cuda_stream(*handle); - auto communicator = - std::make_shared(std::unique_ptr(new raft::comms::std_comms( - nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); + auto communicator = std::make_shared(std::unique_ptr( + new raft::comms::std_comms(nccl_comm, ucx_objects, num_ranks, rank, stream))); resource::set_comms(*handle, communicator); } diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 829c84ed25..2c23a77e47 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -16,112 +16,30 @@ #pragma once +#include #include #include #include #include #include -namespace raft::core { -/** - * @defgroup bitmap Bitmap - * @{ - */ -/** - * @brief View of a RAFT Bitmap. - * - * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view - * with row major order. This class provides functionality for handling a matrix where each element - * is represented as a bit in a bitmap. - * - * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. - * @tparam index_t Indexing type used. Default is uint32_t. - */ -template -struct bitmap_view : public bitset_view { - static_assert((std::is_same::value || - std::is_same::value), - "The bitmap_t must be uint32_t or uint64_t."); - /** - * @brief Create a bitmap view from a device raw pointer. - * - * @param bitmap_ptr Device raw pointer - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) - { - } - - /** - * @brief Create a bitmap view from a device vector view of the bitset. - * - * @param bitmap_span Device vector view of the bitmap - * @param rows Number of row in the matrix. - * @param cols Number of col in the matrix. - */ - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t rows, - index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) - { - } +#include - private: - // Hide the constructors of bitset_view. - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) - : bitset_view(bitmap_ptr, bitmap_len) - { - } - - _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, - index_t bitmap_len) - : bitset_view(bitmap_span, bitmap_len) - { - } - - public: - /** - * @brief Device function to test if a given row and col are set in the bitmap. - * - * @param row Row index of the bit to test - * @param col Col index of the bit to test - * @return bool True if index has not been unset in the bitset - */ - inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool - { - return test(row * cols_ + col); - } - - /** - * @brief Device function to set a given row and col to set_value in the bitset. - * - * @param row Row index of the bit to set - * @param col Col index of the bit to set - * @param new_value Value to set the bit to (true or false) - */ - inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const - { - set(row * cols_ + col, &new_value); - } - - /** - * @brief Get the total number of rows - * @return index_t The total number of rows - */ - inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } - - /** - * @brief Get the total number of columns - * @return index_t The total number of columns - */ - inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } +namespace raft::core { - private: - index_t rows_; - index_t cols_; -}; +template +_RAFT_HOST_DEVICE inline bool bitmap_view::test(const index_t row, + const index_t col) const +{ + return test(row * cols_ + col); +} + +template +_RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, + const index_t col, + bool new_value) const +{ + set(row * cols_ + col, &new_value); +} -/** @} */ } // end namespace raft::core diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp new file mode 100644 index 0000000000..5c77866164 --- /dev/null +++ b/cpp/include/raft/core/bitmap.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + static_assert((std::is_same::type, uint32_t>::value || + std::is_same::type, uint64_t>::value), + "The bitmap_t must be uint32_t or uint64_t."); + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitmap_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_HOST_DEVICE bool test(const index_t row, const index_t col) const; + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_HOST_DEVICE void set(const index_t row, const index_t col, bool new_value) const; + + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 53fd586ed2..d7eedee92e 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -16,7 +16,8 @@ #pragma once -#include // native_popc +#include +#include #include #include #include @@ -28,372 +29,147 @@ #include namespace raft::core { -/** - * @defgroup bitset Bitset - * @{ - */ -/** - * @brief View of a RAFT Bitset. - * - * This lightweight structure stores a pointer to a bitset in device memory with it's length. - * It provides a test() device function to check if a given index is set in the bitset. - * - * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. - * @tparam index_t Indexing type used. Default is uint32_t. - */ -template -struct bitset_view { - static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; - - _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) - : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} - { - } - /** - * @brief Create a bitset view from a device vector view of the bitset. - * - * @param bitset_span Device vector view of the bitset - * @param bitset_len Number of bits in the bitset - */ - _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span, - index_t bitset_len) - : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len} - { - } - /** - * @brief Device function to test if a given index is set in the bitset. - * - * @param sample_index Single index to test - * @return bool True if index has not been unset in the bitset - */ - inline _RAFT_DEVICE auto test(const index_t sample_index) const -> bool - { - const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; - const index_t bit_index = sample_index % bitset_element_size; - const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; - return is_bit_set; - } - /** - * @brief Device function to test if a given index is set in the bitset. - * - * @param sample_index Single index to test - * @return bool True if index has not been unset in the bitset - */ - inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool - { - return test(sample_index); - } - /** - * @brief Device function to set a given index to set_value in the bitset. - * - * @param sample_index index to set - * @param set_value Value to set the bit to (true or false) - */ - inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const - { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; - if (set_value) { - atomicOr(bitset_ptr_ + bit_element, bitmask); - } else { - const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr_ + bit_element, bitmask2); - } - } - - /** - * @brief Get the device pointer to the bitset. - */ - inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; } - inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; } - /** - * @brief Get the number of bits of the bitset representation. - */ - inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; } - - /** - * @brief Get the number of elements used by the bitset representation. - */ - inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t - { - return raft::ceildiv(bitset_len_, bitset_element_size); - } - - inline auto to_mdspan() -> raft::device_vector_view - { - return raft::make_device_vector_view(bitset_ptr_, n_elements()); - } - inline auto to_mdspan() const -> raft::device_vector_view - { - return raft::make_device_vector_view(bitset_ptr_, n_elements()); - } - - private: - bitset_t* bitset_ptr_; - index_t bitset_len_; -}; - -/** - * @brief RAFT Bitset. - * - * This structure encapsulates a bitset in device memory. It provides a view() method to get a - * device-usable lightweight view of the bitset. - * Each index is represented by a single bit in the bitset. The total number of bytes used is - * ceil(bitset_len / 8). - * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. - * @tparam index_t Indexing type used. Default is uint32_t. - */ -template -struct bitset { - static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; - - /** - * @brief Construct a new bitset object with a list of indices to unset. - * - * @param res RAFT resources - * @param mask_index List of indices to unset in the bitset - * @param bitset_len Length of the bitset - * @param default_value Default value to set the bits to. Default is true. - */ - bitset(const raft::resources& res, - raft::device_vector_view mask_index, - index_t bitset_len, - bool default_value = true) - : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), - raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} - { - reset(res, default_value); - set(res, mask_index, !default_value); - } - /** - * @brief Construct a new bitset object - * - * @param res RAFT resources - * @param bitset_len Length of the bitset - * @param default_value Default value to set the bits to. Default is true. - */ - bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) - : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), - resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} - { - reset(res, default_value); - } - // Disable copy constructor - bitset(const bitset&) = delete; - bitset(bitset&&) = default; - bitset& operator=(const bitset&) = delete; - bitset& operator=(bitset&&) = default; - - /** - * @brief Create a device-usable view of the bitset. - * - * @return bitset_view - */ - inline auto view() -> raft::core::bitset_view - { - return bitset_view(to_mdspan(), bitset_len_); - } - [[nodiscard]] inline auto view() const -> raft::core::bitset_view - { - return bitset_view(to_mdspan(), bitset_len_); - } - - /** - * @brief Get the device pointer to the bitset. - */ - inline auto data() -> bitset_t* { return bitset_.data(); } - inline auto data() const -> const bitset_t* { return bitset_.data(); } - /** - * @brief Get the number of bits of the bitset representation. - */ - inline auto size() const -> index_t { return bitset_len_; } - - /** - * @brief Get the number of elements used by the bitset representation. - */ - inline auto n_elements() const -> index_t - { - return raft::ceildiv(bitset_len_, bitset_element_size); - } - - /** @brief Get an mdspan view of the current bitset */ - inline auto to_mdspan() -> raft::device_vector_view - { - return raft::make_device_vector_view(bitset_.data(), n_elements()); - } - [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view - { - return raft::make_device_vector_view(bitset_.data(), n_elements()); - } - - /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to - * the default value. - * @param res RAFT resources - * @param new_bitset_len new size of the bitset - * @param default_value default value to initialize the new bits to - */ - void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) - { - auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); - auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); - bitset_.resize(new_size); - bitset_len_ = new_bitset_len; - if (old_size < new_size) { - // If the new size is larger, set the new bits to the default value - - thrust::fill_n(resource::get_thrust_policy(res), - bitset_.data() + old_size, - new_size - old_size, - default_value ? ~bitset_t{0} : bitset_t{0}); - } - } - - /** - * @brief Test a list of indices in a bitset. - * - * @tparam output_t Output type of the test. Default is bool. - * @param res RAFT resources - * @param queries List of indices to test - * @param output List of outputs - */ - template - void test(const raft::resources& res, - raft::device_vector_view queries, - raft::device_vector_view output) const - { - RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); - auto bitset_view = view(); - raft::linalg::map( - res, - output, - [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, - queries); - } - /** - * @brief Set a list of indices in a bitset to set_value. - * - * @param res RAFT resources - * @param mask_index indices to remove from the bitset - * @param set_value Value to set the bits to (true or false) - */ - void set(const raft::resources& res, - raft::device_vector_view mask_index, - bool set_value = false) - { - auto this_bitset_view = view(); - thrust::for_each_n(resource::get_thrust_policy(res), - mask_index.data_handle(), - mask_index.extent(0), - [this_bitset_view, set_value] __device__(const index_t sample_index) { - this_bitset_view.set(sample_index, set_value); - }); - } - /** - * @brief Flip all the bits in a bitset. - * @param res RAFT resources - */ - void flip(const raft::resources& res) - { - auto bitset_span = this->to_mdspan(); - raft::linalg::map( - res, - bitset_span, - [] __device__(bitset_t element) { return bitset_t(~element); }, - raft::make_const_mdspan(bitset_span)); - } - /** - * @brief Reset the bits in a bitset. - * - * @param res RAFT resources - * @param default_value Value to set the bits to (true or false) - */ - void reset(const raft::resources& res, bool default_value = true) - { - thrust::fill_n(resource::get_thrust_policy(res), - bitset_.data(), - n_elements(), +template +_RAFT_HOST_DEVICE inline bool bitset_view::test(const index_t sample_index) const +{ + const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; + const index_t bit_index = sample_index % bitset_element_size; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; + return is_bit_set; +} + +template +_RAFT_HOST_DEVICE bool bitset_view::operator[](const index_t sample_index) const +{ + return test(sample_index); +} + +template +_RAFT_HOST_DEVICE void bitset_view::set(const index_t sample_index, + bool set_value) const +{ + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr_ + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr_ + bit_element, bitmask2); + } +} + +template +_RAFT_HOST_DEVICE inline index_t bitset_view::n_elements() const +{ + return raft::ceildiv(bitset_len_, bitset_element_size); +} + +template +bitset::bitset(const raft::resources& res, + raft::device_vector_view mask_index, + index_t bitset_len, + bool default_value) + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + raft::resource::get_cuda_stream(res)}, + bitset_len_{bitset_len} +{ + reset(res, default_value); + set(res, mask_index, !default_value); +} + +template +bitset::bitset(const raft::resources& res, + index_t bitset_len, + bool default_value) + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + raft::resource::get_cuda_stream(res)}, + bitset_len_{bitset_len} +{ + reset(res, default_value); +} + +template +index_t bitset::n_elements() const +{ + return raft::ceildiv(bitset_len_, bitset_element_size); +} + +template +void bitset::resize(const raft::resources& res, + index_t new_bitset_len, + bool default_value) +{ + auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); + auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); + bitset_.resize(new_size); + bitset_len_ = new_bitset_len; + if (old_size < new_size) { + // If the new size is larger, set the new bits to the default value + thrust::fill_n(raft::resource::get_thrust_policy(res), + bitset_.data() + old_size, + new_size - old_size, default_value ? ~bitset_t{0} : bitset_t{0}); } - /** - * @brief Returns the number of bits set to true in count_gpu_scalar. - * - * @param[in] res RAFT resources - * @param[out] count_gpu_scalar Device scalar to store the count - */ - void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) - { - auto n_elements_ = n_elements(); - auto count_gpu = - raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); - auto bitset_matrix_view = raft::make_device_matrix_view( - bitset_.data(), n_elements_, 1); - - bitset_t n_last_element = (bitset_len_ % bitset_element_size); - bitset_t last_element_mask = - n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; - raft::linalg::coalesced_reduction( - res, - bitset_matrix_view, - count_gpu, - index_t{0}, - false, - [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t result = 0; - if constexpr (bitset_element_size == 64) { - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(element & last_element_mask)); - else - result = index_t(raft::detail::popc(element)); - } else { // Needed because popc is not overloaded for 16 and 8 bit elements - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); - else - result = index_t(raft::detail::popc(uint32_t{element})); - } - - return result; - }); - } - /** - * @brief Returns the number of bits set to true. - * - * @param res RAFT resources - * @return index_t Number of bits set to true - */ - auto count(const raft::resources& res) -> index_t - { - auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); - count(res, count_gpu_scalar.view()); - index_t count_cpu = 0; - raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); - return count_cpu; - } - /** - * @brief Checks if any of the bits are set to true in the bitset. - * @param res RAFT resources - */ - bool any(const raft::resources& res) { return count(res) > 0; } - /** - * @brief Checks if all of the bits are set to true in the bitset. - * @param res RAFT resources - */ - bool all(const raft::resources& res) { return count(res) == bitset_len_; } - /** - * @brief Checks if none of the bits are set to true in the bitset. - * @param res RAFT resources - */ - bool none(const raft::resources& res) { return count(res) == 0; } - - private: - raft::device_uvector bitset_; - index_t bitset_len_; -}; +} + +template +template +void bitset::test(const raft::resources& res, + raft::device_vector_view queries, + raft::device_vector_view output) const +{ + RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); + auto bitset_view = view(); + raft::linalg::map( + res, + output, + [bitset_view] __device__(index_t query) { return bitset_view.test(query); }, + queries); +} + +template +void bitset::set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value) +{ + auto this_bitset_view = view(); + thrust::for_each_n(raft::resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [this_bitset_view, set_value] __device__(const index_t sample_index) { + this_bitset_view.set(sample_index, set_value); + }); +} + +template +void bitset::flip(const raft::resources& res) +{ + auto bitset_span = this->to_mdspan(); + raft::linalg::map( + res, + bitset_span, + [] __device__(bitset_t element) { return bitset_t(~element); }, + raft::make_const_mdspan(bitset_span)); +} + +template +void bitset::reset(const raft::resources& res, bool default_value) +{ + thrust::fill_n(raft::resource::get_thrust_policy(res), + bitset_.data(), + n_elements(), + default_value ? ~bitset_t{0} : bitset_t{0}); +} + +template +void bitset::count(const raft::resources& res, + raft::device_scalar_view count_gpu_scalar) +{ + auto values = + raft::make_device_vector_view(bitset_.data(), n_elements()); + raft::detail::popc(res, values, bitset_len_, count_gpu_scalar); +} -/** @} */ } // end namespace raft::core diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp new file mode 100644 index 0000000000..0df12f25e6 --- /dev/null +++ b/cpp/include/raft/core/bitset.hpp @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::core { +/** + * @defgroup bitset Bitset + * @{ + */ +/** + * @brief View of a RAFT Bitset. + * + * This lightweight structure stores a pointer to a bitset in device memory with it's length. + * It provides a test() device function to check if a given index is set in the bitset. + * + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitset_view { + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; + + _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) + : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} + { + } + /** + * @brief Create a bitset view from a device vector view of the bitset. + * + * @param bitset_span Device vector view of the bitset + * @param bitset_len Number of bits in the bitset + */ + _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span, + index_t bitset_len) + : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len} + { + } + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_HOST_DEVICE auto test(const index_t sample_index) const -> bool; + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_HOST_DEVICE auto operator[](const index_t sample_index) const -> bool; + /** + * @brief Device function to set a given index to set_value in the bitset. + * + * @param sample_index index to set + * @param set_value Value to set the bit to (true or false) + */ + inline _RAFT_HOST_DEVICE void set(const index_t sample_index, bool set_value) const; + + /** + * @brief Get the device pointer to the bitset. + */ + inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; } + /** + * @brief Get the number of bits of the bitset representation. + */ + inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t; + + inline auto to_mdspan() -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, n_elements()); + } + inline auto to_mdspan() const -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, n_elements()); + } + + private: + bitset_t* bitset_ptr_; + index_t bitset_len_; +}; + +/** + * @brief RAFT Bitset. + * + * This structure encapsulates a bitset in device memory. It provides a view() method to get a + * device-usable lightweight view of the bitset. + * Each index is represented by a single bit in the bitset. The total number of bytes used is + * ceil(bitset_len / 8). + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitset { + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; + + /** + * @brief Construct a new bitset object with a list of indices to unset. + * + * @param res RAFT resources + * @param mask_index List of indices to unset in the bitset + * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. + */ + bitset(const raft::resources& res, + raft::device_vector_view mask_index, + index_t bitset_len, + bool default_value = true); + + /** + * @brief Construct a new bitset object + * + * @param res RAFT resources + * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. + */ + bitset(const raft::resources& res, index_t bitset_len, bool default_value = true); + // Disable copy constructor + bitset(const bitset&) = delete; + bitset(bitset&&) = default; + bitset& operator=(const bitset&) = delete; + bitset& operator=(bitset&&) = default; + + /** + * @brief Create a device-usable view of the bitset. + * + * @return bitset_view + */ + inline auto view() -> raft::core::bitset_view + { + return bitset_view(to_mdspan(), bitset_len_); + } + [[nodiscard]] inline auto view() const -> raft::core::bitset_view + { + return bitset_view(to_mdspan(), bitset_len_); + } + + /** + * @brief Get the device pointer to the bitset. + */ + inline auto data() -> bitset_t* { return bitset_.data(); } + inline auto data() const -> const bitset_t* { return bitset_.data(); } + /** + * @brief Get the number of bits of the bitset representation. + */ + inline auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t; + + /** @brief Get an mdspan view of the current bitset */ + inline auto to_mdspan() -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_.data(), n_elements()); + } + [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_.data(), n_elements()); + } + + /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to + * the default value. + * @param res RAFT resources + * @param new_bitset_len new size of the bitset + * @param default_value default value to initialize the new bits to + */ + void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true); + + /** + * @brief Test a list of indices in a bitset. + * + * @tparam output_t Output type of the test. Default is bool. + * @param res RAFT resources + * @param queries List of indices to test + * @param output List of outputs + */ + template + void test(const raft::resources& res, + raft::device_vector_view queries, + raft::device_vector_view output) const; + /** + * @brief Set a list of indices in a bitset to set_value. + * + * @param res RAFT resources + * @param mask_index indices to remove from the bitset + * @param set_value Value to set the bits to (true or false) + */ + void set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value = false); + /** + * @brief Flip all the bits in a bitset. + * @param res RAFT resources + */ + void flip(const raft::resources& res); + /** + * @brief Reset the bits in a bitset. + * + * @param res RAFT resources + * @param default_value Value to set the bits to (true or false) + */ + void reset(const raft::resources& res, bool default_value = true); + /** + * @brief Returns the number of bits set to true in count_gpu_scalar. + * + * @param[in] res RAFT resources + * @param[out] count_gpu_scalar Device scalar to store the count + */ + void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar); + /** + * @brief Returns the number of bits set to true. + * + * @param res RAFT resources + * @return index_t Number of bits set to true + */ + auto count(const raft::resources& res) -> index_t + { + auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); + count(res, count_gpu_scalar.view()); + index_t count_cpu = 0; + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + return count_cpu; + } + /** + * @brief Checks if any of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool any(const raft::resources& res) { return count(res) > 0; } + /** + * @brief Checks if all of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool all(const raft::resources& res) { return count(res) == bitset_len_; } + /** + * @brief Checks if none of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool none(const raft::resources& res) { return count(res) == 0; } + + private: + raft::device_uvector bitset_; + index_t bitset_len_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/core/detail/logger.hpp b/cpp/include/raft/core/detail/logger.hpp index 532aee4d90..f3f52b46ae 100644 --- a/cpp/include/raft/core/detail/logger.hpp +++ b/cpp/include/raft/core/detail/logger.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in future releases." \ " Please use the version instead.") +#endif #include diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index 82db75de84..253d8e5b93 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -24,23 +24,19 @@ #include #include -#include +#include #include #include -#include #include namespace raft::common::nvtx::detail { /** - * @brief An internal struct to store associated state with the color - * generator + * @brief An internal struct to to initialize the color generator */ -struct color_gen_state { - /** collection of all tagged colors generated so far */ - static inline std::unordered_map all_colors_; - /** mutex for accessing the above map */ - static inline std::mutex map_mutex_; +struct color_gen { + /** This determines how many bits of the hash to use for the generator */ + using hash_type = uint16_t; /** saturation */ static inline constexpr float kS = 0.9f; /** value */ @@ -109,32 +105,22 @@ inline auto hsv2rgb(float h, float s, float v) -> uint32_t /** * @brief Helper method to generate 'visually distinct' colors. * Inspired from https://martin.ankerl.com/2009/12/09/how-to-create-random-colors-programmatically/ - * However, if an associated tag is passed, it will look up in its history for - * any generated color against this tag and if found, just returns it, else - * generates a new color, assigns a tag to it and stores it for future usage. + * It calculates a hash of the passed string and uses the result to generate + * distinct yet deterministic colors. * Such a thing is very useful for nvtx markers where the ranges associated * with a specific tag should ideally get the same color for the purpose of * visualizing it on nsight-systems timeline. - * @param tag look for any previously generated colors with this tag or - * associate the currently generated color with it + * @param tag a string used as an input to generate a distinct color. * @return returns 32b RGB integer with alpha channel set of 0xff */ inline auto generate_next_color(const std::string& tag) -> uint32_t { - // std::unordered_map color_gen_state::all_colors_; - // std::mutex color_gen_state::map_mutex_; - - std::lock_guard guard(color_gen_state::map_mutex_); - if (!tag.empty()) { - auto itr = color_gen_state::all_colors_.find(tag); - if (itr != color_gen_state::all_colors_.end()) { return itr->second; } - } - auto h = static_cast(rand()) / static_cast(RAND_MAX); - h += color_gen_state::kInvPhi; + auto x = static_cast(std::hash{}(tag)); + auto u = std::numeric_limits::max(); + auto h = static_cast(x) / static_cast(u); + h += color_gen::kInvPhi; if (h >= 1.f) h -= 1.f; - auto rgb = hsv2rgb(h, color_gen_state::kS, color_gen_state::kV); - if (!tag.empty()) { color_gen_state::all_colors_[tag] = rgb; } - return rgb; + return hsv2rgb(h, color_gen::kS, color_gen::kV); } template diff --git a/cpp/include/raft/core/detail/popc.cuh b/cpp/include/raft/core/detail/popc.cuh new file mode 100644 index 0000000000..d74b68b715 --- /dev/null +++ b/cpp/include/raft/core/detail/popc.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::detail { + +/** + * @brief Count the number of bits that are set to 1 in a vector. + * + * @tparam value_t the value type of the vector. + * @tparam index_t the index type of vector and scalar. + * + * @param[in] res raft handle for managing expensive resources + * @param[in] values Number of row in the matrix. + * @param[in] max_len Maximum number of bits to count. + * @param[out] counter Number of bits that are set to 1. + */ +template +void popc(const raft::resources& res, + device_vector_view values, + index_t max_len, + raft::device_scalar_view counter) +{ + auto values_size = values.size(); + auto values_matrix = raft::make_device_matrix_view( + values.data_handle(), values_size, 1); + auto counter_vector = raft::make_device_vector_view(counter.data_handle(), 1); + + static constexpr index_t len_per_item = sizeof(value_t) * 8; + + value_t tail_len = (max_len % len_per_item); + value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0}; + raft::linalg::coalesced_reduction( + res, + values_matrix, + counter_vector, + index_t{0}, + false, + [tail_mask, values_size] __device__(value_t value, index_t index) { + index_t result = 0; + if constexpr (len_per_item == 64) { + if (index == values_size - 1) + result = index_t(raft::detail::popc(value & tail_mask)); + else + result = index_t(raft::detail::popc(value)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == values_size - 1) + result = index_t(raft::detail::popc(uint32_t{value} & tail_mask)); + else + result = index_t(raft::detail::popc(uint32_t{value})); + } + + return result; + }); +} + +} // end namespace raft::detail \ No newline at end of file diff --git a/cpp/include/raft/core/device_container_policy.hpp b/cpp/include/raft/core/device_container_policy.hpp index 8c6eff582b..18d8b77364 100644 --- a/cpp/include/raft/core/device_container_policy.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -31,7 +31,8 @@ #include #include -#include +#include +#include #include @@ -117,7 +118,7 @@ class device_uvector { */ explicit device_uvector(std::size_t size, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) : data_{size, stream, mr} { } @@ -164,19 +165,11 @@ class device_uvector_policy { public: auto create(raft::resources const& res, size_t n) -> container_type { - if (mr_ == nullptr) { - // NB: not using the workspace resource by default! - // The workspace resource is for short-lived temporary allocations. - return container_type(n, resource::get_cuda_stream(res)); - } else { - return container_type(n, resource::get_cuda_stream(res), mr_); - } + return container_type(n, resource::get_cuda_stream(res), mr_); } constexpr device_uvector_policy() = default; - constexpr explicit device_uvector_policy(rmm::mr::device_memory_resource* mr) noexcept : mr_(mr) - { - } + explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {} [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference { @@ -192,7 +185,7 @@ class device_uvector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } private: - rmm::mr::device_memory_resource* mr_{nullptr}; + rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()}; }; } // namespace raft diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 855642cd76..a34f6e2e02 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -21,6 +21,8 @@ #include #include +#include + #include namespace raft { @@ -107,7 +109,7 @@ template auto make_device_mdarray(raft::resources const& handle, - rmm::mr::device_memory_resource* mr, + rmm::device_async_resource_ref mr, extents exts) { using mdarray_t = device_mdarray; diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 366e387fdd..856ecc96d7 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -37,6 +37,7 @@ #include #include +#include #include @@ -120,7 +121,7 @@ class device_resources : public resources { cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); } - rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); } + rmm::exec_policy_nosync& get_thrust_policy() const { return resource::get_thrust_policy(*this); } /** * @brief synchronize a stream on the current container diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 9aa9e4fb85..b785010a0a 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,16 @@ namespace raft::resource { * @{ */ +class device_memory_resource : public resource { + public: + explicit device_memory_resource(std::shared_ptr mr) : mr_(mr) {} + ~device_memory_resource() override = default; + auto get_resource() -> void* override { return mr_.get(); } + + private: + std::shared_ptr mr_; +}; + class limiting_memory_resource : public resource { public: limiting_memory_resource(std::shared_ptr mr, @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource { } }; +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class large_workspace_resource_factory : public resource_factory { + public: + explicit large_workspace_resource_factory( + std::shared_ptr mr = {nullptr}) + : mr_{mr ? mr + : std::shared_ptr{ + rmm::mr::get_current_device_resource(), void_op{}}} + { + } + auto get_resource_type() -> resource_type override + { + return resource_type::LARGE_WORKSPACE_RESOURCE; + } + auto make_resource() -> resource* override { return new device_memory_resource(mr_); } + + private: + std::shared_ptr mr_; +}; + /** * Factory that knows how to construct a specific raft::resource to populate * the resources instance. @@ -144,7 +177,7 @@ class workspace_resource_factory : public resource_factory { // Note, the workspace does not claim all this memory from the start, so it's still usable by // the main resource as well. // This limit is merely an order for algorithm internals to plan the batching accordingly. - return total_size / 2; + return total_size / 4; } }; @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource( workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt)); }; +inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource* +{ + if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE); +}; + +inline void set_large_workspace_resource(resources const& res, + std::shared_ptr mr = { + nullptr}) +{ + res.add_resource_factory(std::make_shared(mr)); +}; + /** @} */ } // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index d2021728c4..d9126251c9 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -28,23 +28,24 @@ namespace raft::resource { */ enum resource_type { // device-specific resource types - CUBLAS_HANDLE = 0, // cublas handle - CUSOLVER_DN_HANDLE, // cusolver dn handle - CUSOLVER_SP_HANDLE, // cusolver sp handle - CUSPARSE_HANDLE, // cusparse handle - CUDA_STREAM_VIEW, // view of a cuda stream - CUDA_STREAM_POOL, // cuda stream pool - CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams - COMMUNICATOR, // raft communicator - SUB_COMMUNICATOR, // raft sub communicator - DEVICE_PROPERTIES, // cuda device properties - DEVICE_ID, // cuda device id - STREAM_VIEW, // view of a cuda stream or a placeholder in - // CUDA-free builds - THRUST_POLICY, // thrust execution policy - WORKSPACE_RESOURCE, // rmm device memory resource - CUBLASLT_HANDLE, // cublasLt handle - CUSTOM, // runtime-shared default-constructible resource + CUBLAS_HANDLE = 0, // cublas handle + CUSOLVER_DN_HANDLE, // cusolver dn handle + CUSOLVER_SP_HANDLE, // cusolver sp handle + CUSPARSE_HANDLE, // cusparse handle + CUDA_STREAM_VIEW, // view of a cuda stream + CUDA_STREAM_POOL, // cuda stream pool + CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams + COMMUNICATOR, // raft communicator + SUB_COMMUNICATOR, // raft sub communicator + DEVICE_PROPERTIES, // cuda device properties + DEVICE_ID, // cuda device id + STREAM_VIEW, // view of a cuda stream or a placeholder in + // CUDA-free builds + THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations + CUBLASLT_HANDLE, // cublasLt handle + CUSTOM, // runtime-shared default-constructible resource + LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations LAST_KEY // reserved for the last key }; diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp index f81898be8a..c728f0a00e 100644 --- a/cpp/include/raft/core/resource/thrust_policy.hpp +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -24,7 +24,7 @@ namespace raft::resource { class thrust_policy_resource : public resource { public: thrust_policy_resource(rmm::cuda_stream_view stream_view) - : thrust_policy_(std::make_unique(stream_view)) + : thrust_policy_(std::make_unique(stream_view)) { } void* get_resource() override { return thrust_policy_.get(); } @@ -32,7 +32,7 @@ class thrust_policy_resource : public resource { ~thrust_policy_resource() override {} private: - std::unique_ptr thrust_policy_; + std::unique_ptr thrust_policy_; }; /** @@ -60,13 +60,13 @@ class thrust_policy_resource_factory : public resource_factory { * @param res raft res object for managing resources * @return thrust execution policy */ -inline rmm::exec_policy& get_thrust_policy(resources const& res) +inline rmm::exec_policy_nosync& get_thrust_policy(resources const& res) { if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { rmm::cuda_stream_view stream = get_cuda_stream(res); res.add_resource_factory(std::make_shared(stream)); } - return *res.get_resource(resource_type::THRUST_POLICY); + return *res.get_resource(resource_type::THRUST_POLICY); }; /** diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 3e3699766f..951e030cbd 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -256,9 +256,8 @@ void masked_l2_nn_impl(raft::resources const& handle, static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block."); // Get stream and workspace memory resource - rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(resource::get_workspace_resource(handle)); auto stream = resource::get_cuda_stream(handle); + auto ws_mr = resource::get_workspace_resource(handle); // Acquire temporary buffers and initialize to zero: // 1) Adjacency matrix bitfield diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh index ed0b6848ae..cba059154f 100644 --- a/cpp/include/raft/distance/specializations.cuh +++ b/cpp/include/raft/distance/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message( \ __FILE__ \ " is deprecated and will be removed." \ " Including specializations is not necessary any more." \ " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") +#endif diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index ed0b6848ae..cba059154f 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message( \ __FILE__ \ " is deprecated and will be removed." \ " Including specializations is not necessary any more." \ " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") +#endif diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 9588a7f329..e85b05575f 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message( \ __FILE__ \ " is deprecated and will be removed." \ " Including specializations is not necessary any more." \ " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") +#endif diff --git a/cpp/include/raft/lap/lap.cuh b/cpp/include/raft/lap/lap.cuh index f7828294cd..b06cd113c1 100644 --- a/cpp/include/raft/lap/lap.cuh +++ b/cpp/include/raft/lap/lap.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,11 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the raft/solver version instead.") +#endif #include diff --git a/cpp/include/raft/lap/lap.hpp b/cpp/include/raft/lap/lap.hpp index 5472422053..0f1ad14ed5 100644 --- a/cpp/include/raft/lap/lap.hpp +++ b/cpp/include/raft/lap/lap.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,10 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the cuh version instead.") +#endif #include diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index d580ea72c1..9f3be7ce0e 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -28,11 +28,18 @@ namespace raft { namespace linalg { namespace detail { -template +template struct ReductionThinPolicy { - static constexpr int LogicalWarpSize = warpSize; - static constexpr int RowsPerBlock = rpb; - static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; + static_assert(tpb % warpSize == 0); + + static constexpr int LogicalWarpSize = warpSize; + static constexpr int ThreadsPerBlock = tpb; + static constexpr int RowsPerLogicalWarp = rpw; + static constexpr int NumLogicalWarps = ThreadsPerBlock / LogicalWarpSize; + static constexpr int RowsPerBlock = NumLogicalWarps * RowsPerLogicalWarp; + + // Whether D (run-time arg) will be smaller than warpSize (compile-time parameter) + static constexpr bool NoSequentialReduce = noLoop; }; template (blockIdx.x)); - if (i >= N) return; + /* The strategy to achieve near-SOL memory bandwidth differs based on D: + * - For small D, we need to process multiple rows per logical warp in order to have + * multiple loads per thread and increase bytes in flight and amortize latencies. + * - For large D, we start with a sequential reduction. The compiler partially unrolls + * that loop (e.g. first a loop of stride 16, then 8, 4, and 1). + */ + IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i0 >= N) return; - OutType acc = init; - for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { - acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + OutType acc[Policy::RowsPerLogicalWarp]; +#pragma unroll + for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) { + acc[k] = init; } - acc = raft::logicalWarpReduce(acc, reduce_op); - if (threadIdx.x == 0) { + + if constexpr (Policy::NoSequentialReduce) { + IdxType j = threadIdx.x; + if (j < D) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + // Only the first row is known to be within bounds. Clamp to avoid out-of-mem read. + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + } + } + } else { + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + } + } + } + + /* This vector reduction has two benefits compared to naive separate reductions: + * - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles + * are required and there is no initial sequential reduction to amortize that cost). + * - It distributes the outputs to multiple threads, enabling a coalesced store when the number of + * rows per logical warp and logical warp size are equal. + */ + raft::logicalWarpReduceVector( + acc, threadIdx.x, reduce_op); + + constexpr int reducOutVecWidth = + std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize); + constexpr int reducOutGroupSize = + std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp); + constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize; + + if (threadIdx.x % reducOutGroupSize == 0) { + const int groupId = threadIdx.x / reducOutGroupSize; if (inplace) { - dots[i] = final_op(reduce_op(dots[i], acc)); +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(reduce_op(dots[i], acc[k])); } + } } else { - dots[i] = final_op(acc); +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(acc[k]); } + } } } } @@ -89,8 +149,12 @@ void coalescedReductionThin(OutType* dots, FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( - "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); - dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + "coalescedReductionThin<%d,%d,%d,%d>", + Policy::LogicalWarpSize, + Policy::ThreadsPerBlock, + Policy::RowsPerLogicalWarp, + static_cast(Policy::NoSequentialReduce)); + dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1); dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); coalescedReductionThinKernel <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); @@ -115,19 +179,28 @@ void coalescedReductionThinDispatcher(OutType* dots, FinalLambda final_op = raft::identity_op()) { if (D <= IdxType(2)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(4)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(8)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(16)) { - coalescedReductionThin>( + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(32)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D < IdxType(128)) { + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { - coalescedReductionThin>( + // For D=128 (included) and above, the 4x-unrolled loading loop is used + // and multiple rows per warp are counter-productive in terms of cache-friendliness + // and register use. + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } } @@ -319,10 +392,10 @@ void coalescedReductionThickDispatcher(OutType* dots, // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( + coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( + coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } } diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index 245f8eb4b0..236c840040 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -15,9 +15,11 @@ */ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Use cublaslt_wrappers.hpp if you really need this low-level api.") +#endif #include "cublaslt_wrappers.hpp" diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index c9dcbda5cc..7b8d35706b 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -18,9 +18,11 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Use raft/linalg/gemm.hpp instead.") +#endif #include "detail/gemm.hpp" #include "gemm.hpp" // Part of the API transferred to the non-deprecated file diff --git a/cpp/include/raft/linalg/lanczos.cuh b/cpp/include/raft/linalg/lanczos.cuh index 04e9980583..0117a8e1d4 100644 --- a/cpp/include/raft/linalg/lanczos.cuh +++ b/cpp/include/raft/linalg/lanczos.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,11 @@ #pragma once +#ifndef RAFT_HIDE_DEPRECATION_WARNINGS #pragma message(__FILE__ \ " is deprecated and will be removed in a future release." \ " Please use the sparse solvers version instead.") +#endif #include diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 1f60860c8c..de5f4e62ce 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -18,9 +18,11 @@ #include "detail/normalize.cuh" +#include #include #include #include +#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 506cbffcb9..6db1a5acac 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -20,9 +20,6 @@ #include #include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view -#include // rmm::mr::device_memory_resource - #include // __half #include // uint32_t diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 36a346fda3..2207b0216e 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -29,9 +29,9 @@ #include #include +#include #include -#include -#include +#include #include #include @@ -442,14 +442,76 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } } -template +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + const int pass, + const T*& out_buf, + const IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + out_buf = const_cast(reinterpret_cast(bufs) + buf_len); + out_idx_buf = + const_cast(reinterpret_cast(bufs + sizeof(T) * 2 * buf_len) + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } +} + +template RAFT_KERNEL last_filter_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, const IdxT len, + const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -458,22 +520,31 @@ RAFT_KERNEL last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + const IdxT buf_len = calc_buf_len(len); - if (previous_len > buf_len || in_buf == in) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int pass = calc_num_passes() - 1; constexpr int start_bit = calc_start_bit(pass); + set_buf_pointers(in + l_offset, in_idx + l_offset, bufs, buf_len, pass, in_buf, in_idx_buf); + + if (previous_len > buf_len || in_buf == in + l_offset) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + out += batch_id * k; + out_idx += batch_id * k; + const auto kth_value_bits = counter->kth_value_bits; const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; @@ -510,6 +581,29 @@ RAFT_KERNEL last_filter_kernel(const T* in, f); } +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_val( + T* dest, const T* src, S len, IdxT k, const bool select_min) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + const T default_val = select_min ? upper_bound() : lower_bound(); + for (S i = idx; i < k; i += stride) { + dest[i] = i < len ? src[i] : default_val; + } +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_idx(T* dest, const T* src, S len) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + + for (S i = idx; i < len; i += stride) { + dest[i] = src ? src[i] : i; + } +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -545,13 +639,16 @@ RAFT_KERNEL last_filter_kernel(const T* in, * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and * their indices. */ -template +template RAFT_KERNEL radix_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, Counter* counters, @@ -567,21 +664,38 @@ RAFT_KERNEL radix_kernel(const T* in, IdxT current_k; IdxT previous_len; IdxT current_len; + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + if (pass == 0) { current_k = k; - previous_len = len; + previous_len = l_len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() - current_len = len; + current_len = l_len; } else { current_k = counter->k; current_len = counter->len; previous_len = counter->previous_len; } + if constexpr (!len_or_indptr) { + if (pass == 0 && l_len <= k) { + copy_in_val(out + batch_id * k, in + l_offset, l_len, k, select_min); + copy_in_idx(out_idx + batch_id * k, (in_idx ? (in_idx + l_offset) : nullptr), l_len); + if (threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + __syncthreads(); + return; + } + } + if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle @@ -590,20 +704,33 @@ RAFT_KERNEL radix_kernel(const T* in, const bool early_stop = (current_len == current_k); const IdxT buf_len = calc_buf_len(len); + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + set_buf_pointers(in + l_offset, + (in_idx ? (in_idx + l_offset) : nullptr), + bufs, + buf_len, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + // "previous_len > buf_len" means previous pass skips writing buffer if (pass == 0 || pass == 1 || previous_len > buf_len) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; } // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -611,9 +738,6 @@ RAFT_KERNEL radix_kernel(const T* in, if (pass == 0 || current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; - } else { - out_buf += batch_id * buf_len; - out_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; @@ -640,7 +764,6 @@ RAFT_KERNEL radix_kernel(const T* in, unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } - if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { @@ -676,7 +799,7 @@ RAFT_KERNEL radix_kernel(const T* in, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, counter, select_min, @@ -726,7 +849,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, radix_kernel, BlockSize, 0)); + &active_blocks, radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; @@ -757,78 +880,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) return best_num_blocks; } -template -_RAFT_HOST void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } -} - -template -_RAFT_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - char* bufs, - IdxT buf_len, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - } else if (pass % 2 == 0) { - in_buf = reinterpret_cast(bufs); - in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - out_buf = const_cast(in_buf + buf_len); - out_idx_buf = const_cast(in_idx_buf + buf_len); - } else { - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - in_buf = out_buf + buf_len; - in_idx_buf = out_idx_buf + buf_len; - } -} - -template +template void radix_topk(const T* in, const IdxT* in_idx, int batch_size, @@ -842,15 +894,13 @@ void radix_topk(const T* in, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { // TODO: is it possible to relax this restriction? static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - - auto kernel = radix_kernel; + auto kernel = radix_kernel; const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, false); if (max_chunk_size != static_cast(batch_size)) { @@ -862,55 +912,33 @@ void radix_topk(const T* in, rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - auto kernel = radix_kernel; + auto kernel = radix_kernel; - const T* chunk_in = in + offset * len; - const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; - T* chunk_out = out + offset * k; - IdxT* chunk_out_idx = out_idx + offset * k; - const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; dim3 blocks(grid_dim, chunk_size); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(chunk_in, - chunk_in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - pass, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - if (fused_last_filter && pass == num_passes - 1) { - kernel = radix_kernel; + kernel = radix_kernel; } - kernel<<>>(chunk_in, - chunk_in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, + kernel<<>>(in, + in_idx, + bufs.data(), + offset, chunk_out, chunk_out_idx, counters.data(), @@ -924,16 +952,18 @@ void radix_topk(const T* in, } if (!fused_last_filter) { - last_filter_kernel<<>>(chunk_in, - chunk_in_idx, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - len, - k, - counters.data(), - select_min); + last_filter_kernel + <<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + len, + chunk_len_i, + k, + counters.data(), + select_min); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } @@ -1015,7 +1045,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, } } -template +template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, @@ -1024,30 +1054,48 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, T* out, IdxT* out_idx, const bool select_min, - char* bufs) + char* bufs, + size_t offset) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; __shared__ IdxT histogram[num_buckets]; + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + + IdxT l_len = len; + IdxT l_offset = (offset + batch_id) * len; + if constexpr (!len_or_indptr) { + l_offset = len_i[batch_id]; + l_len = len_i[batch_id + 1] - l_offset; + } + if (threadIdx.x == 0) { counter.k = k; - counter.len = len; - counter.previous_len = len; + counter.len = l_len; + counter.previous_len = l_len; counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; } __syncthreads(); - const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow - in += batch_id * len; - if (in_idx) { in_idx += batch_id * len; } + in += l_offset; + if (in_idx) { in_idx += l_offset; } out += batch_id * k; out_idx += batch_id * k; const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + if constexpr (!len_or_indptr) { + if (l_len <= k) { + copy_in_val(out, in, l_len, k, select_min); + copy_in_idx(out_idx, in_idx, l_len); + __syncthreads(); + return; + } + } + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { const T* in_buf; @@ -1073,7 +1121,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -1102,7 +1150,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_buf ? out_idx_buf : in_idx, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, &counter, select_min, @@ -1117,7 +1165,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // counters and global histograms, can be kept in shared memory and cheap sync operations can be // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. -template +template void radix_topk_one_block(const T* in, const IdxT* in_idx, int batch_size, @@ -1129,11 +1177,11 @@ void radix_topk_one_block(const T* in, const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; + auto kernel = radix_topk_one_block_kernel; const IdxT buf_len = calc_buf_len(len); const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, true); @@ -1144,15 +1192,16 @@ void radix_topk_one_block(const T* in, for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - kernel<<>>(in + offset * len, - in_idx ? (in_idx + offset * len) : nullptr, + kernel<<>>(in, + in_idx, len, chunk_len_i, k, out + offset * k, out_idx + offset * k, select_min, - bufs.data()); + bufs.data(), + offset); } } @@ -1182,6 +1231,10 @@ void radix_topk_one_block(const T* in, * it affects the number of passes and number of buckets. * @tparam BlockSize * Number of threads in a kernel thread block. + * @tparam len_or_indptr + * Flag to interpret `len_i` as either direct row lengths (true) or CSR format + * index pointers (false). When true, each `len_i` element denotes the length of a row. When + * false, `len_i` represents the index pointers for a CSR matrix with shape of `batch_size + 1`. * * @param[in] res container of reusable resources * @param[in] in @@ -1212,9 +1265,12 @@ void radix_topk_one_block(const T* in, * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param len_i - * optional array of size (batch_size) providing lengths for each individual row + * Optional array used differently based on `len_or_indptr`: + * When `len_or_indptr` is true, `len_i` presents the lengths of each row, which is `batch_size`. + * When `len_or_indptr` is false, `len_i` works like a indptr for a CSR matrix. The length of each + * row would be (`len_i[row_id + 1] - len_i[row_id]`). `len_i` size is `batch_size + 1`. */ -template +template void select_k(raft::resources const& res, const T* in, const IdxT* in_idx, @@ -1227,9 +1283,12 @@ void select_k(raft::resources const& res, bool fused_last_filter, const IdxT* len_i) { + RAFT_EXPECTS(!(!len_or_indptr && (len_i == nullptr)), + "When `len_or_indptr` is false, `len_i` must not be nullptr!"); + auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); - if (k == len) { + if (k == len && len_or_indptr) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1248,29 +1307,29 @@ void select_k(raft::resources const& res, constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { - impl::radix_topk(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - fused_last_filter, - len_i, - grid_dim, - sm_cnt, - stream, - mr); + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + len_i, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 572558153d..7da659291c 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,8 +27,9 @@ #include #include +#include #include -#include +#include #include #include @@ -754,22 +755,32 @@ template