diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index be7a8673ff186..a4f6fcc6b6940 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -365,19 +365,21 @@ def execute(self, context: Context) -> None: # type: ignore[override] job = self._submit_job(hook, job_id="") context["ti"].xcom_push(key="job_id", value=job.job_id) - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryValueCheckTrigger( - conn_id=self.gcp_conn_id, - job_id=job.job_id, - project_id=hook.project_id, - sql=self.sql, - pass_value=self.pass_value, - tolerance=self.tol, - poll_interval=self.poll_interval, - ), - method_name="execute_complete", - ) + if job.running(): + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryValueCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + sql=self.sql, + pass_value=self.pass_value, + tolerance=self.tol, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + self.log.info("Current state of job %s is %s", job.job_id, job.state) def execute_complete(self, context: Context, event: dict[str, Any]) -> None: """ diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 34c1a803cf8eb..54be8c4550226 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -1748,6 +1748,33 @@ def test_bigquery_value_check_async(self, mock_hook, create_task_instance_of_ope exc.value.trigger, BigQueryValueCheckTrigger ), "Trigger is not a BigQueryValueCheckTrigger" + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.defer") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_bigquery_value_check_operator_async_finish_before_deferred( + self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator + ): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + mock_hook.return_value.insert_job.return_value.running.return_value = False + + ti = create_task_instance_of_operator( + BigQueryValueCheckOperator, + dag_id="dag_id", + task_id="check_value", + sql="SELECT COUNT(*) FROM Any", + pass_value=2, + use_legacy_sql=True, + deferrable=True, + ) + + ti.task.execute(MagicMock()) + assert not mock_defer.called + assert mock_execute.called + @pytest.mark.parametrize( "kwargs, expected", [