diff --git a/perfmetrics/scripts/ml_tests/pytorch/run_model.sh b/perfmetrics/scripts/ml_tests/pytorch/run_model.sh index cc2d18538f..16e850d20f 100755 --- a/perfmetrics/scripts/ml_tests/pytorch/run_model.sh +++ b/perfmetrics/scripts/ml_tests/pytorch/run_model.sh @@ -14,6 +14,7 @@ # limitations under the License. PYTORCH_VESRION=$1 +NUM_EPOCHS=80 # Install golang wget -O go_tar.tar.gz https://go.dev/dl/go1.21.4.linux-amd64.tar.gz -q @@ -76,8 +77,13 @@ python -c 'import torch;torch.hub.list("facebookresearch/xcit:main")' # (TulsiShah) TODO: Pytorch 2.0 compile mode has issues (https://github.com/pytorch/pytorch/issues/94599), # which is fixed in pytorch version 2.1.0 (https://github.com/pytorch/pytorch/pull/100071). # We'll remove this workaround once we update our Docker image to use Pytorch 2.1.0 or greater version. +# Reducing the epochs as pytorch2 long haul tests are running on NVIDIA L4 machines, which lack the powerful GPU of +# the NVIDIA A100. So it is taking longer time to complete the training. We will set it back to 80 when the NVIDIA A100 GPU machine +# will be available. if [ ${PYTORCH_VESRION} == "v2" ]; then + NUM_EPOCHS=36 + allowed_functions_file="/opt/conda/lib/python3.10/site-packages/torch/_dynamo/allowed_functions.py" # Update the pytorch library code to bypass the kernel-cache echo "Updating the pytorch library code to Disallow_in_graph distributed API.." @@ -197,7 +203,7 @@ gsutil cp start_time.txt $ARTIFACTS_BUCKET_PATH/ --norm_last_layer False \ --use_fp16 False \ --clip_grad 0 \ - --epochs 80 \ + --epochs $NUM_EPOCHS \ --global_crops_scale 0.25 1.0 \ --local_crops_number 10 \ --local_crops_scale 0.05 0.25 \