Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DataFusionInstanceLink usage #34514

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions airflow/providers/google/cloud/operators/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from googleapiclient.errors import HttpError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates
from airflow.providers.google.cloud.links.datafusion import (
DataFusionInstanceLink,
Expand All @@ -34,16 +34,25 @@
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.providers.google.cloud.utils.helpers import resource_path_to_dict

if TYPE_CHECKING:
from airflow.utils.context import Context


class DataFusionPipelineLinkHelper:
moiseenkov marked this conversation as resolved.
Show resolved Hide resolved
"""Helper class for Pipeline links."""
"""
Helper class for Pipeline links.

.. warning::
This class is deprecated. Consider using ``resource_path_to_dict()`` instead.
"""

@staticmethod
def get_project_id(instance):
raise AirflowProviderDeprecationWarning(
"DataFusionPipelineLinkHelper is deprecated. Consider using resource_path_to_dict() instead."
)
instance = instance["name"]
project_id = next(x for x in instance.split("/") if x.startswith("airflow"))
return project_id
Expand Down Expand Up @@ -114,7 +123,7 @@ def execute(self, context: Context) -> None:
instance = hook.wait_for_operation(operation)
self.log.info("Instance %s restarted successfully", self.instance_name)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -272,7 +281,7 @@ def execute(self, context: Context) -> dict:
instance_name=self.instance_name, location=self.location, project_id=self.project_id
)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -361,7 +370,7 @@ def execute(self, context: Context) -> None:
instance = hook.wait_for_operation(operation)
self.log.info("Instance %s updated successfully", self.instance_name)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down Expand Up @@ -432,7 +441,7 @@ def execute(self, context: Context) -> dict:
project_id=self.project_id,
)

project_id = self.project_id or DataFusionPipelineLinkHelper.get_project_id(instance)
project_id = resource_path_to_dict(resource_name=instance["name"])["projects"]
DataFusionInstanceLink.persist(
context=context,
task_instance=self,
Expand Down
21 changes: 21 additions & 0 deletions airflow/providers/google/cloud/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,24 @@
def normalize_directory_path(source_object: str | None) -> str | None:
"""Makes sure dir path ends with a slash."""
return source_object + "/" if source_object and not source_object.endswith("/") else source_object


def resource_path_to_dict(resource_name: str) -> dict[str, str]:
"""Converts a path-like GCP resource name into a dictionary.

For example, the path `projects/my-project/locations/my-location/instances/my-instance` will be converted
to a dict:
`{"projects": "my-project",
"locations": "my-location",
"instances": "my-instance",}`
"""
if not resource_name:
return {}
path_items = resource_name.split("/")
if len(path_items) % 2:
raise ValueError(
"Invalid resource_name. Expected the path-like name consisting of key/value pairs "
"'key1/value1/key2/value2/...', for example 'projects/<project>/locations/<location>'."
)
iterator = iter(path_items)
return dict(zip(iterator, iterator))
17 changes: 13 additions & 4 deletions tests/providers/google/cloud/operators/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType

HOOK_STR = "airflow.providers.google.cloud.operators.datafusion.DataFusionHook"
RESOURCE_PATH_TO_DICT_STR = "airflow.providers.google.cloud.operators.datafusion.resource_path_to_dict"

TASK_ID = "test_task"
LOCATION = "test-location"
Expand All @@ -54,9 +55,11 @@


class TestCloudDataFusionUpdateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_to_dict):
update_maks = "instance.name"
mock_resource_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionUpdateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand All @@ -78,8 +81,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionRestartInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionRestartInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand All @@ -95,8 +100,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionCreateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionCreateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand Down Expand Up @@ -133,8 +140,10 @@ def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):


class TestCloudDataFusionGetInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionGetInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
Expand Down
19 changes: 18 additions & 1 deletion tests/providers/google/cloud/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,28 @@
# under the License.
from __future__ import annotations

from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
import pytest

from airflow.providers.google.cloud.utils.helpers import normalize_directory_path, resource_path_to_dict


class TestHelpers:
def test_normalize_directory_path(self):
assert normalize_directory_path("dir_path") == "dir_path/"
assert normalize_directory_path("dir_path/") == "dir_path/"
assert normalize_directory_path(None) is None

def test_resource_path_to_dict(self):
resource_name = "key1/value1/key2/value2"
expected_dict = {"key1": "value1", "key2": "value2"}
actual_dict = resource_path_to_dict(resource_name=resource_name)
assert set(actual_dict.items()) == set(expected_dict.items())

def test_resource_path_to_dict_empty(self):
resource_name = ""
expected_dict = {}
assert resource_path_to_dict(resource_name=resource_name) == expected_dict

def test_resource_path_to_dict_fail(self):
with pytest.raises(ValueError):
resource_path_to_dict(resource_name="key/value/key")