Skip to content

Commit

Permalink
fix: add compatibility layer for attr.string_keyed_label_dict
Browse files Browse the repository at this point in the history
Also stop using `Label(...)` in local_cuda as the dev experience is
really awful when components_mapping fallbacks to `attr.string_dict`
  • Loading branch information
cloudhan committed Jan 8, 2025
1 parent 8c54b42 commit 29ed3f1
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 21 deletions.
5 changes: 4 additions & 1 deletion cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda/private:compat.bzl", "components_mapping_compat")
load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda")

cuda_component_tag = tag_class(attrs = {
Expand Down Expand Up @@ -30,10 +31,12 @@ cuda_component_tag = tag_class(attrs = {
cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
"components_mapping": attr.string_keyed_label_dict(
"components_mapping": components_mapping_compat.attr(
doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " +
"Only the repo part of the label is usefull",
),
"version": attr.string(),
"nvcc_version": attr.string(),
})

def _find_modules(module_ctx):
Expand Down
26 changes: 26 additions & 0 deletions cuda/private/compat.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_is_attr_string_keyed_label_dict_available = getattr(attr, "string_keyed_label_dict", None) != None
_is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@")

def _attr(*args, **kwargs):
"""Compatibility layer for attr.string_keyed_label_dict(...)"""
if _is_attr_string_keyed_label_dict_available:
return attr.string_keyed_label_dict(*args, **kwargs)
else:
return attr.string_dict(*args, **kwargs)

def _repo_str(repo_str_or_repo_label):
"""Get mapped repo as string.
Args:
repo_str_or_repo_label: `"@repo"` or `Label("@repo")` """
if type(repo_str_or_repo_label) == "Label":
canonical_repo_name = repo_str_or_repo_label.repo_name
repo_str = ("@@{}" if _is_bzlmod_enabled else "@{}").format(canonical_repo_name)
return repo_str
else:
return repo_str_or_repo_label

components_mapping_compat = struct(
attr = _attr,
repo_str = _repo_str,
)
45 changes: 29 additions & 16 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("//cuda/private:compat.bzl", "components_mapping_compat")
load("//cuda/private:template_helper.bzl", "template_helper")
load("//cuda/private:templates/registry.bzl", "FULL_COMPONENT_NAME", "REGISTRY")
load("//cuda/private:toolchain.bzl", "register_detected_cuda_toolchains")
Expand Down Expand Up @@ -92,27 +93,33 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
if rc not in repository_ctx.attr.components_mapping:
fail('component "{}" is required.'.format(rc))

is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@")
canonical_nvcc_repo_name = repository_ctx.attr.components_mapping["nvcc"].repo_name
nvcc_repo = ("@@{}" if is_bzlmod_enabled else "@{}").format(canonical_nvcc_repo_name)
nvcc_repo = components_mapping_compat.repo_str(repository_ctx.attr.components_mapping["nvcc"])

bin_ext = ".exe" if _is_windows(repository_ctx) else ""
nvlink = str(Label(nvcc_repo + "//:nvcc/bin/nvlink{}".format(bin_ext)))
link_stub = str(Label(nvcc_repo + "//:nvcc/bin/crt/link.stub"))
bin2c = str(Label(nvcc_repo + "//:nvcc/bin/bin2c{}".format(bin_ext)))
fatbinary = str(Label(nvcc_repo + "//:nvcc/bin/fatbinary{}".format(bin_ext)))
nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext)
nvlink = "{}//:nvcc/bin/nvlink{}".format(nvcc_repo, bin_ext)
link_stub = "{}//:nvcc/bin/crt/link.stub".format(nvcc_repo)
bin2c = "{}//:nvcc/bin/bin2c{}".format(nvcc_repo, bin_ext)
fatbinary = "{}//:nvcc/bin/fatbinary{}".format(nvcc_repo, bin_ext)

nvcc_root = Label(nvcc_repo).workspace_root + "/nvcc"
nvcc_version_major, nvcc_version_minor = _get_nvcc_version(repository_ctx, nvcc_root)
cuda_version_str = repository_ctx.attr.version
if cuda_version_str == None or cuda_version_str == "":
fail("attr version is required.")

nvcc_version_str = repository_ctx.attr.nvcc_version
if nvcc_version_str == None or nvcc_version_str == "":
nvcc_version_str = cuda_version_str

cuda_version_major, cuda_version_minor = cuda_version_str.split(".")[:2]
nvcc_version_major, nvcc_version_minor = nvcc_version_str.split(".")[:2]

return struct(
path = nvcc_root,
# this should have been extracted from cuda.h, reuse nvcc for now
version_major = nvcc_version_major,
version_minor = nvcc_version_minor,
# this is extracted from `nvcc --version`
path = None, # scattered components
version_major = cuda_version_major,
version_minor = cuda_version_minor,
nvcc_version_major = nvcc_version_major,
nvcc_version_minor = nvcc_version_minor,
nvcc_label = nvcc,
nvlink_label = nvlink,
link_stub_label = link_stub,
bin2c_label = bin2c,
Expand Down Expand Up @@ -242,7 +249,9 @@ local_cuda = repository_rule(
implementation = _local_cuda_impl,
attrs = {
"toolkit_path": attr.string(mandatory = False),
"components_mapping": attr.string_keyed_label_dict(),
"components_mapping": components_mapping_compat.attr(),
"version": attr.string(),
"nvcc_version": attr.string(),
},
configure = True,
local = True,
Expand Down Expand Up @@ -329,19 +338,23 @@ def rules_cuda_dependencies():
],
)

def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, register_toolchains = False):
def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False):
"""Populate the local_cuda repo.
Args:
toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically.
components_mapping: dict mapping from component_name to its corresponding cuda_component's repo_name
version: str for cuda toolkit version. Required for deliverable toolkit only.
nvcc_version: str for nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.
register_toolchains: Register the toolchains if enabled.
"""

local_cuda(
name = "local_cuda",
toolkit_path = toolkit_path,
components_mapping = components_mapping,
version = version,
nvcc_version = nvcc_version,
)

if register_toolchains:
Expand Down
8 changes: 4 additions & 4 deletions cuda/private/template_helper.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("//cuda/private:compat.bzl", "components_mapping_compat")
load("//cuda/private:templates/registry.bzl", "REGISTRY")

def _to_forward_slash(s):
Expand Down Expand Up @@ -33,11 +34,10 @@ def _generate_local_cuda_build_impl(repository_ctx, libpath, components, is_loca
template_content.append(repository_ctx.read(frag))

if is_local_cuda and is_deliverable: # generate `@local_cuda//BUILD` for CTK with deliverables
for comp, label in components.items():
for comp in components:
for target in REGISTRY[comp]:
# canonical_repo_name = label.repo_name
apparent_repo_name = label.name
line = 'alias(name = "{target}", actual = "@{repo}//:{target}")'.format(target = target, repo = apparent_repo_name)
repo = components_mapping_compat.repo_str(components[comp])
line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo)
template_content.append(line)

# add an empty line to separate aliased targets from different components
Expand Down
1 change: 1 addition & 0 deletions tests/integration/toolchain_redist/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cuda.toolkit(
"cudart": "@local_cuda_cudart",
"nvcc": "@local_cuda_nvcc",
},
version = "12.6",
)
use_repo(
cuda,
Expand Down

0 comments on commit 29ed3f1

Please sign in to comment.