-
-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Minimum JAX version is 0.4.16 #4127
Comments
Thanks @DavidMStraub for opening this, will look into this |
We added a hard pin to the JAX version (to 0.4.27) in #4129. We usually don't unpin it or add bounds instead of pins because the JAX API is unstable and we find that there are breaking changes in every patch release. Please let us know if that helps, @DavidMStraub. |
Thanks for the quick fix! Yes, I think this will work. However, I am wondering whether you should also have |
Ah, thanks for pointing this out! This function was used for the |
PyBaMM Version
24.1
Python Version
3.10.12
Describe the bug
When a JAX version between 0.4.0 and 0.4.15 is installed and PyBaMM is installed without the
[jax]
optional (which can easily happen e.g. in an environment wherepybamm[all]==24.1
is issued afterpybamm[jax]==23.9
has been issued before), PyBaMM cannot be imported becausepybamm.have_jax()
only checks for a version greater thanJAX_VERSION = "0.4"
(so, returnsTrue
), but the importfrom jax.extend import linear_util as lu
only starts working with JAX >= 0.4.16. (See #3671 where this was changed to silence a deprecation warning.)I suggest to add a minimum JAX version (currently 0.4.16) and compare to that in
have_jax
.Steps to Reproduce
Relevant log output
The text was updated successfully, but these errors were encountered: