-
Notifications
You must be signed in to change notification settings - Fork 18
installation is not working #31
Comments
Instructions to get halfway (python 3.10): pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# now, when importing we get
# ModuleNotFoundError: No module named 'ml_dtypes._ml_dtypes_ext'
# solve as per https://developer.apple.com/forums/thread/737890
pip install ml_dtypes==0.2.0
# now, numpy is not working
pip install -U numpy --force-reinstallation
# now, it can be run However, when I now open python, I get: from jax import numpy as jnp
a = jnp.zeros(5)
# external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. So still not usable. |
Use |
So, I downloaded the wheel, and installed
As far as I can see that should be compatible with jaxlib==0.4.11 (based on the source code) If I run it, I still get
|
Could you please set environment variable import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
import jax
jax.numpy.array([0]) There used to be some useful dll info, not sure how it goes now, tho. Might worth a try. |
So, I reinstalled everything from scratch, just to make sure it's not because of some old environment that I tried: conda env create -n jax
conda activate jax
conda install numpy scipy jupyter
# this should download the same file, just putting it in for reproducability
pip install jaxlib==0.4.11+cuda12.cudnn89 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install jax==0.4.13
conda install nvidiatoolkit And then I ran the script above. I get the following output:
It starts with not finding cuda, but then it does seem to find it. The full traceback is here: XlaRuntimeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 a = jax.numpy.zeros(512)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\numpy\lax_numpy.py:2153, in zeros(shape, dtype)
2151 dtypes.check_user_dtype_supported(dtype, "zeros")
2152 shape = canonicalize_shape(shape)
-> 2153 return lax.full(shape, 0, _jnp_dtype(dtype))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:1206, in full(shape, fill_value, dtype)
1204 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
1205 fill_value = _convert_element_type(fill_value, dtype, weak_type)
-> 1206 return broadcast(fill_value, shape)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:768, in broadcast(operand, sizes)
754 """Broadcasts an array, adding new leading dimensions
755
756 Args:
(...)
765 jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
766 """
767 dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
--> 768 return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\lax\lax.py:797, in broadcast_in_dim(operand, shape, broadcast_dimensions)
795 else:
796 dyn_shape, static_shape = [], shape # type: ignore
--> 797 return broadcast_in_dim_p.bind(
798 operand, *dyn_shape, shape=tuple(static_shape),
799 broadcast_dimensions=tuple(broadcast_dimensions))
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\core.py:815, in EvalTrace.process_primitive(self, primitive, tracers, params)
814 def process_primitive(self, primitive, tracers, params):
--> 815 return primitive.impl(*tracers, **params)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:132, in apply_primitive(prim, *args, **params)
130 try:
131 in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 132 compiled_fun = xla_primitive_callable(
133 prim, in_avals, OrigShardings(in_shardings), **params)
134 except pxla.DeviceAssignmentMismatchError as e:
135 fails, = e.args
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:284, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
282 return f(*args, **kwargs)
283 else:
--> 284 return cached(config._trace_context(), *args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\util.py:277, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
275 @functools.lru_cache(max_size)
276 def cached(_, *args, **kwargs):
--> 277 return f(*args, **kwargs)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:223, in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
221 return out,
222 donated_invars = (False,) * len(in_avals)
--> 223 compiled = _xla_callable_uncached(
224 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
225 orig_in_shardings)
226 if not prim.multiple_results:
227 return lambda *args, **kw: compiled(*args, **kw)[0]
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:253, in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
248 def _xla_callable_uncached(fun: lu.WrappedFun, name, donated_invars,
249 keep_unused, in_avals, orig_in_shardings):
250 computation = sharded_lowering(
251 fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
252 lowering_platform=None)
--> 253 return computation.compile().unsafe_call
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2323, in MeshComputation.compile(self, compiler_options)
2320 executable = MeshExecutable.from_trivial_jaxpr(
2321 **self.compile_args)
2322 else:
-> 2323 executable = UnloadedMeshExecutable.from_hlo(
2324 self._name,
2325 self._hlo,
2326 **self.compile_args,
2327 compiler_options=compiler_options)
2328 if compiler_options is None:
2329 self._executable = executable
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2645, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***)
2642 mesh = i.mesh # type: ignore
2643 break
-> 2645 xla_executable, compile_options = _cached_compilation(
2646 hlo, name, mesh, spmd_lowering,
2647 tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
2648 tuple(host_callbacks), backend, da, pmap_nreps,
2649 compiler_options_keys, compiler_options_values)
2651 if hasattr(backend, "compile_replicated"):
2652 semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\interpreters\pxla.py:2555, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
2550 return None, compile_options
2552 with dispatch.log_elapsed_time(
2553 "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
2554 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2555 xla_executable = dispatch.compile_or_get_cached(
2556 backend, computation, dev, compile_options, host_callbacks)
2557 return xla_executable, compile_options
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:497, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
493 use_compilation_cache = (compilation_cache.is_initialized() and
494 backend.platform in supported_platforms)
496 if not use_compilation_cache:
--> 497 return backend_compile(backend, computation, compile_options,
498 host_callbacks)
500 cache_key = compilation_cache.get_cache_key(
501 computation, devices, compile_options, backend)
503 cached_executable = _cache_read(module_name, cache_key, compile_options,
504 backend)
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
311 @wraps(func)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
File C:\ProgramData\Anaconda3\envs\jax\lib\site-packages\jax\_src\dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
460 return backend.compile(built_c, compile_options=options,
461 host_callbacks=host_callbacks)
462 # Some backends don't have `host_callbacks` option yet
463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. |
Once we get it to work, I could make a conda environment file that hopefully works without having to go through the same options. Would you be interested in including that? |
|
Yes, it does. I also just checked, looking with
Gives, among others I also installed cupy, that works without problems. |
Then does cudnn*.dll exists under that dir? |
Hi there,
First of all, thank you for supporting windows! I've used this build before with great success. However, at the moment it's not working, nor can I find a way to get an older version to work.
I'm trying to set up jax on a windows PC with conda, but the provided instructions do not work anymore. I also can't really get any other version to work.
I'm installing on a laptop, this is the output from nvidia-smi:
I tried:
This might obviously not work for cuda 12.0. However, If i run it with
pip install jax[pip_cuda12] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
I get the same result.
I also tried this for
python==3.11, python==3.10
orpython==3.9
. Same result.When I just download a jaxlib it also does not work, sometimes I get a bit further but no computations can be done and I run into 'AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11''.
What should the python version be? And what would be the right command?
The text was updated successfully, but these errors were encountered: