diff --git a/cuda/private/actions/compile.bzl b/cuda/private/actions/compile.bzl index 5129bdac..3091465a 100644 --- a/cuda/private/actions/compile.bzl +++ b/cuda/private/actions/compile.bzl @@ -75,6 +75,7 @@ def compile( cuda_toolchain, cuda_feature_config, common.cuda_archs_info, + common.sysroot, source_file = src.path, output_file = obj_file.path, host_compiler = host_compiler, diff --git a/cuda/private/actions/dlink.bzl b/cuda/private/actions/dlink.bzl index 412cf4da..eb4ab3f7 100644 --- a/cuda/private/actions/dlink.bzl +++ b/cuda/private/actions/dlink.bzl @@ -75,6 +75,7 @@ def _compiler_device_link( cuda_toolchain, cuda_feature_config, common.cuda_archs_info, + common.sysroot, output_file = obj_file.path, host_compiler = host_compiler, host_compile_flags = common.host_compile_flags, diff --git a/cuda/private/cuda_helper.bzl b/cuda/private/cuda_helper.bzl index 1d5855c4..02e5128e 100644 --- a/cuda/private/cuda_helper.bzl +++ b/cuda/private/cuda_helper.bzl @@ -174,6 +174,7 @@ def _get_cuda_archs_info(ctx): def _create_common_info( cuda_archs_info = None, + sysroot = None, includes = [], quote_includes = [], system_includes = [], @@ -194,6 +195,7 @@ def _create_common_info( Args: cuda_archs_info: `CudaArchsInfo`. + sysroot: The `sysroot`. includes: include paths. Can be used with `#include <...>` and `#include "..."`. quote_includes: include paths. Can be used with `#include "..."`. system_includes: include paths. Can be used with `#include <...>`. @@ -212,6 +214,7 @@ def _create_common_info( """ return struct( cuda_archs_info = cuda_archs_info, + sysroot = sysroot, includes = includes, quote_includes = quote_includes, system_includes = system_includes, @@ -279,8 +282,6 @@ def _create_common(ctx): host_defines = [] host_local_defines = [i for i in attr.host_local_defines] host_compile_flags = attr._default_host_copts[BuildSettingInfo].value + [i for i in attr.host_copts] - if cc_toolchain.sysroot: - host_compile_flags.append("--sysroot={}".format(cc_toolchain.sysroot)) host_link_flags = [] if hasattr(attr, "host_linkopts"): host_link_flags.extend([i for i in attr.host_linkopts]) @@ -295,6 +296,7 @@ def _create_common(ctx): return _create_common_info( cuda_archs_info = _get_cuda_archs_info(ctx), + sysroot = getattr(cc_toolchain, "sysroot", None), includes = includes, quote_includes = quote_includes, system_includes = system_includes, @@ -388,6 +390,7 @@ def _create_compile_variables( cuda_toolchain, feature_configuration, cuda_archs_info, + sysroot = None, source_file = None, output_file = None, host_compiler = None, @@ -407,6 +410,7 @@ def _create_compile_variables( cuda_toolchain: cuda_toolchain for which we are creating build variables. feature_configuration: Feature configuration to be queried. cuda_archs_info: `CudaArchsInfo` + sysroot: The `sysroot`. source_file: source file for the compilation. output_file: output file of the compilation. host_compiler: host compiler path. @@ -428,6 +432,7 @@ def _create_compile_variables( return struct( arch_specs = arch_specs, use_arch_native = len(arch_specs) == 0, + sysroot = sysroot, source_file = source_file, output_file = output_file, host_compiler = host_compiler, @@ -448,6 +453,7 @@ def _create_device_link_variables( cuda_toolchain, feature_configuration, cuda_archs_info, + sysroot = None, output_file = None, host_compiler = None, host_compile_flags = [], @@ -459,6 +465,7 @@ def _create_device_link_variables( cuda_toolchain: cuda_toolchain for which we are creating build variables. feature_configuration: Feature configuration to be queried. cuda_archs_info: `CudaArchsInfo` + sysroot: The `sysroot`. output_file: output file of the device linking. host_compiler: host compiler path. host_compile_flags: flags pass to host compiler. @@ -480,6 +487,7 @@ def _create_device_link_variables( return struct( arch_specs = arch_specs, use_arch_native = len(arch_specs) == 0, + sysroot = sysroot, output_file = output_file, host_compiler = host_compiler, host_compile_flags = host_compile_flags, diff --git a/cuda/private/toolchain_configs/clang.bzl b/cuda/private/toolchain_configs/clang.bzl index 844677c0..9c339538 100644 --- a/cuda/private/toolchain_configs/clang.bzl +++ b/cuda/private/toolchain_configs/clang.bzl @@ -205,6 +205,7 @@ def _impl(ctx): flag_set( actions = [ ACTION_NAMES.cuda_compile, + ACTION_NAMES.device_link, ], flag_groups = [ flag_group( diff --git a/cuda/private/toolchain_configs/nvcc.bzl b/cuda/private/toolchain_configs/nvcc.bzl index d2f0c50c..c93f05a6 100644 --- a/cuda/private/toolchain_configs/nvcc.bzl +++ b/cuda/private/toolchain_configs/nvcc.bzl @@ -382,6 +382,25 @@ def _impl(ctx): provides = ["compilation_mode"], ) + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cuda_compile, + ACTION_NAMES.device_link, + ], + flag_groups = [ + flag_group( + flags = ["-Xcompiler", "--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + ptxas_flags_feature = feature( name = "ptxas_flags", enabled = True, @@ -498,6 +517,7 @@ def _impl(ctx): dbg_feature, opt_feature, fastbuild_feature, + sysroot_feature, ptxas_flags_feature, compiler_input_flags_feature, compiler_output_flags_feature,