Skip to content

Commit

Permalink
fix(providers/sql): respect soft_fail argument when exception is rais…
Browse files Browse the repository at this point in the history
…ed (#34199)
  • Loading branch information
Lee-W committed Sep 8, 2023
1 parent c5016f7 commit f5c2748
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 20 deletions.
28 changes: 23 additions & 5 deletions airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Any, Sequence

from airflow import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -96,19 +96,37 @@ def poke(self, context: Any):
records = hook.get_records(self.sql, self.parameters)
if not records:
if self.fail_on_empty:
raise AirflowException("No rows returned, raising as per fail_on_empty flag")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = "No rows returned, raising as per fail_on_empty flag"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
else:
return False

first_cell = records[0][0]
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Failure criteria met. self.failure({first_cell}) returned True"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
else:
raise AirflowException(f"self.failure is present, but not callable -> {self.failure}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"self.failure is present, but not callable -> {self.failure}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

if self.success is not None:
if callable(self.success):
return self.success(first_cell)
else:
raise AirflowException(f"self.success is present, but not callable -> {self.success}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"self.success is present, but not callable -> {self.success}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return bool(first_cell)
72 changes: 57 additions & 15 deletions tests/providers/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.dag import DAG
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.sensors.sql import SqlSensor
Expand Down Expand Up @@ -117,17 +117,26 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
mock_get_records.return_value = [["1"]]
assert op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
def test_sql_sensor_postgres_poke_fail_on_empty(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", fail_on_empty=True
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
fail_on_empty=True,
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
Expand All @@ -148,10 +157,19 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
mock_get_records.return_value = [["1"]]
assert not op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure(self, mock_hook):
def test_sql_sensor_postgres_poke_failure(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1]
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -161,17 +179,23 @@ def test_sql_sensor_postgres_poke_failure(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
def test_sql_sensor_postgres_poke_failure_success(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [2],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -181,20 +205,26 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

mock_get_records.return_value = [[2]]
assert op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
def test_sql_sensor_postgres_poke_failure_success_same(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -204,40 +234,52 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
def test_sql_sensor_postgres_poke_invalid_failure(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=[1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException) as ctx:
with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.failure is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
def test_sql_sensor_postgres_poke_invalid_success(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
success=[1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException) as ctx:
with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

Expand Down

0 comments on commit f5c2748

Please sign in to comment.