From 77ecaed6f2204284869c5ab30f3c97292a0904de Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Fri, 24 Mar 2023 12:16:45 +0000 Subject: [PATCH] On Windows, call ``os.add_dll_directory`` before trying to load CUDA backend --- pygenn/genn_model.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index 6d1bd84f2f..d50792f3e8 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -47,6 +47,7 @@ from psutil import cpu_count from setuptools import msvc from subprocess import check_call # to call make +import sys from textwrap import dedent from warnings import warn @@ -71,22 +72,6 @@ CurrentSource, CustomUpdate) from .model_preprocessor import prepare_snippet -# Loop through backends in preferential order -backend_modules = OrderedDict() -for b in ["CUDA", "SingleThreadedCPU", "OpenCL"]: - # Try and import - try: - m = import_module(".genn_wrapper." + b + "Backend", "pygenn") - # Ignore failed imports - likely due to non-supported backends - except ImportError as ex: - pass - # Raise any other errors - except: - raise - # Otherwise add to (ordered) dictionary - else: - backend_modules[b] = m - # If we're on windows if system() == "Windows": # Get environment and cache in class, convertings @@ -105,6 +90,29 @@ # **NOTE** shutil.which would be nicer, but isn't in Python < 3.3 _msbuild = find_executable("msbuild", _msvc_env["PATH"]) + # If Python version is newer than 3.8 and CUDA path is in environment + if sys.version_info >= (3, 8) and "CUDA_PATH" in environ: + # Add CUDA bin directory to DLL search directories + from os import add_dll_directory + add_dll_directory(path.join(environ["CUDA_PATH"], "bin")) + + +# Loop through backends in preferential order +backend_modules = OrderedDict() +for b in ["CUDA", "SingleThreadedCPU", "OpenCL"]: + # Try and import + try: + m = import_module(".genn_wrapper." + b + "Backend", "pygenn") + # Ignore failed imports - likely due to non-supported backends + except ImportError as ex: + pass + # Raise any other errors + except: + raise + # Otherwise add to (ordered) dictionary + else: + backend_modules[b] = m + GeNNType = namedtuple("GeNNType", ["np_dtype", "assign_ext_ptr_array", "assign_ext_ptr_single"]) class GeNNModel(object):