diff --git a/tests/test_sphericaldf.py b/tests/test_sphericaldf.py index de50d1067..a800ba6f9 100644 --- a/tests/test_sphericaldf.py +++ b/tests/test_sphericaldf.py @@ -3,7 +3,7 @@ WIN32 = platform.system() == "Windows" if not WIN32: # Enable 64bit for JAX - from jax.config import config + from jax import config config.update("jax_enable_x64", True) import numpy