Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for CUDNN 9.0? #24180

Open
jbkyang-nvi opened this issue Oct 8, 2024 · 3 comments
Open

Support for CUDNN 9.0? #24180

jbkyang-nvi opened this issue Oct 8, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@jbkyang-nvi
Copy link

jbkyang-nvi commented Oct 8, 2024

I am using CUDNN 9.0 with CUDA 12.4 and I tried 2 things to make it work with Jax:

  1. Pip install with pip3 install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html which does not work given the jax-cuda-releases only supports CUDNN 8.9 and 9.1.
    This worked for this environment a month ago. Now I see
Loaded runtime CuDNN library: 9.0.0 but source was compiled with: 9.1.1.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
...
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
  1. Build jaxlib myself with instructions which fails with:
Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/ac0a09c0a1602a90816292620faa8c49/external/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl", line 62, column 13, in _cuda_redist_json_impl
                fail(
Error in fail: The supported CUDNN versions are ["8.6", "8.9.4.25", "8.9.6", "8.9.7.29", "9.1.1", "9.2.0", "9.2.1", "9.3.0", "9.4.0"]. Please provide a supported version in HERMETIC_CUDNN_VERSION environment variable or add JSON URL for CUDNN version=9.0.0.

Did something change? Any suggestions on how to fix this problem if I don't want to update my CUDNN runtime?

@jbkyang-nvi jbkyang-nvi added the enhancement New feature or request label Oct 8, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 8, 2024

Yeah jax_cuda_releases is a legacy thing. It will never be updated. The wheels are shipped on pypi these days.

Is there a reason you cannot update to CUDNN 9.1? It should be very easy to do: pip install nvidia-cudnn-cuda12 will do it, even if you use a local installation of CUDA for everything else. Note we recommend installing CUDA and CUDNN using the pip wheels; doing so is considerably easier.

It is probably possible to self-build a jaxlib with CUDNN 9.0 support, but you'd have to do that by adding another entry to the the BUILD files for a 9.0 version. To do this you'd need to add that version here:
https://github.com/google/tsl/blob/main/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl

and then I think you would build jaxlib with GPU enabled with the following options enabled:
--bazel_options=--override_repository=tsl=/path/to/your/tsl/fork --bazel_options=--repo_env=HERMETIC_CUDNN_VERSION="9.0.0" (where 9.0.0 is the version you added).

No promises, but that would probably work.

@jbkyang-nvi
Copy link
Author

Yeah jax_cuda_releases is a legacy thing. It will never be updated. The wheels are shipped on pypi these days.

🙏 will update the build for future jax

Is there a reason you cannot update to CUDNN 9.1? It should be very easy to do: pip install nvidia-cudnn-cuda12 will do it, even if you use a local installation of CUDA for everything else. Note we recommend installing CUDA and CUDNN using the pip wheels; doing so is considerably easier.

We're shipping a container with other libraries that are tested with CUDNN 9.0, only jax is the outlier.

It is probably possible to self-build a jaxlib with CUDNN 9.0 support, but you'd have to do that by adding another entry to the the BUILD files for a 9.0 version. To do this you'd need to add that version here: https://github.com/google/tsl/blob/main/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl

and then I think you would build jaxlib with GPU enabled with the following options enabled: --bazel_options=--override_repository=tsl=/path/to/your/tsl/fork --bazel_options=--repo_env=HERMETIC_CUDNN_VERSION="9.0.0" (where 9.0.0 is the version you added).

No promises, but that would probably work.

Thanks! Will try that if all else fails.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 8, 2024

Is there a reason you cannot update to CUDNN 9.1? It should be very easy to do: pip install nvidia-cudnn-cuda12 will do it, even if you use a local installation of CUDA for everything else. Note we recommend installing CUDA and CUDNN using the pip wheels; doing so is considerably easier.

We're shipping a container with other libraries that are tested with CUDNN 9.0, only jax is the outlier.

I will note that cudnn promises backwards but not forwards compatibility:
https://docs.nvidia.com/deeplearning/cudnn/latest/developer/forward-compatibility.html#cudnn-api-compatibility

So in principle, assuming you believe NVIDIA's promises to that effect (I guess you need your own testing to be sure), you can install 9.1 even for users that expect 9.0 and things should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants