Skip to content

Commit

Permalink
Fix #39: Bring add_cu() implementation in line with current Numba
Browse files Browse the repository at this point in the history
Numba 0.58 implements `add_cu()` differently to Numba 0.57 and below.
This commit brings the implementation in the patch in line with Numba
0.58 onwards, and makes the minimum supported Numba version 0.58, so
that we don't have to support multiple different implementations.
  • Loading branch information
gmarkall committed Jan 18, 2024
1 parent d62cc1d commit 3130222
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_numba_version_ok = False
_numba_error = None

required_numba_ver = (0, 57)
required_numba_ver = (0, 58)

mvc_docs_url = (
"https://numba.readthedocs.io/en/stable/cuda/" "minor_version_compatibility.html"
Expand All @@ -29,12 +29,9 @@

if _numba_version_ok:
from numba.core import config
from numba.cuda.cudadrv.driver import FILE_EXTENSION_MAP, Linker, LinkerError

if ver < (0, 58):
from numba.cuda.cudadrv.driver import NvrtcProgram
else:
from numba.cuda.cudadrv.nvrtc import NvrtcProgram
from numba.cuda.cudadrv import nvrtc
from numba.cuda.cudadrv.driver import (driver, FILE_EXTENSION_MAP, Linker,
LinkerError)
else:
# Prevent the definition of PatchedLinker failing if we have no Numba
# Linker - it won't be used anyway.
Expand Down Expand Up @@ -117,16 +114,20 @@ def add_file(self, path, kind):
raise LinkerError from e

def add_cu(self, cu, name):
program = NvrtcProgram(cu, name)
with driver.get_active_context() as ac:
dev = driver.get_device(ac.devnum)
cc = dev.compute_capability

ptx, log = nvrtc.compile(cu, name, cc)

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % name).center(80, "-"))
print(program.ptx.decode())
print(ptx)
print("=" * 80)

# Link the program's PTX using the normal linker mechanism
ptx_name = os.path.splitext(name)[0] + ".ptx"
self.add_ptx(program.ptx.rstrip(b"\x00"), ptx_name)
self.add_ptx(ptx.encode(), ptx_name)

def complete(self):
try:
Expand Down

0 comments on commit 3130222

Please sign in to comment.