From ff1c2ac152b6fa5e07724417b83de6b711ab5104 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 30 Sep 2024 10:37:23 -0700 Subject: [PATCH] Add a test for 64-bit precision of IFFT on GPU. Fixes https://github.com/jax-ml/jax/issues/23827. The precision fix was in https://github.com/openxla/xla/pull/17598, which has now been integrated into JAX, but we add a test here based on the repro from https://github.com/jax-ml/jax/issues/23827. PiperOrigin-RevId: 680633622 --- tests/fft_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/fft_test.py b/tests/fft_test.py index 30e82c54336a..e64fa4db1277 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -26,6 +26,7 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_complex from jax._src.numpy.fft import _fft_norm @@ -477,5 +478,16 @@ def testFftnormOverflow(self, norm, func_name, dtype): np_norm = np.reciprocal(np_norm) self.assertArraysAllClose(jax_norm, np_norm, rtol=3e-8, check_dtypes=False) + def testFftNormalizationPrecision(self): + # reported in https://github.com/jax-ml/jax/issues/23827 + if not config.enable_x64.value: + raise self.skipTest("requires jax_enable_x64=true") + if jaxlib_version <= (0, 4, 33): + raise self.skipTest("requires jaxlib version > 0.4.33") + n = 31 + a = np.ones((n, 15), dtype="complex128") + self.assertArraysAllClose( + jnp.fft.ifft(a, n=n, axis=1), np.fft.ifft(a, n=n, axis=1), atol=1e-14) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())