From f5c2748c3346bdebf445afd615657af8849345dd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 9 Sep 2023 04:09:13 +0800 Subject: [PATCH] fix(providers/sql): respect soft_fail argument when exception is raised (#34199) --- airflow/providers/common/sql/sensors/sql.py | 28 ++++++-- .../providers/common/sql/sensors/test_sql.py | 72 +++++++++++++++---- 2 files changed, 80 insertions(+), 20 deletions(-) diff --git a/airflow/providers/common/sql/sensors/sql.py b/airflow/providers/common/sql/sensors/sql.py index 73505390fc0b8..7eab94e5582aa 100644 --- a/airflow/providers/common/sql/sensors/sql.py +++ b/airflow/providers/common/sql/sensors/sql.py @@ -18,7 +18,7 @@ from typing import Any, Sequence -from airflow import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.hooks.base import BaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.sensors.base import BaseSensorOperator @@ -96,19 +96,37 @@ def poke(self, context: Any): records = hook.get_records(self.sql, self.parameters) if not records: if self.fail_on_empty: - raise AirflowException("No rows returned, raising as per fail_on_empty flag") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = "No rows returned, raising as per fail_on_empty flag" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) else: return False + first_cell = records[0][0] if self.failure is not None: if callable(self.failure): if self.failure(first_cell): - raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = f"Failure criteria met. self.failure({first_cell}) returned True" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) else: - raise AirflowException(f"self.failure is present, but not callable -> {self.failure}") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = f"self.failure is present, but not callable -> {self.failure}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) + if self.success is not None: if callable(self.success): return self.success(first_cell) else: - raise AirflowException(f"self.success is present, but not callable -> {self.success}") + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = f"self.success is present, but not callable -> {self.success}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return bool(first_cell) diff --git a/tests/providers/common/sql/sensors/test_sql.py b/tests/providers/common/sql/sensors/test_sql.py index b14c977bd36f5..7491d03e1a2ca 100644 --- a/tests/providers/common/sql/sensors/test_sql.py +++ b/tests/providers/common/sql/sensors/test_sql.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models.dag import DAG from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.sensors.sql import SqlSensor @@ -117,17 +117,26 @@ def test_sql_sensor_postgres_poke(self, mock_hook): mock_get_records.return_value = [["1"]] assert op.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook): + def test_sql_sensor_postgres_poke_fail_on_empty( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( - task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", fail_on_empty=True + task_id="sql_sensor_check", + conn_id="postgres_default", + sql="SELECT 1", + fail_on_empty=True, + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [] - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): op.poke(None) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") @@ -148,10 +157,19 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook): mock_get_records.return_value = [["1"]] assert not op.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_failure(self, mock_hook): + def test_sql_sensor_postgres_poke_failure( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( - task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1] + task_id="sql_sensor_check", + conn_id="postgres_default", + sql="SELECT 1", + failure=lambda x: x in [1], + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -161,17 +179,23 @@ def test_sql_sensor_postgres_poke_failure(self, mock_hook): assert not op.poke(None) mock_get_records.return_value = [[1]] - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): op.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_failure_success(self, mock_hook): + def test_sql_sensor_postgres_poke_failure_success( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1], success=lambda x: x in [2], + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -181,20 +205,26 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook): assert not op.poke(None) mock_get_records.return_value = [[1]] - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): op.poke(None) mock_get_records.return_value = [[2]] assert op.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook): + def test_sql_sensor_postgres_poke_failure_success_same( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1], success=lambda x: x in [1], + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) @@ -204,40 +234,52 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook): assert not op.poke(None) mock_get_records.return_value = [[1]] - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): op.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook): + def test_sql_sensor_postgres_poke_invalid_failure( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=[1], + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [[1]] - with pytest.raises(AirflowException) as ctx: + with pytest.raises(expected_exception) as ctx: op.poke(None) assert "self.failure is present, but not callable -> [1]" == str(ctx.value) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") - def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook): + def test_sql_sensor_postgres_poke_invalid_success( + self, mock_hook, soft_fail: bool, expected_exception: AirflowException + ): op = SqlSensor( task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", success=[1], + soft_fail=soft_fail, ) mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = [[1]] - with pytest.raises(AirflowException) as ctx: + with pytest.raises(expected_exception) as ctx: op.poke(None) assert "self.success is present, but not callable -> [1]" == str(ctx.value)