Skip to content

Commit

Permalink
Add FMHA T5x test (#442)
Browse files Browse the repository at this point in the history
Adding the JAX T5x FMHA E2E system test to check for fmha lowering
support. Following are the steps implemented in the test:

FMHA lowering flag is enabled by default now, enabled the dumping of hlo
to track fmha forward and backward instructions.
Added the test as part of _ci.yaml file and also added a nightly
workflow file for it. We will add this test as part of performance
benchmarking later and add hlo to baseline.
Also added changes for correction of seq length of decoder (should be a
multiple of 64)

The test was failing with following error related to
CUDNN_STATUS_BAD_PARAM. The fix for this is added in the [PR]
(openxla/xla#6872) in upstream which is now
merged and the test passes.
[Bug](https://nvbugspro.nvidia.com/bug/4409713) for this error.

run for these changes: [workflow run
link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7894631992)

---------

Co-authored-by: Terry Kong <terryk@nvidia.com>
  • Loading branch information
hmonishN and terrykong committed Mar 8, 2024
1 parent 9cced45 commit 8e8320f
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 11 deletions.
55 changes: 53 additions & 2 deletions .github/container/test-t5x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ usage() {
echo " -e, --epochs Number of epochs to run, defaults to 7."
echo " --multiprocess Enable the multiprocess GPU mode."
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder"
echo " --seed INT Random seed for deterministim. Defaults to 42."
echo " -s, --steps-per-epoch INT Steps per epoch. Detauls to 100"
echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0"
echo " -h, --help Print usage."
exit $1
}

args=$(getopt -o a:b:cd:e:ho:s: --long additional-args:,batch-size:,use-contrib-configs,dtype:,enable-te:,epochs:,help,multiprocess,output:,seed:,steps-per-epoch: -- "$@")
args=$(getopt -o a:b:cd:e:ho:s: --long additional-args:,batch-size:,use-contrib-configs,dtype:,enable-te:,enable-fmha:,epochs:,help,multiprocess,output:,seed:,save-hlo:,steps-per-epoch: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
Expand All @@ -43,6 +45,8 @@ OUTPUT=$(mktemp -d)
SEED=42
STEPS_PER_EPOCH=100
ENABLE_TE=${ENABLE_TE:-0}
ENABLE_FMHA=${ENABLE_FMHA:-0}
SAVE_HLO=${SAVE_HLO:-1}

eval set -- "$args"
while [ : ]; do
Expand All @@ -67,6 +71,10 @@ while [ : ]; do
ENABLE_TE="$2"
shift 2
;;
--enable-fmha)
ENABLE_FMHA="$2"
shift 2
;;
-e | --epochs)
EPOCHS="$2"
shift 2
Expand All @@ -82,6 +90,10 @@ while [ : ]; do
OUTPUT="$2"
shift 2
;;
--save-hlo)
SAVE_HLO="$2"
shift 2
;;
--seed)
SEED="$2"
shift 2
Expand All @@ -105,6 +117,20 @@ if [[ $BATCH_SIZE == 0 ]]; then
usage 1
fi

# Set hlo dump folder after output folder is set.
HLO_DIR=${OUTPUT}/hlo
export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}"
export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"
echo "HLO will be dumped in ${HLO_DIR} dir."

## Setting the env variables for FMHA
if [[ "$ENABLE_FMHA" -eq "1" ]]; then
echo "Setting XLA FMHA Flags";
export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}"
export XLA_FLAGS="${BASE_XLA_FLAGS_FMHA} ${XLA_FLAGS:-}"
fi

echo "XLA FLAGS: $XLA_FLAGS"
## Set derived variables

TRAIN_STEPS=$(($EPOCHS * $STEPS_PER_EPOCH))
Expand All @@ -114,11 +140,13 @@ print_var BATCH_SIZE
print_var USE_CONTRIB_CONFIGS
print_var DTYPE
print_var ENABLE_TE
print_var ENABLE_FMHA
print_var EPOCHS
print_var OUTPUT
print_var MULTIPROCESS
print_var STEPS_PER_EPOCH
print_var TRAIN_STEPS
print_var SAVE_HLO

## Enter T5X source folder
T5X_DIR=$(dirname `python -c 'import t5x; print(*t5x.__path__)'`)
Expand Down Expand Up @@ -178,7 +206,7 @@ $(
import dummy_wikipedia
MIXTURE_OR_TASK_NAME = "dummy_wikipedia"
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128}
DROPOUT_RATE = 0.0
USE_CACHED_TASKS = False
TRAIN_STEPS = %gin.REQUIRED
Expand Down Expand Up @@ -206,3 +234,26 @@ ENABLE_TE=$ENABLE_TE python -m t5x.train \
$ADDITIONAL_ARGS \
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu)
echo "Output at ${OUTPUT}"

if [[ "$ENABLE_FMHA" -eq "1" ]]; then
## Check if fmha instructions are present in the HLO dumped file or not.
fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*"
result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt)

if [[ $SAVE_HLO -eq 0 ]]; then
rm -rf $HLO_DIR
echo "Removed dumped HLO directory!"
fi

if [ -z "$result" ]; then
echo "E: No FMHA instructions were found in the hlo files!"
exit 1
else
echo -e "Found FMHA instructions in the following HLO files: \n $result"
fi
else
if [[ $SAVE_HLO -eq 0 ]]; then
rm -rf $HLO_DIR
echo "Removed dumped HLO directory!"
fi
fi
76 changes: 67 additions & 9 deletions .github/workflows/_test_upstream_t5x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,24 @@ jobs:
t5x-multi-gpu:
strategy:
matrix:
N_GPU: [1, 2, 4, 8]
include:
- TEST_NAME: "1P1G"
N_GPU: 1
ADDITIONAL_ARGS: ""
- TEST_NAME: "1P2G"
N_GPU: 2
ADDITIONAL_ARGS: ""
- TEST_NAME: "1P4G"
N_GPU: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: "1P8G"
N_GPU: 8
- TEST_NAME: "1P1G_fmha"
N_GPU: 1
ADDITIONAL_ARGS: "--enable-fmha 1"
- TEST_NAME: "1P2G_fmha"
N_GPU: 2
ADDITIONAL_ARGS: "--enable-fmha 1"
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -70,7 +87,7 @@ jobs:
shell: bash -x -e {0}
run: |
IMAGE="$(echo ${{inputs.T5X_IMAGE}} | sed 's/\//#/')"
TEST_CASE_NAME=1P${{ matrix.N_GPU }}G
TEST_CASE_NAME=${{ matrix.TEST_NAME }}
JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME}
LOG_FILE=/nfs/cluster/${JOB_NAME}.log
MODEL_PATH=/nfs/cluster/${JOB_NAME}
Expand Down Expand Up @@ -114,10 +131,11 @@ jobs:
--dtype bfloat16 \
--batch-size ${{ steps.meta.outputs.BATCH_SIZE }} \
--epochs 7 \
--steps-per-epoch 100
--steps-per-epoch 100 \
${{ matrix.ADDITIONAL_ARGS }}
EOF
)
echo "SLURM_JOB_ID=${JOB}" >> $GITHUB_OUTPUT
. .github/workflows/scripts/wait_for_slurm_job.sh
Expand Down Expand Up @@ -174,8 +192,47 @@ jobs:
t5x-multi-node:
strategy:
matrix:
N_GPU: [1, 2, 4, 8]
N_NODE: [1, 2]
include:
- TEST_NAME: "1G1N"
N_GPU: 1
N_NODE: 1
ADDITIONAL_ARGS: ""
- TEST_NAME: "2G1N"
N_GPU: 2
N_NODE: 1
ADDITIONAL_ARGS: ""
- TEST_NAME: "4G1N"
N_GPU: 4
N_NODE: 1
ADDITIONAL_ARGS: ""
- TEST_NAME: "8G1N"
N_GPU: 8
N_NODE: 1
ADDITIONAL_ARGS: ""
- TEST_NAME: "1G2N"
N_GPU: 1
N_NODE: 2
ADDITIONAL_ARGS: ""
- TEST_NAME: "2G2N"
N_GPU: 2
N_NODE: 2
ADDITIONAL_ARGS: ""
- TEST_NAME: "4G2N"
N_GPU: 4
N_NODE: 2
ADDITIONAL_ARGS: ""
- TEST_NAME: "8G2N"
N_GPU: 8
N_NODE: 2
ADDITIONAL_ARGS: ""
- TEST_NAME: "2G2N_fmha"
N_GPU: 2
N_NODE: 2
ADDITIONAL_ARGS: "--enable-fmha 1"
- TEST_NAME: "8G2N_fmha"
N_GPU: 8
N_NODE: 2
ADDITIONAL_ARGS: "--enable-fmha 1"
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -207,9 +264,9 @@ jobs:
shell: bash -x -e {0}
run: |
IMAGE="$(echo ${{inputs.T5X_IMAGE}} | sed 's/\//#/')"
TEST_CASE_NAME=${{ matrix.N_GPU }}G${{ matrix.N_NODE }}N
TEST_CASE_NAME=${{ matrix.TEST_NAME }}
TOTAL_TASKS=$((${{ matrix.N_GPU }} * ${{ matrix.N_NODE }}))
JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME}
JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME};
LOG_FILE=/nfs/cluster/${JOB_NAME}.log
MODEL_PATH=/nfs/cluster/${JOB_NAME}
BATCH_SIZE=$((${{ inputs.BATCH_SIZE_PER_GPU }} * ${{ matrix.N_GPU }} * ${{ matrix.N_NODE }}))
Expand Down Expand Up @@ -254,7 +311,8 @@ jobs:
--batch-size ${{ steps.meta.outputs.BATCH_SIZE }} \
--epochs 7 \
--steps-per-epoch 100 \
--multiprocess
--multiprocess \
${{ matrix.ADDITIONAL_ARGS }}
EOF
)
Expand Down

0 comments on commit 8e8320f

Please sign in to comment.