Skip to content

Commit

Permalink
Add STOPPED to the failure cases for Sagemaker Training Jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi committed Sep 23, 2024
1 parent 3d2f9c1 commit 51de3b1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class SageMakerHook(AwsBaseHook):
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
pipeline_non_terminal_states = {"Executing", "Stopping"}
failed_states = {"Failed"}
training_failed_states = {*failed_states, "Stopped"}

def __init__(self, *args, **kwargs):
super().__init__(client_type="sagemaker", *args, **kwargs)
Expand Down Expand Up @@ -309,7 +310,7 @@ def create_training_job(
self.check_training_status_with_log(
config["TrainingJobName"],
self.non_terminal_states,
self.failed_states,
self.training_failed_states,
wait_for_completion,
check_interval,
max_ingestion_time,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def non_terminal_states(self):
return SageMakerHook.non_terminal_states

def failed_states(self):
return SageMakerHook.failed_states
return SageMakerHook.training_failed_states

def get_sagemaker_response(self):
if self.print_log:
Expand Down

0 comments on commit 51de3b1

Please sign in to comment.