Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setup endpoint_url per-service in AWS Connection #34593

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ def _refresh_credentials(self) -> dict[str, Any]:
if assume_role_method not in ("assume_role", "assume_role_with_saml"):
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

sts_client = self.basic_session.client("sts", config=self.config)
sts_client = self.basic_session.client(
"sts",
config=self.config,
endpoint_url=self.conn.get_service_endpoint_url("sts", sts_connection_assume=True),
)

if assume_role_method == "assume_role":
sts_response = self._assume_role(sts_client=sts_client)
Expand Down Expand Up @@ -558,10 +562,33 @@ def conn_config(self) -> AwsConnectionWrapper:
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
)

def _resolve_service_name(self, is_resource_type: bool = False) -> str:
"""Resolve service name based on type or raise an error."""
if exactly_one(self.client_type, self.resource_type):
# It is possible to write simple conditions, however it make mypy unhappy.
if self.client_type:
if is_resource_type:
raise LookupError("Requested `resource_type`, but `client_type` was set instead.")
return self.client_type
elif self.resource_type:
if not is_resource_type:
raise LookupError("Requested `client_type`, but `resource_type` was set instead.")
return self.resource_type

raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
)

@property
def service_name(self) -> str:
"""Extracted botocore/boto3 service name from hook parameters."""
return self._resolve_service_name(is_resource_type=bool(self.resource_type))

@property
def service_config(self) -> dict:
service_name = self.client_type or self.resource_type
return self.conn_config.get_service_config(service_name)
"""Config for hook-specific service from AWS Connection."""
return self.conn_config.get_service_config(service_name=self.service_name)

@property
def region_name(self) -> str | None:
Expand Down Expand Up @@ -609,19 +636,20 @@ def get_client_type(
deferrable: bool = False,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session."""
client_type = self.client_type
service_name = self._resolve_service_name(is_resource_type=False)
session = self.get_session(region_name=region_name, deferrable=deferrable)
endpoint_url = self.conn_config.get_service_endpoint_url(service_name=service_name)
if not isinstance(session, boto3.session.Session):
return session.create_client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=endpoint_url,
config=self._get_config(config),
verify=self.verify,
)

return session.client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=endpoint_url,
config=self._get_config(config),
verify=self.verify,
)
Expand All @@ -632,11 +660,11 @@ def get_resource_type(
config: Config | None = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session."""
resource_type = self.resource_type
service_name = self._resolve_service_name(is_resource_type=True)
session = self.get_session(region_name=region_name)
return session.resource(
resource_type,
endpoint_url=self.conn_config.endpoint_url,
service_name=service_name,
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=service_name),
config=self._get_config(config),
verify=self.verify,
)
Expand All @@ -648,15 +676,9 @@ def conn(self) -> BaseAwsConnection:

:return: boto3.client or boto3.resource
"""
if not exactly_one(self.client_type, self.resource_type):
raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
)
elif self.client_type:
if self.client_type:
return self.get_client_type(region_name=self.region_name)
else:
return self.get_resource_type(region_name=self.region_name)
return self.get_resource_type(region_name=self.region_name)

@property
def async_conn(self):
Expand Down Expand Up @@ -730,7 +752,10 @@ def expand_role(self, role: str, region_name: str | None = None) -> str:
else:
session = self.get_session(region_name=region_name)
_client = session.client(
"iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
service_name="iam",
endpoint_url=self.conn_config.get_service_endpoint_url("iam"),
config=self.config,
verify=self.verify,
)
return _client.get_role(RoleName=role)["Role"]["Arn"]

Expand Down Expand Up @@ -799,10 +824,9 @@ def test_connection(self):
"""
try:
session = self.get_session()
test_endpoint_url = self.conn_config.extra_config.get("test_endpoint_url")
conn_info = session.client(
"sts",
endpoint_url=test_endpoint_url,
service_name="sts",
endpoint_url=self.conn_config.get_service_endpoint_url("sts", sts_test_connection=True),
).get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
Expand Down
42 changes: 40 additions & 2 deletions airflow/providers/amazon/aws/utils/connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,48 @@ class AwsConnectionWrapper(LoggingMixin):
assume_role_method: str | None = field(init=False, default=None)
assume_role_kwargs: dict[str, Any] = field(init=False, default_factory=dict)

# Per AWS Service configuration dictionary where key is name of boto3 ``service_name``
service_config: dict[str, dict[str, Any]] = field(init=False, default_factory=dict)

@cached_property
def conn_repr(self):
return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"

def get_service_config(self, service_name):
return self.extra_dejson.get("service_config", {}).get(service_name, {})
def get_service_config(self, service_name: str) -> dict[str, Any]:
"""Get AWS Service related config dictionary.

:param service_name: Name of botocore/boto3 service.
"""
return self.service_config.get(service_name, {})

def get_service_endpoint_url(
self, service_name: str, *, sts_connection_assume: bool = False, sts_test_connection: bool = False
) -> str | None:
service_config = self.get_service_config(service_name=service_name)
global_endpoint_url = self.endpoint_url

if service_name == "sts" and True in (sts_connection_assume, sts_test_connection):
# There are different logics exists historically for STS Client
# 1. For assume role we never use global endpoint_url
# 2. For test connection we also use undocumented `test_endpoint`\
# 3. For STS as service we might use endpoint_url (default for other services)
global_endpoint_url = None
if sts_connection_assume and sts_test_connection:
raise AirflowException(
"Can't resolve STS endpoint when both "
"`sts_connection` and `sts_test_connection` set to True."
)
elif sts_test_connection:
if "test_endpoint_url" in self.extra_config:
warnings.warn(
"extra['test_endpoint_url'] is deprecated and will be removed in a future release."
" Please set `endpoint_url` in `service_config.sts` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
global_endpoint_url = self.extra_config["test_endpoint_url"]

return service_config.get("endpoint_url", global_endpoint_url)

def __post_init__(self, conn: Connection):
if isinstance(conn, type(self)):
Expand Down Expand Up @@ -182,6 +218,8 @@ def __post_init__(self, conn: Connection):
)

extra = deepcopy(conn.extra_dejson)
self.service_config = extra.get("service_config", {})

session_kwargs = extra.get("session_kwargs", {})
if session_kwargs:
warnings.warn(
Expand Down
36 changes: 35 additions & 1 deletion docs/apache-airflow-providers-amazon/connections/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ Extra (optional)
* ``config_kwargs``: Additional **kwargs** used to construct a
`botocore.config.Config <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html>`__.
To anonymously access public AWS resources (equivalent of `signature_version=botocore.UNSGINED`), set `"signature_version"="unsigned"` within `config_kwargs`.
* ``endpoint_url``: Endpoint URL for the connection.
* ``endpoint_url``: Global Endpoint URL for the connection. You could specify endpoint url per AWS service by utilize
``service_config``, for more details please refer to :ref:`howto/connection:aws:per-service-endpoint-configuration`

* ``verify``: Whether or not to verify SSL certificates.

The following extra parameters used for specific AWS services:
Expand Down Expand Up @@ -343,6 +345,38 @@ The following settings may be used within the ``assume_role_with_saml`` containe
Per-service configuration
^^^^^^^^^^^^^^^^^^^^^^^^^

.. _howto/connection:aws:per-service-endpoint-configuration:

AWS Service Endpoint URL configuration
""""""""""""""""""""""""""""""""""""""

To use ``endpoint_url`` per specific AWS service in the single connection you might setup it in service config.
For enforce to default ``botocore``/``boto3`` behaviour you might set value to ``null``.
The precedence rules are as follows:

1. ``endpoint_url`` specified per service level.
2. ``endpoint_url`` specified in root level of connection extra. Please note that **sts** client which are
uses in assume role or test connection do not use global parameter.
3. Default ``botocore``/``boto3`` behaviour


.. code-block:: json

{
"endpoint_url": "s3.amazonaws.com"
"service_config": {
"s3": {
"endpoint_url": "https://s3.eu-west-1.amazonaws.com"
},
"sts": {
"endpoint_url": "https://sts.eu-west-2.amazonaws.com"
},
"ec2": {
"endpoint_url": null
}
}
}

S3 Bucket configurations
""""""""""""""""""""""""

Expand Down
Loading