Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression in JAX FP8 matmul fusion #24051

Open
balancap opened this issue Oct 1, 2024 · 4 comments
Open

Regression in JAX FP8 matmul fusion #24051

balancap opened this issue Oct 1, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@balancap
Copy link

balancap commented Oct 1, 2024

Description

For Graphcore JAX Scalify FP8 research project, we have documented how the XLA compiler fuses FP8 matmul with inputs/outputs scaling as well as abs-max capture (Scalify aims at doing end-to-end automatic scale propagation in JAX, similar to autograd). See https://github.com/graphcore-research/jax-scalify/blob/main/docs/JAX%20FP8%20matmul%20tutorial.ipynb for the full notebook.

The following code used to work with JAX 0.4.30:

def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_fp32 = a_fp8.astype(jnp.float32) * a_scale
    b_fp32 = b_fp8.astype(jnp.float32) * b_scale
    
    # Do the matmul (NOTE: adding transpose to reduce on last axis).
    d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_fp32 = d_fp32 * d_scale
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_fp32.astype(jnp.float8_e4m3fn)

leading to a single HLO fused custom call (with __cublas$lt$matmul$f8 target):

ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %constant_1 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom_call_target="__cublas$lt$matmul$f8"
  ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0
}

From 0.4.31 and later version, it generates the following sub-optimal HLO (note it is calling dot in FP32!):

%gemm_fusion_dot.17_computation (parameter_0: f8e4m3fn[32,64], parameter_1: f32[], parameter_2: f8e4m3fn[128,64], parameter_3: f32[]) -> f32[32,128] {
  %parameter_0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %convert.3 = f32[32,64]{1,0} convert(f8e4m3fn[32,64]{1,0} %parameter_0)
  %parameter_1 = f32[] parameter(1)
  %broadcast.8 = f32[32,64]{1,0} broadcast(f32[] %parameter_1), dimensions={}
  %multiply.3 = f32[32,64]{1,0} multiply(f32[32,64]{1,0} %convert.3, f32[32,64]{1,0} %broadcast.8)
  %parameter_2 = f8e4m3fn[128,64]{1,0} parameter(2)
  %convert.4 = f32[128,64]{1,0} convert(f8e4m3fn[128,64]{1,0} %parameter_2)
  %parameter_3 = f32[] parameter(3)
  %broadcast.10 = f32[128,64]{1,0} broadcast(f32[] %parameter_3), dimensions={}
  %multiply.4 = f32[128,64]{1,0} multiply(f32[128,64]{1,0} %convert.4, f32[128,64]{1,0} %broadcast.10)
  ROOT %dot.1 = f32[32,128]{1,0} dot(f32[32,64]{1,0} %multiply.3, f32[128,64]{1,0} %multiply.4), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}

%fused_convert (param_0.3: f32[], param_1.3: f32[32,128]) -> f8e4m3fn[32,128] {
  %constant_8_1 = f32[] constant(-448)
  %broadcast.12.1 = f32[32,128]{1,0} broadcast(f32[] %constant_8_1), dimensions={}
  %param_1.3 = f32[32,128]{1,0} parameter(1)
  %constant_0_1 = f32[] constant(0)
  %broadcast.13.1 = f32[32,128]{1,0} broadcast(f32[] %constant_0_1), dimensions={}
  %maximum.2.1 = f32[32,128]{1,0} maximum(f32[32,128]{1,0} %param_1.3, f32[32,128]{1,0} %broadcast.13.1)
  %param_0.3 = f32[] parameter(0)
  %broadcast.15.1 = f32[32,128]{1,0} broadcast(f32[] %param_0.3), dimensions={}
  %multiply.5.1 = f32[32,128]{1,0} multiply(f32[32,128]{1,0} %maximum.2.1, f32[32,128]{1,0} %broadcast.15.1)
  %constant_6_1 = f32[] constant(448)
  %broadcast.16.1 = f32[32,128]{1,0} broadcast(f32[] %constant_6_1), dimensions={}
  %clamp.1.1 = f32[32,128]{1,0} clamp(f32[32,128]{1,0} %broadcast.12.1, f32[32,128]{1,0} %multiply.5.1, f32[32,128]{1,0} %broadcast.16.1)
  ROOT %convert.5.1 = f8e4m3fn[32,128]{1,0} convert(f32[32,128]{1,0} %clamp.1.1)
}

ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %gemm_fusion_dot.17.0 = f32[32,128]{1,0} fusion(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f32[] %Arg_2.3.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_3.4.0), kind=kCustom, calls=%gemm_fusion_dot.17_computation
  ROOT %loop_convert_fusion = f8e4m3fn[32,128]{1,0} fusion(f32[] %Arg_4.5.0, f32[32,128]{1,0} %gemm_fusion_dot.17.0), kind=kLoop, calls=%fused_convert
}

Fusing FP8 scaling + abs-max capture is really key to obtaining good FP8 MFU numbers for training.

I should probably opened a similar bug ticket in the OpenXLA project, but it raises questions on JAX side:

  • It's none trivial to write the proper FP8 matmul code fusing scaling + abs-max, I believe documenting it would be super beneficial (happy to help on that side based on the work in this notebook);
  • As proper fusing is critical, should JAX have test covering to make sure there is no regression on this side?

System info (python version, jaxlib version, accelerator, etc.)

Regression from jax==0.4.31 and later versions.
Using H100 GPU hardware.

@balancap balancap added the bug Something isn't working label Oct 1, 2024
@balancap
Copy link
Author

balancap commented Oct 2, 2024

Some update on latest JAX 0.4.33. Using bfloat16 as "quantization" dtype, this piece of code is working:

e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
    b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
    
    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
    # ReLU non-linearity. Note: applied before scaling.
    d_dqt = jax.nn.relu(d_dqt)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_dqt = d_dqt * d_scale.astype(dqt_dtype)
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_dqt.astype(jnp.float8_e4m3fn)

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()

generating:

HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="cb91ee8d36851f9c763ab7e356c24970"}

ENTRY %main.31 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %constant_1_0 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target="__cublas$lt$matmul$f8"
  ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0
}

But as soon as I remove jax.nn.relu, it is back generating a long HLO sub-optimal code (as above)

e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
    b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
    
    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_dqt = d_dqt * d_scale.astype(dqt_dtype)
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_dqt.astype(jnp.float8_e4m3fn)

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()

lyprince pushed a commit to graphcore-research/jax-scalify that referenced this issue Oct 2, 2024
The main trick is to use `dqt_dtype = jnp.bfloat16` instead of `jnp.float32` for "dequantization".
As presented in bug ticket jax-ml/jax#24051, there is still some issues with the latest JAX version
when no `relu` or `abs-max` is used.
balancap added a commit to graphcore-research/jax-scalify that referenced this issue Oct 2, 2024
The main trick is to use `dqt_dtype = jnp.bfloat16` instead of `jnp.float32` for "dequantization".
As presented in bug ticket jax-ml/jax#24051, there is still some issues with the latest JAX version
when no `relu` or `abs-max` is used.
@dfm
Copy link
Collaborator

dfm commented Oct 2, 2024

Thanks for reporting this issue. I expect that the core change in behavior is coming from the XLA side, but it would be useful to first check the input HLO (i.e. before calling .compile()) to see if there are any obvious changes between JAX versions. Can you take a look at that, then we can see if any XLA experts have good suggestions for why you're seeing this change in behavior?

@balancap
Copy link
Author

balancap commented Oct 2, 2024

Good point, here is the HLO before compilation corresponding to:

e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max
# "Dequantization" datatype (note: required to be BF16!)
dqt_dtype = jnp.bfloat16

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)
    b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)
    
    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_dqt = jax.lax.dot(a_dqt, b_dqt.T)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_dqt = d_dqt * d_scale.astype(dqt_dtype)
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_dqt.astype(jnp.float8_e4m3fn)

JAX v0.4.31 (working)

module @jit_matmul_fn_with_scale attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg1: tensor<128x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg2: tensor<f32> {mhlo.layout_mode = "default"}, %arg3: tensor<f32> {mhlo.layout_mode = "default"}, %arg4: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<32x128xf8E4M3FN> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.convert %arg0 : (tensor<32x64xf8E4M3FN>) -> tensor<32x64xbf16>
    %1 = stablehlo.convert %arg2 : (tensor<f32>) -> tensor<bf16>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<bf16>) -> tensor<32x64xbf16>
    %3 = stablehlo.multiply %0, %2 : tensor<32x64xbf16>
    %4 = stablehlo.convert %arg1 : (tensor<128x64xf8E4M3FN>) -> tensor<128x64xbf16>
    %5 = stablehlo.convert %arg3 : (tensor<f32>) -> tensor<bf16>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<bf16>) -> tensor<128x64xbf16>
    %7 = stablehlo.multiply %4, %6 : tensor<128x64xbf16>
    %8 = stablehlo.transpose %7, dims = [1, 0] : (tensor<128x64xbf16>) -> tensor<64x128xbf16>
    %9 = stablehlo.dot_general %3, %8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x64xbf16>, tensor<64x128xbf16>) -> tensor<32x128xbf16>
    %10 = stablehlo.convert %arg4 : (tensor<f32>) -> tensor<bf16>
    %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %12 = stablehlo.multiply %9, %11 : tensor<32x128xbf16>
    %cst = stablehlo.constant dense<-4.480000e+02> : tensor<bf16>
    %cst_0 = stablehlo.constant dense<4.480000e+02> : tensor<bf16>
    %13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %14 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %15 = stablehlo.clamp %13, %12, %14 : tensor<32x128xbf16>
    %16 = stablehlo.convert %15 : (tensor<32x128xbf16>) -> tensor<32x128xf8E4M3FN>
    return %16 : tensor<32x128xf8E4M3FN>
  }
}

JAX v0.4.33/v0.4.32 (regression)

module @jit_matmul_fn_with_scale attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg1: tensor<128x64xf8E4M3FN> {mhlo.layout_mode = "default"}, %arg2: tensor<f32> {mhlo.layout_mode = "default"}, %arg3: tensor<f32> {mhlo.layout_mode = "default"}, %arg4: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<32x128xf8E4M3FN> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.convert %arg0 : (tensor<32x64xf8E4M3FN>) -> tensor<32x64xbf16>
    %1 = stablehlo.convert %arg2 : (tensor<f32>) -> tensor<bf16>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<bf16>) -> tensor<32x64xbf16>
    %3 = stablehlo.multiply %0, %2 : tensor<32x64xbf16>
    %4 = stablehlo.convert %arg1 : (tensor<128x64xf8E4M3FN>) -> tensor<128x64xbf16>
    %5 = stablehlo.convert %arg3 : (tensor<f32>) -> tensor<bf16>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<bf16>) -> tensor<128x64xbf16>
    %7 = stablehlo.multiply %4, %6 : tensor<128x64xbf16>
    %8 = stablehlo.transpose %7, dims = [1, 0] : (tensor<128x64xbf16>) -> tensor<64x128xbf16>
    %9 = stablehlo.dot_general %3, %8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x64xbf16>, tensor<64x128xbf16>) -> tensor<32x128xbf16>
    %10 = stablehlo.convert %arg4 : (tensor<f32>) -> tensor<bf16>
    %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %12 = stablehlo.multiply %9, %11 : tensor<32x128xbf16>
    %cst = stablehlo.constant dense<-4.480000e+02> : tensor<bf16>
    %cst_0 = stablehlo.constant dense<4.480000e+02> : tensor<bf16>
    %13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %14 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<32x128xbf16>
    %15 = stablehlo.clamp %13, %12, %14 : tensor<32x128xbf16>
    %16 = stablehlo.convert %15 : (tensor<32x128xbf16>) -> tensor<32x128xf8E4M3FN>
    return %16 : tensor<32x128xf8E4M3FN>
  }
}

Unless I am mistaken, they seem to be the same :)

@balancap
Copy link
Author

balancap commented Oct 3, 2024

I opened an XLA bug ticket: openxla/xla#17887

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants