Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ptx code type for program #317

Merged
merged 17 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 18 additions & 37 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import ctypes
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional
from warnings import warn

from cuda.core.experimental._device import Device
from cuda.core.experimental._module import ObjectCode
Expand All @@ -23,11 +23,11 @@


# Note: this function is reused in the tests
def _decide_nvjitlink_or_driver():
def _decide_nvjitlink_or_driver() -> bool:
"""Returns True if falling back to the cuLink* driver APIs."""
global _driver_ver, _driver, _nvjitlink
if _driver or _nvjitlink:
return
return _driver is not None

_driver_ver = handle_return(driver.cuDriverGetVersion())
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
Expand All @@ -43,7 +43,7 @@ def _decide_nvjitlink_or_driver():
_nvjitlink = None

if _nvjitlink is None:
warnings.warn(
warn(
"nvJitLink is not installed or too old (<12.3). Therefore it is not usable "
"and the culink APIs will be used instead.",
stacklevel=3,
Expand Down Expand Up @@ -98,78 +98,59 @@ class LinkerOptions:
will be used.
max_register_count : int, optional
Maximum register count.
Maps to: ``-maxrregcount=<N>``.
time : bool, optional
Print timing information to the info log.
Maps to ``-time``.
Default: False.
verbose : bool, optional
Print verbose messages to the info log.
Maps to ``-verbose``.
Default: False.
link_time_optimization : bool, optional
Perform link time optimization.
Maps to: ``-lto``.
Default: False.
ptx : bool, optional
Emit PTX after linking instead of CUBIN; only supported with ``-lto``.
Maps to ``-ptx``.
Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
Default: False.
optimization_level : int, optional
Set optimization level. Only 0 and 3 are accepted.
Maps to ``-O<N>``.
debug : bool, optional
Generate debug information.
Maps to ``-g``
Default: False.
lineinfo : bool, optional
Generate line information.
Maps to ``-lineinfo``.
Default: False.
ftz : bool, optional
Flush denormal values to zero.
Maps to ``-ftz=<n>``.
Default: False.
prec_div : bool, optional
Use precise division.
Maps to ``-prec-div=<n>``.
Default: True.
prec_sqrt : bool, optional
Use precise square root.
Maps to ``-prec-sqrt=<n>``.
Default: True.
fma : bool, optional
Use fast multiply-add.
Maps to ``-fma=<n>``.
Default: True.
kernels_used : List[str], optional
Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple
times.
Maps to ``-kernels-used=<name>``.
variables_used : List[str], optional
Pass a list of variables that are used; any not in the list can be removed.
Maps to ``-variables-used=<name>``
optimize_unused_variables : bool, optional
Assume that if a variable is not referenced in device code, it can be removed.
Maps to: ``-optimize-unused-variables``
Default: False.
xptxas : List[str], optional
ptxas_options : List[str], optional
Pass options to PTXAS.
Maps to: ``-Xptxas=<opt>``.
split_compile : int, optional
Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
compilation (default).
Maps to ``-split-compile=<N>``.
Default: 1.
split_compile_extended : int, optional
A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
option can potentially impact performance of the compiled binary.
Maps to ``-split-compile-extended=<N>``.
Default: 1.
no_cache : bool, optional
Do not cache the intermediate steps of nvJitLink.
Maps to ``-no-cache``.
Default: False.
"""

Expand All @@ -189,7 +170,7 @@ class LinkerOptions:
kernels_used: Optional[List[str]] = None
variables_used: Optional[List[str]] = None
optimize_unused_variables: Optional[bool] = None
xptxas: Optional[List[str]] = None
ptxas_options: Optional[List[str]] = None
split_compile: Optional[int] = None
split_compile_extended: Optional[int] = None
no_cache: Optional[bool] = None
Expand Down Expand Up @@ -239,8 +220,8 @@ def _init_nvjitlink(self):
self.formatted_options.append(f"-variables-used={variable}")
if self.optimize_unused_variables is not None:
self.formatted_options.append("-optimize-unused-variables")
if self.xptxas is not None:
for opt in self.xptxas:
if self.ptxas_options is not None:
for opt in self.ptxas_options:
self.formatted_options.append(f"-Xptxas={opt}")
if self.split_compile is not None:
self.formatted_options.append(f"-split-compile={self.split_compile}")
Expand Down Expand Up @@ -290,21 +271,21 @@ def _init_driver(self):
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
if self.ftz is not None:
raise ValueError("ftz option is deprecated in the driver API")
warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.prec_div is not None:
raise ValueError("prec_div option is deprecated in the driver API")
warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.prec_sqrt is not None:
raise ValueError("prec_sqrt option is deprecated in the driver API")
warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.fma is not None:
raise ValueError("fma options is deprecated in the driver API")
warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.kernels_used is not None:
raise ValueError("kernels_used is deprecated in the driver API")
warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.variables_used is not None:
raise ValueError("variables_used is deprecated in the driver API")
warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.optimize_unused_variables is not None:
raise ValueError("optimize_unused_variables is deprecated in the driver API")
if self.xptxas is not None:
raise ValueError("xptxas option is not supported by the driver API")
warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
if self.ptxas_options is not None:
raise ValueError("ptxas_options option is not supported by the driver API")
if self.split_compile is not None:
raise ValueError("split_compile option is not supported by the driver API")
if self.split_compile_extended is not None:
Expand Down
Loading
Loading