Skip to content

Commit

Permalink
use dynamic CUDA wheels on CUDA 11
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 22, 2025
1 parent 501c8ce commit 107c2ef
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 50 deletions.
30 changes: 12 additions & 18 deletions ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,18 @@ rapids-generate-version > ./VERSION

cd "${package_dir}"

case "${RAPIDS_CUDA_VERSION}" in
12.*)
EXCLUDE_ARGS=(
--exclude "libcublas.so.12"
--exclude "libcublasLt.so.12"
--exclude "libcurand.so.10"
--exclude "libcusolver.so.11"
--exclude "libcusparse.so.12"
--exclude "libnvJitLink.so.12"
--exclude "libucp.so.0"
)
;;
11.*)
EXCLUDE_ARGS=(
--exclude "libucp.so.0"
)
;;
esac
EXCLUDE_ARGS=(
--exclude "libcublas.so.11"
--exclude "libcublas.so.12"
--exclude "libcublasLt.so.11"
--exclude "libcublasLt.so.12"
--exclude "libcurand.so.10"
--exclude "libcusolver.so.11"
--exclude "libcusparse.so.11"
--exclude "libcusparse.so.12"
--exclude "libnvJitLink.so.12"
--exclude "libucp.so.0"
)

if [[ ${package_name} != "libraft" ]]; then
EXCLUDE_ARGS+=(
Expand Down
11 changes: 0 additions & 11 deletions ci/build_wheel_libraft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,5 @@ export PIP_NO_BUILD_ISOLATION=0

RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})"

case "${RAPIDS_CUDA_VERSION}" in
12.*)
EXTRA_CMAKE_ARGS="-DUSE_CUDA_MATH_WHEELS=ON"
;;
11.*)
EXTRA_CMAKE_ARGS="-DUSE_CUDA_MATH_WHEELS=OFF"
;;
esac

export SKBUILD_CMAKE_ARGS="${EXTRA_CMAKE_ARGS}"

ci/build_wheel.sh libraft ${package_dir} cpp
ci/validate_wheel.sh ${package_dir} final_dist libraft
1 change: 1 addition & 0 deletions ci/validate_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RAPIDS_CUDA_MAJOR="${RAPIDS_CUDA_VERSION%%.*}"
# some packages are much larger on CUDA 11 than on CUDA 12
PYDISTCHECK_ARGS=()
if [[ "${package_name}" == "libraft" ]]; then
# TODO(jameslamb): revise these thresholds
if [[ "${RAPIDS_CUDA_MAJOR}" == "11" ]]; then
PYDISTCHECK_ARGS+=(
--max-allowed-size-compressed '750M'
Expand Down
5 changes: 4 additions & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,14 @@ dependencies:
- nvidia-curand-cu12
- nvidia-cusolver-cu12
- nvidia-cusparse-cu12
# CUDA 11 does not provide wheels, so use the system libraries instead
- matrix:
cuda: "11.*"
use_cuda_wheels: "true"
packages:
- nvidia-cublas-cu11
- nvidia-curand-cu11
- nvidia-cusolver-cu11
- nvidia-cusparse-cu11
# if use_cuda_wheels=false is provided, do not add dependencies on any CUDA wheels
# (e.g. for DLFW and pip devcontainers)
- matrix:
Expand Down
31 changes: 11 additions & 20 deletions python/libraft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ project(
LANGUAGES CXX
)

option(USE_CUDA_MATH_WHEELS "Use the CUDA math wheels instead of the system libraries" OFF)

# Check if raft is already available. If so, it is the user's responsibility to ensure that the
# CMake package is also available at build time of the Python raft package.
find_package(raft "${RAPIDS_VERSION}")
Expand All @@ -35,14 +33,8 @@ endif()
unset(raft_FOUND)

# --- CUDA --- #
find_package(CUDAToolkit REQUIRED)
set(CUDA_STATIC_RUNTIME ON)
set(CUDA_STATIC_MATH_LIBRARIES ON)
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.0)
set(CUDA_STATIC_MATH_LIBRARIES OFF)
elseif(USE_CUDA_MATH_WHEELS)
message(FATAL_ERROR "Cannot use CUDA math wheels with CUDA < 12.0")
endif()
set(CUDA_STATIC_MATH_LIBRARIES OFF)

# --- RAFT ---#
set(BUILD_TESTS OFF)
Expand All @@ -52,14 +44,13 @@ set(RAFT_COMPILE_LIBRARY ON)

add_subdirectory(../../cpp raft-cpp)

if(NOT CUDA_STATIC_MATH_LIBRARIES AND USE_CUDA_MATH_WHEELS)
set_property(
TARGET raft_lib
PROPERTY INSTALL_RPATH
"$ORIGIN/../nvidia/cublas/lib"
"$ORIGIN/../nvidia/curand/lib"
"$ORIGIN/../nvidia/cusolver/lib"
"$ORIGIN/../nvidia/cusparse/lib"
"$ORIGIN/../nvidia/nvjitlink/lib"
)
endif()
# assumes libraft.so is installed 2 levels deep, e.g. site-packages/libraft/lib64/libraft.so
set_property(
TARGET raft_lib
PROPERTY INSTALL_RPATH
"$ORIGIN/../../nvidia/cublas/lib"
"$ORIGIN/../../nvidia/curand/lib"
"$ORIGIN/../../nvidia/cusolver/lib"
"$ORIGIN/../../nvidia/cusparse/lib"
"$ORIGIN/../../nvidia/nvjitlink/lib"
)

0 comments on commit 107c2ef

Please sign in to comment.