-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression in JAX FP8 matmul fusion #24051
Comments
Some update on latest JAX 0.4.33. Using e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16
# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs.
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
# Dequantize x and y
a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
# Do the matmul (NOTE: adding transpose to simplify HLO).
d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
# ReLU non-linearity. Note: applied before scaling.
d_dqt = jax.nn.relu(d_dqt)
# Rescale & clamp to -max/+max FP8 E4M3 values.
d_dqt = d_dqt * d_scale.astype(dqt_dtype)
# NOTE: clamping is NOT optional for proper pattern matching!
d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
# (Re)Quantize the scaled matmul output.
return d_dqt.astype(jnp.float8_e4m3fn)
# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile() generating:
But as soon as I remove e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16
# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs.
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
# Dequantize x and y
a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
# Do the matmul (NOTE: adding transpose to simplify HLO).
d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
# Rescale & clamp to -max/+max FP8 E4M3 values.
d_dqt = d_dqt * d_scale.astype(dqt_dtype)
# NOTE: clamping is NOT optional for proper pattern matching!
d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
# (Re)Quantize the scaled matmul output.
return d_dqt.astype(jnp.float8_e4m3fn)
# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile() |
The main trick is to use `dqt_dtype = jnp.bfloat16` instead of `jnp.float32` for "dequantization". As presented in bug ticket jax-ml/jax#24051, there is still some issues with the latest JAX version when no `relu` or `abs-max` is used.
The main trick is to use `dqt_dtype = jnp.bfloat16` instead of `jnp.float32` for "dequantization". As presented in bug ticket jax-ml/jax#24051, there is still some issues with the latest JAX version when no `relu` or `abs-max` is used.
Thanks for reporting this issue. I expect that the core change in behavior is coming from the XLA side, but it would be useful to first check the input HLO (i.e. before calling |
Good point, here is the HLO before compilation corresponding to: e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16
# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs.
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
# Dequantize x and y
a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
# Do the matmul (NOTE: adding transpose to simplify HLO).
d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
# Rescale & clamp to -max/+max FP8 E4M3 values.
d_dqt = d_dqt * d_scale.astype(dqt_dtype)
# NOTE: clamping is NOT optional for proper pattern matching!
d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
# (Re)Quantize the scaled matmul output.
return d_dqt.astype(jnp.float8_e4m3fn) JAX v0.4.31 (working)
JAX v0.4.33/v0.4.32 (regression)
Unless I am mistaken, they seem to be the same :) |
I opened an XLA bug ticket: openxla/xla#17887 |
Description
For Graphcore JAX Scalify FP8 research project, we have documented how the XLA compiler fuses FP8 matmul with inputs/outputs scaling as well as
abs-max
capture (Scalify aims at doing end-to-end automatic scale propagation in JAX, similar to autograd). See https://github.com/graphcore-research/jax-scalify/blob/main/docs/JAX%20FP8%20matmul%20tutorial.ipynb for the full notebook.The following code used to work with JAX 0.4.30:
leading to a single HLO fused custom call (with __cublas$lt$matmul$f8 target):
From 0.4.31 and later version, it generates the following sub-optimal HLO (note it is calling
dot
in FP32!):Fusing FP8 scaling +
abs-max
capture is really key to obtaining good FP8 MFU numbers for training.I should probably opened a similar bug ticket in the OpenXLA project, but it raises questions on JAX side:
System info (python version, jaxlib version, accelerator, etc.)
Regression from
jax==0.4.31
and later versions.Using H100 GPU hardware.
The text was updated successfully, but these errors were encountered: