Skip to content

Commit

Permalink
Merge pull request #579 from genn-team/add_cuda_dll_path
Browse files Browse the repository at this point in the history
Add CUDA DLL path
  • Loading branch information
neworderofjamie authored Apr 19, 2023
2 parents 800bd65 + 77ecaed commit 7afc4c8
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions pygenn/genn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from psutil import cpu_count
from setuptools import msvc
from subprocess import check_call # to call make
import sys
from textwrap import dedent
from warnings import warn

Expand All @@ -71,22 +72,6 @@
CurrentSource, CustomUpdate)
from .model_preprocessor import prepare_snippet

# Loop through backends in preferential order
backend_modules = OrderedDict()
for b in ["CUDA", "SingleThreadedCPU", "OpenCL"]:
# Try and import
try:
m = import_module(".genn_wrapper." + b + "Backend", "pygenn")
# Ignore failed imports - likely due to non-supported backends
except ImportError as ex:
pass
# Raise any other errors
except:
raise
# Otherwise add to (ordered) dictionary
else:
backend_modules[b] = m

# If we're on windows
if system() == "Windows":
# Get environment and cache in class, convertings
Expand All @@ -105,6 +90,29 @@
# **NOTE** shutil.which would be nicer, but isn't in Python < 3.3
_msbuild = find_executable("msbuild", _msvc_env["PATH"])

# If Python version is newer than 3.8 and CUDA path is in environment
if sys.version_info >= (3, 8) and "CUDA_PATH" in environ:
# Add CUDA bin directory to DLL search directories
from os import add_dll_directory
add_dll_directory(path.join(environ["CUDA_PATH"], "bin"))


# Loop through backends in preferential order
backend_modules = OrderedDict()
for b in ["CUDA", "SingleThreadedCPU", "OpenCL"]:
# Try and import
try:
m = import_module(".genn_wrapper." + b + "Backend", "pygenn")
# Ignore failed imports - likely due to non-supported backends
except ImportError as ex:
pass
# Raise any other errors
except:
raise
# Otherwise add to (ordered) dictionary
else:
backend_modules[b] = m

GeNNType = namedtuple("GeNNType", ["np_dtype", "assign_ext_ptr_array", "assign_ext_ptr_single"])

class GeNNModel(object):
Expand Down

0 comments on commit 7afc4c8

Please sign in to comment.