Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate hook management in LivyOperator #34431

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
"""This module contains the Apache Livy operator."""
from __future__ import annotations

from functools import cached_property
from time import sleep
from typing import TYPE_CHECKING, Any, Sequence

from deprecated.classic import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
from airflow.providers.apache.livy.triggers.livy import LivyTrigger
Expand Down Expand Up @@ -119,41 +122,43 @@ def __init__(
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}

self._livy_hook: LivyHook | None = None
self._batch_id: int | str
self.retry_args = retry_args
self.deferrable = deferrable

def get_hook(self) -> LivyHook:
@cached_property
def hook(self) -> LivyHook:
"""
Get valid hook.

:return: hook
:return: LivyHook
"""
if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook):
self._livy_hook = LivyHook(
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
)
return self._livy_hook
return LivyHook(
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
)

@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
def get_hook(self) -> LivyHook:
"""Get valid hook."""
return self.hook

def execute(self, context: Context) -> Any:
self._batch_id = self.get_hook().post_batch(**self.spark_params)
self._batch_id = self.hook.post_batch(**self.spark_params)
self.log.info("Generated batch-id is %s", self._batch_id)

# Wait for the job to complete
if not self.deferrable:
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id

hook = self.get_hook()
state = hook.get_batch_state(self._batch_id, retry_args=self.retry_args)
state = self.hook.get_batch_state(self._batch_id, retry_args=self.retry_args)
self.log.debug("Batch with id %s is in state: %s", self._batch_id, state.value)
if state not in hook.TERMINAL_STATES:
if state not in self.hook.TERMINAL_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=LivyTrigger(
Expand All @@ -168,11 +173,11 @@ def execute(self, context: Context) -> Any:
)
else:
self.log.info("Batch with id %s terminated with state: %s", self._batch_id, state.value)
hook.dump_batch_logs(self._batch_id)
self.hook.dump_batch_logs(self._batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {self._batch_id} did not succeed")

context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id

def poll_for_termination(self, batch_id: int | str) -> None:
Expand All @@ -181,14 +186,13 @@ def poll_for_termination(self, batch_id: int | str) -> None:

:param batch_id: id of the batch session to monitor.
"""
hook = self.get_hook()
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
while state not in hook.TERMINAL_STATES:
state = self.hook.get_batch_state(batch_id, retry_args=self.retry_args)
while state not in self.hook.TERMINAL_STATES:
self.log.debug("Batch with id %s is in state: %s", batch_id, state.value)
sleep(self._polling_interval)
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
state = self.hook.get_batch_state(batch_id, retry_args=self.retry_args)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value)
hook.dump_batch_logs(batch_id)
self.hook.dump_batch_logs(batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {batch_id} did not succeed")

Expand All @@ -198,7 +202,7 @@ def on_kill(self) -> None:
def kill(self) -> None:
"""Delete the current batch session."""
if self._batch_id is not None:
self.get_hook().delete_batch(self._batch_id)
self.hook.delete_batch(self._batch_id)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Expand All @@ -218,5 +222,5 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
self.task_id,
event["response"],
)
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(event["batch_id"])["appId"])
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(event["batch_id"])["appId"])
return event["batch_id"]
28 changes: 4 additions & 24 deletions tests/providers/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
from airflow.providers.apache.livy.hooks.livy import BatchState
from airflow.providers.apache.livy.operators.livy import LivyOperator
from airflow.utils import db, timezone

Expand Down Expand Up @@ -63,7 +63,6 @@ def side_effect(_, retry_args):
mock_livy.side_effect = side_effect

task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example")
task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID, retry_args=None)
Expand All @@ -87,7 +86,6 @@ def side_effect(_, retry_args):
mock_livy.side_effect = side_effect

task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example")
task._livy_hook = task.get_hook()

with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)
Expand Down Expand Up @@ -147,14 +145,6 @@ def test_deletion(self, mock_get_batch, mock_post, mock_delete):

mock_delete.assert_called_once_with(BATCH_ID)

def test_injected_hook(self):
def_hook = LivyHook(livy_conn_id="livyunittest")

task = LivyOperator(file="sparkapp", dag=self.dag, task_id="livy_example")
task._livy_hook = def_hook

assert task.get_hook() == def_hook

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
Expand All @@ -171,7 +161,7 @@ def test_log_dump(self, mock_get_batch, mock_post, mock_get_logs, mock_get, capl
polling_interval=1,
)
caplog.clear()
with caplog.at_level(level=logging.INFO, logger=task.get_hook().log.name):
with caplog.at_level(level=logging.INFO, logger=task.hook.log.name):
task.execute(context=self.mock_context)

assert "first_line" in caplog.messages
Expand Down Expand Up @@ -200,7 +190,6 @@ def side_effect(_, retry_args):
task = LivyOperator(
file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True
)
task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID, retry_args=None)
Expand All @@ -226,7 +215,6 @@ def side_effect(_, retry_args):
task = LivyOperator(
file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True
)
task._livy_hook = task.get_hook()

with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)
Expand Down Expand Up @@ -287,7 +275,7 @@ def test_execution_with_extra_options_deferrable(
)

task.execute(context=self.mock_context)
assert task.get_hook().extra_options == extra_options
assert task.hook.extra_options == extra_options

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
Expand Down Expand Up @@ -315,14 +303,6 @@ def test_deletion_deferrable(

mock_delete.assert_called_once_with(BATCH_ID)

def test_injected_hook_deferrable(self):
def_hook = LivyHook(livy_conn_id="livyunittest")

task = LivyOperator(file="sparkapp", dag=self.dag, task_id="livy_example", deferrable=True)
task._livy_hook = def_hook

assert task.get_hook() == def_hook

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
Expand All @@ -341,7 +321,7 @@ def test_log_dump_deferrable(self, mock_get_batch, mock_post, mock_get_logs, moc
)
caplog.clear()

with caplog.at_level(level=logging.INFO, logger=task.get_hook().log.name):
with caplog.at_level(level=logging.INFO, logger=task.hook.log.name):
task.execute(context=self.mock_context)

assert "first_line" in caplog.messages
Expand Down