Skip to content

Commit

Permalink
feat: make sysroot a feature
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Jan 28, 2025
1 parent 1b36a26 commit 28f6ffc
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
1 change: 1 addition & 0 deletions cuda/private/actions/compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def compile(
cuda_toolchain,
cuda_feature_config,
common.cuda_archs_info,
common.sysroot,
source_file = src.path,
output_file = obj_file.path,
host_compiler = host_compiler,
Expand Down
1 change: 1 addition & 0 deletions cuda/private/actions/dlink.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _compiler_device_link(
cuda_toolchain,
cuda_feature_config,
common.cuda_archs_info,
common.sysroot,
output_file = obj_file.path,
host_compiler = host_compiler,
host_compile_flags = common.host_compile_flags,
Expand Down
12 changes: 10 additions & 2 deletions cuda/private/cuda_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _get_cuda_archs_info(ctx):

def _create_common_info(
cuda_archs_info = None,
sysroot = None,
includes = [],
quote_includes = [],
system_includes = [],
Expand All @@ -194,6 +195,7 @@ def _create_common_info(
Args:
cuda_archs_info: `CudaArchsInfo`.
sysroot: The `sysroot`.
includes: include paths. Can be used with `#include <...>` and `#include "..."`.
quote_includes: include paths. Can be used with `#include "..."`.
system_includes: include paths. Can be used with `#include <...>`.
Expand All @@ -212,6 +214,7 @@ def _create_common_info(
"""
return struct(
cuda_archs_info = cuda_archs_info,
sysroot = sysroot,
includes = includes,
quote_includes = quote_includes,
system_includes = system_includes,
Expand Down Expand Up @@ -279,8 +282,6 @@ def _create_common(ctx):
host_defines = []
host_local_defines = [i for i in attr.host_local_defines]
host_compile_flags = attr._default_host_copts[BuildSettingInfo].value + [i for i in attr.host_copts]
if cc_toolchain.sysroot:
host_compile_flags.append("--sysroot={}".format(cc_toolchain.sysroot))
host_link_flags = []
if hasattr(attr, "host_linkopts"):
host_link_flags.extend([i for i in attr.host_linkopts])
Expand All @@ -295,6 +296,7 @@ def _create_common(ctx):

return _create_common_info(
cuda_archs_info = _get_cuda_archs_info(ctx),
sysroot = getattr(cc_toolchain, "sysroot", None),
includes = includes,
quote_includes = quote_includes,
system_includes = system_includes,
Expand Down Expand Up @@ -388,6 +390,7 @@ def _create_compile_variables(
cuda_toolchain,
feature_configuration,
cuda_archs_info,
sysroot = None,
source_file = None,
output_file = None,
host_compiler = None,
Expand All @@ -407,6 +410,7 @@ def _create_compile_variables(
cuda_toolchain: cuda_toolchain for which we are creating build variables.
feature_configuration: Feature configuration to be queried.
cuda_archs_info: `CudaArchsInfo`
sysroot: The `sysroot`.
source_file: source file for the compilation.
output_file: output file of the compilation.
host_compiler: host compiler path.
Expand All @@ -428,6 +432,7 @@ def _create_compile_variables(
return struct(
arch_specs = arch_specs,
use_arch_native = len(arch_specs) == 0,
sysroot = sysroot,
source_file = source_file,
output_file = output_file,
host_compiler = host_compiler,
Expand All @@ -448,6 +453,7 @@ def _create_device_link_variables(
cuda_toolchain,
feature_configuration,
cuda_archs_info,
sysroot = None,
output_file = None,
host_compiler = None,
host_compile_flags = [],
Expand All @@ -459,6 +465,7 @@ def _create_device_link_variables(
cuda_toolchain: cuda_toolchain for which we are creating build variables.
feature_configuration: Feature configuration to be queried.
cuda_archs_info: `CudaArchsInfo`
sysroot: The `sysroot`.
output_file: output file of the device linking.
host_compiler: host compiler path.
host_compile_flags: flags pass to host compiler.
Expand All @@ -480,6 +487,7 @@ def _create_device_link_variables(
return struct(
arch_specs = arch_specs,
use_arch_native = len(arch_specs) == 0,
sysroot = sysroot,
output_file = output_file,
host_compiler = host_compiler,
host_compile_flags = host_compile_flags,
Expand Down
1 change: 1 addition & 0 deletions cuda/private/toolchain_configs/clang.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _impl(ctx):
flag_set(
actions = [
ACTION_NAMES.cuda_compile,
ACTION_NAMES.device_link,
],
flag_groups = [
flag_group(
Expand Down
20 changes: 20 additions & 0 deletions cuda/private/toolchain_configs/nvcc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,25 @@ def _impl(ctx):
provides = ["compilation_mode"],
)

sysroot_feature = feature(
name = "sysroot",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.cuda_compile,
ACTION_NAMES.device_link,
],
flag_groups = [
flag_group(
flags = ["-Xcompiler", "--sysroot=%{sysroot}"],
expand_if_available = "sysroot",
),
],
),
],
)

ptxas_flags_feature = feature(
name = "ptxas_flags",
enabled = True,
Expand Down Expand Up @@ -498,6 +517,7 @@ def _impl(ctx):
dbg_feature,
opt_feature,
fastbuild_feature,
sysroot_feature,
ptxas_flags_feature,
compiler_input_flags_feature,
compiler_output_flags_feature,
Expand Down

0 comments on commit 28f6ffc

Please sign in to comment.