Skip to content

Commit

Permalink
Add a test for 64-bit precision of IFFT on GPU.
Browse files Browse the repository at this point in the history
Fixes #23827. The precision fix was in openxla/xla#17598, which has now been integrated into JAX, but we add a test here based on the repro from #23827.

PiperOrigin-RevId: 680633622
  • Loading branch information
dfm authored and Google-ML-Automation committed Sep 30, 2024
1 parent 504bc43 commit ff1c2ac
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lib import version as jaxlib_version
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.numpy.fft import _fft_norm

Expand Down Expand Up @@ -477,5 +478,16 @@ def testFftnormOverflow(self, norm, func_name, dtype):
np_norm = np.reciprocal(np_norm)
self.assertArraysAllClose(jax_norm, np_norm, rtol=3e-8, check_dtypes=False)

def testFftNormalizationPrecision(self):
# reported in https://github.com/jax-ml/jax/issues/23827
if not config.enable_x64.value:
raise self.skipTest("requires jax_enable_x64=true")
if jaxlib_version <= (0, 4, 33):
raise self.skipTest("requires jaxlib version > 0.4.33")
n = 31
a = np.ones((n, 15), dtype="complex128")
self.assertArraysAllClose(
jnp.fft.ifft(a, n=n, axis=1), np.fft.ifft(a, n=n, axis=1), atol=1e-14)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ff1c2ac

Please sign in to comment.