Skip to content

Commit

Permalink
Fix FP8 matmul fusion with recent JAX versions (0.4.31+). (#136)
Browse files Browse the repository at this point in the history
The main trick is to use `dqt_dtype = jnp.bfloat16` instead of `jnp.float32` for "dequantization".
As presented in bug ticket jax-ml/jax#24051, there is still some issues with the latest JAX version
when no `relu` or `abs-max` is used.
  • Loading branch information
balancap authored Oct 2, 2024
1 parent e139789 commit 365e5d9
Showing 1 changed file with 62 additions and 42 deletions.
104 changes: 62 additions & 42 deletions docs/JAX FP8 matmul tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "51775bad-18ad-49b7-9371-930b3704a294",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Notebook JAX version: 0.4.31\n",
"Notebook JAX device: cuda:0\n"
]
}
],
"source": [
"import jax\n",
"\n",
"print(f\"Notebook JAX version: {jax.__version__}\")\n",
"print(f\"Notebook JAX device: {jax.devices()[0]}\")"
]
},
{
"cell_type": "markdown",
Expand All @@ -52,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "fb62c752-f7ba-4714-8605-88e2afcff88f",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -110,15 +124,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "9be90f27-5520-45f6-a42d-b309572e6e91",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-10-01 14:39:16.245591: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
"2024-10-02 08:31:01.744162: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.5.82). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
},
{
Expand Down Expand Up @@ -159,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "7edfa758-bf4e-49fa-8c5d-5dc9c0c2c346",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -206,7 +220,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "72d805ea-89b6-457d-9558-ff31fdd23d35",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -280,7 +294,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "1ed9d08e-b18a-4fe7-bcba-72b95ddf6e68",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -325,17 +339,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "b9a608d7-6cf8-457b-8275-bdcacc9b06fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"230c40ffa1e1e3ba7f06e4a65ac9e2bd\"}\n",
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"880fbc3fe38d16fac872dc7542132e26\"}\n",
"\n",
"ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n",
"ENTRY %main.25 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n",
" %constant_1 = f32[] constant(1)\n",
" %Arg_4.5.0 = f32[] parameter(4)\n",
" %Arg_3.4.0 = f32[] parameter(3)\n",
Expand All @@ -352,22 +366,24 @@
],
"source": [
"e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n",
"# \"Dequantization\" datatype (note: required to be BF16!)\n",
"dqt_dtype = jnp.bfloat16\n",
"\n",
"# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n",
"def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n",
" # Dequantize x and y\n",
" a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n",
" b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n",
" a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n",
" b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n",
" \n",
" # Do the matmul (NOTE: adding transpose to reduce on last axis).\n",
" d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n",
" d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n",
" \n",
" # Rescale & clamp to -max/+max FP8 E4M3 values.\n",
" d_fp32 = d_fp32 * d_scale\n",
" d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n",
" # NOTE: clamping is NOT optional for proper pattern matching!\n",
" d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n",
" d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n",
" # (Re)Quantize the scaled matmul output.\n",
" return d_fp32.astype(jnp.float8_e4m3fn)\n",
" return d_dqt.astype(jnp.float8_e4m3fn)\n",
"\n",
"# AOT compilation with JAX, inspecting the (final) HLO module generated.\n",
"fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n",
Expand All @@ -387,17 +403,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "44f28bbb-d4c6-4170-a736-76d667d73f97",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"f1fb5db9dad54941d7d17e04fdbe9515\"}\n",
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"ba54f58f7ec56c7beda9299cd16bb7b2\"}\n",
"\n",
"ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n",
"ENTRY %main.31 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n",
" %constant_1_0 = f32[] constant(1)\n",
" %Arg_4.5.0 = f32[] parameter(4)\n",
" %Arg_3.4.0 = f32[] parameter(3)\n",
Expand All @@ -414,24 +430,26 @@
],
"source": [
"e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n",
"# \"Dequantization\" datatype (note: required to be BF16!)\n",
"dqt_dtype = jnp.bfloat16\n",
"\n",
"# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n",
"def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n",
" # Dequantize x and y\n",
" a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n",
" b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n",
" a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n",
" b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n",
" \n",
" # Do the matmul (NOTE: adding transpose to simplify HLO).\n",
" d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n",
" d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n",
" # ReLU non-linearity. Note: applied before scaling.\n",
" d_fp32 = jax.nn.relu(d_fp32)\n",
" d_dqt = jax.nn.relu(d_dqt)\n",
" \n",
" # Rescale & clamp to -max/+max FP8 E4M3 values.\n",
" d_fp32 = d_fp32 * d_scale\n",
" d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n",
" # NOTE: clamping is NOT optional for proper pattern matching!\n",
" d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n",
" d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n",
" # (Re)Quantize the scaled matmul output.\n",
" return d_fp32.astype(jnp.float8_e4m3fn)\n",
" return d_dqt.astype(jnp.float8_e4m3fn)\n",
"\n",
"# AOT compilation with JAX, inspecting the (final) HLO module generated.\n",
"fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n",
Expand All @@ -449,7 +467,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "2ca21eae-8b0c-454b-b670-1ef0d5935a5c",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -504,17 +522,17 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"id": "a65cf3be-c465-49ae-9e90-2ada54dba84a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"5d38b8087de7ebb664888f640beb2017\"}\n",
"HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"206494040898ad9e7c872e73f922a9e5\"}\n",
"\n",
"ENTRY %main.36 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {\n",
"ENTRY %main.40 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {\n",
" %constant_1_0 = f32[] constant(1)\n",
" %Arg_4.5.0 = f32[] parameter(4)\n",
" %Arg_3.4.0 = f32[] parameter(3)\n",
Expand All @@ -524,7 +542,7 @@
" %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n",
" %get-tuple-element.1.0 = f32[] get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=1\n",
" %get-tuple-element.4 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0\n",
" ROOT %tuple.35.0 = (f8e4m3fn[32,128]{1,0}, f32[]) tuple(f8e4m3fn[32,128]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.1.0)\n",
" ROOT %tuple.39.0 = (f8e4m3fn[32,128]{1,0}, f32[]) tuple(f8e4m3fn[32,128]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.1.0)\n",
"}\n",
"\n",
"\n"
Expand All @@ -533,26 +551,28 @@
],
"source": [
"e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n",
"# \"Dequantization\" datatype (note: required to be BF16!)\n",
"dqt_dtype = jnp.bfloat16\n",
"\n",
"# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n",
"def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n",
" # Dequantize x and y\n",
" a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n",
" b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n",
" a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n",
" b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n",
" \n",
" # Do the matmul (NOTE: adding transpose to simplify HLO).\n",
" d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n",
" d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n",
" # ReLU non-linearity. Note: needs to be before the scaling.\n",
" d_fp32 = jax.nn.relu(d_fp32)\n",
" d_dqt = jax.nn.relu(d_dqt)\n",
" # Delayed rescaling: capture the raw output scaling for latter.\n",
" out_scale = jnp.max(jnp.abs(d_fp32))\n",
" out_scale = jnp.max(jnp.abs(d_dqt)).astype(jnp.float32)\n",
"\n",
" # Rescale & clamp to -max/+max FP8 E4M3 values.\n",
" d_fp32 = d_fp32 * d_scale\n",
" d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n",
" # NOTE: clamping is NOT optional for proper pattern matching!\n",
" d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n",
" d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n",
" # (Re)Quantize the scaled matmul output.\n",
" return d_fp32.astype(jnp.float8_e4m3fn), out_scale\n",
" return d_dqt.astype(jnp.float8_e4m3fn), out_scale\n",
"\n",
"# AOT compilation with JAX, inspecting the (final) HLO module generated.\n",
"fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n",
Expand All @@ -570,7 +590,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"id": "20d4d088-6563-44c2-86a1-ab2c34fe4e8e",
"metadata": {},
"outputs": [
Expand Down

0 comments on commit 365e5d9

Please sign in to comment.