-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for CUDNN 9.0? #24180
Comments
Yeah Is there a reason you cannot update to CUDNN 9.1? It should be very easy to do: It is probably possible to self-build a jaxlib with CUDNN 9.0 support, but you'd have to do that by adding another entry to the the BUILD files for a 9.0 version. To do this you'd need to add that version here: and then I think you would build jaxlib with GPU enabled with the following options enabled: No promises, but that would probably work. |
🙏 will update the build for future jax
We're shipping a container with other libraries that are tested with CUDNN 9.0, only jax is the outlier.
Thanks! Will try that if all else fails. |
I will note that cudnn promises backwards but not forwards compatibility: So in principle, assuming you believe NVIDIA's promises to that effect (I guess you need your own testing to be sure), you can install 9.1 even for users that expect 9.0 and things should work. |
I am using CUDNN 9.0 with CUDA 12.4 and I tried 2 things to make it work with Jax:
pip3 install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
which does not work given the jax-cuda-releases only supports CUDNN 8.9 and 9.1.This worked for this environment a month ago. Now I see
Did something change? Any suggestions on how to fix this problem if I don't want to update my CUDNN runtime?
The text was updated successfully, but these errors were encountered: