From 51de3b155c3d82b5410bb49c6845f9580b12a84f Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 23 Sep 2024 13:56:05 -0700 Subject: [PATCH] Add STOPPED to the failure cases for Sagemaker Training Jobs --- airflow/providers/amazon/aws/hooks/sagemaker.py | 3 ++- airflow/providers/amazon/aws/sensors/sagemaker.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index af131697a5e8d..2c0f4fb25edc5 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -155,6 +155,7 @@ class SageMakerHook(AwsBaseHook): endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"} pipeline_non_terminal_states = {"Executing", "Stopping"} failed_states = {"Failed"} + training_failed_states = {*failed_states, "Stopped"} def __init__(self, *args, **kwargs): super().__init__(client_type="sagemaker", *args, **kwargs) @@ -309,7 +310,7 @@ def create_training_job( self.check_training_status_with_log( config["TrainingJobName"], self.non_terminal_states, - self.failed_states, + self.training_failed_states, wait_for_completion, check_interval, max_ingestion_time, diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index b01e24cd5b815..af07c504aa29d 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -238,7 +238,7 @@ def non_terminal_states(self): return SageMakerHook.non_terminal_states def failed_states(self): - return SageMakerHook.failed_states + return SageMakerHook.training_failed_states def get_sagemaker_response(self): if self.print_log: