Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ACR] Fixing credential_scopes kwarg #19664

Merged
merged 7 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ def __init__(self, endpoint, **kwargs): # pylint: disable=missing-client-constr
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
self.credential_scope = "https://management.core.windows.net/.default"
self._client = ContainerRegistry(
credential=None,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=ExchangeClientAuthenticationPolicy(),
credential_scopes=kwargs.pop("credential_scopes", self.credential_scope),
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, credential, endpoint, **kwargs):
super(ContainerRegistryChallengePolicy, self).__init__()
self._credential = credential
if self._credential is None:
self._exchange_client = AnonymousACRExchangeClient(endpoint)
self._exchange_client = AnonymousACRExchangeClient(endpoint, **kwargs)
else:
self._exchange_client = ACRExchangeClient(endpoint, self._credential, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class ContainerRegistryBaseClient(object):
:param str endpoint: Azure Container Registry endpoint
:param credential: AAD Token for authenticating requests with Azure
:type credential: :class:`azure.identity.DefaultTokenCredential`
:keyword authentication_scope: URL for credential authentication if different from the default
:paramtype authentication_scope: str
:keyword credential_scopes: URL for credential authentication if different from the default
:paramtype credential_scopes: List[str]
"""

def __init__(self, endpoint, credential, **kwargs):
Expand All @@ -40,7 +40,6 @@ def __init__(self, endpoint, credential, **kwargs):
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=auth_policy,
credential_scopes=kwargs.get("credential_scopes", "https://management.core.windows.net/.default"),
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self, endpoint, credential=None, **kwargs):
:param str endpoint: An ACR endpoint
:param credential: The credential with which to authenticate
:type credential: :class:`~azure.core.credentials.TokenCredential`
:keyword authentication_scope: URL for credential authentication if different from the default
:paramtype authentication_scope: str
:keyword credential_scopes: URL for credential authentication if different from the default
:paramtype credential_scopes: List[str]
:returns: None
:raises: None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ def __init__(self, endpoint, credential, **kwargs):
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
self.credential_scope = kwargs.get("authentication_scope", "https://management.core.windows.net/.default")
self.credential_scopes = kwargs.get("credential_scopes", ["https://management.core.windows.net/.default"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a private attr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably should, but the ExchangeClient models are not exposed publicly

self._client = ContainerRegistry(
credential=credential,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=ExchangeClientAuthenticationPolicy(),
credential_scopes=self.credential_scope,
**kwargs
)
self._credential = credential
Expand All @@ -74,7 +73,7 @@ def get_refresh_token(self, service, **kwargs):
def exchange_aad_token_for_refresh_token(self, service=None, **kwargs):
# type: (str, Dict[str, Any]) -> str
refresh_token = self._client.authentication.exchange_aad_access_token_for_acr_refresh_token(
service=service, access_token=self._credential.get_token(self.credential_scope).token, **kwargs
service=service, access_token=self._credential.get_token(*self.credential_scopes).token, **kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is credential_scopes an iterable (the name implies that it is)? If so, does the keyword paramtype need to be updated from str to something like List[str]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be a List[str].

)
return refresh_token.refresh_token

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
self._credential_scope = "https://management.core.windows.net/.default"
self._client = ContainerRegistry(
credential=None,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=ExchangeClientAuthenticationPolicy(),
credential_scopes=kwargs.pop("credential_scopes", self._credential_scope),
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class ContainerRegistryBaseClient(object):
:type endpoint: str
:param credential: AAD Token for authenticating requests with Azure
:type credential: :class:`~azure.identity.DefaultTokenCredential`
:keyword authentication_scope: URL for credential authentication if different from the default
:paramtype authentication_scope: str
:keyword credential_scopes: URL for credential authentication if different from the default
:paramtype credential_scopes: List[str]
"""

def __init__(self, endpoint: str, credential: Optional["AsyncTokenCredential"] = None, **kwargs) -> None:
Expand All @@ -39,7 +39,6 @@ def __init__(self, endpoint: str, credential: Optional["AsyncTokenCredential"] =
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=auth_policy,
credential_scopes=kwargs.get("credential_scopes", "https://management.core.windows.net/.default"),
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, endpoint: str, credential: Optional["AsyncTokenCredential"] =
:type endpoint: str
:param credential: The credential with which to authenticate
:type credential: :class:`~azure.core.credentials_async.AsyncTokenCredential`
:keyword authentication_scope: URL for credential authentication if different from the default
:paramtype authentication_scope: str
:keyword credential_scopes: URL for credential authentication if different from the default
:paramtype credential_scopes: List[str]
:returns: None
:raises: None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ def __init__(self, endpoint: str, credential: "AsyncTokencredential", **kwargs:
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
self._credential_scope = kwargs.get("authentication_scope", "https://management.core.windows.net/.default")
self.credential_scopes = kwargs.get("credential_scopes", ["https://management.core.windows.net/.default"])
self._client = ContainerRegistry(
credential=credential,
url=endpoint,
sdk_moniker=USER_AGENT,
authentication_policy=ExchangeClientAuthenticationPolicy(),
credential_scopes=self._credential_scope,
**kwargs
)
self._credential = credential
Expand All @@ -67,7 +66,7 @@ async def get_refresh_token(self, service: str, **kwargs: Dict[str, Any]) -> str
return self._refresh_token

async def exchange_aad_token_for_refresh_token(self, service: str = None, **kwargs: Dict[str, Any]) -> str:
token = await self._credential.get_token(self._credential_scope)
token = await self._credential.get_token(*self.credential_scopes)
refresh_token = await self._client.authentication.exchange_aad_access_token_for_acr_refresh_token(
service, token.token, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
../../core/azure-core
../azure-mgmt-containerregistry
aiohttp>=3.0; python_version >= '3.5'
azure-identity
azure-identity
msrestazure>=0.4.11
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import logging
import os

from azure.containerregistry.aio import (
# ContainerRepository,
ContainerRegistryClient,
)

from azure.core.credentials import AccessToken
from azure.identity.aio import DefaultAzureCredential
from azure.identity.aio import DefaultAzureCredential, ClientSecretCredential
from azure.identity import AzureAuthorityHosts

from testcase import ContainerRegistryTestClass, get_authorization_scope, get_authority

from testcase import ContainerRegistryTestClass
logger = logging.getLogger()


class AsyncFakeTokenCredential(object):
Expand All @@ -30,25 +35,27 @@ class AsyncContainerRegistryTestClass(ContainerRegistryTestClass):
def __init__(self, method_name):
super(AsyncContainerRegistryTestClass, self).__init__(method_name)

def get_credential(self):
def get_credential(self, authority=None, **kwargs):
if self.is_live:
return DefaultAzureCredential()
if authority != AzureAuthorityHosts.AZURE_PUBLIC_CLOUD:
return ClientSecretCredential(
tenant_id=os.environ["CONTAINERREGISTRY_TENANT_ID"],
client_id=os.environ["CONTAINERREGISTRY_CLIENT_ID"],
client_secret=os.environ["CONTAINERREGISTRY_CLIENT_SECRET"],
authority=authority
)
return DefaultAzureCredential(**kwargs)
return AsyncFakeTokenCredential()

def create_registry_client(self, endpoint, **kwargs):
return ContainerRegistryClient(
endpoint=endpoint,
credential=self.get_credential(),
**kwargs,
)

def create_container_repository(self, endpoint, name, **kwargs):
return ContainerRepository(
endpoint=endpoint,
name=name,
credential=self.get_credential(),
**kwargs,
)
authority = get_authority(endpoint)
audience = kwargs.pop("audience", None)
if not audience:
audience = get_authorization_scope(authority)
credential = self.get_credential(authority=authority)
return ContainerRegistryClient(endpoint=endpoint, credential=credential, credential_scopes=audience, **kwargs)

def create_anon_client(self, endpoint, **kwargs):
return ContainerRegistryClient(endpoint=endpoint, credential=None, **kwargs)
authority = get_authority(endpoint)
audience = get_authorization_scope(authority)
return ContainerRegistryClient(endpoint=endpoint, credential=None, credential_scopes=audience, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
class TestContainerRegistryClient(ContainerRegistryTestClass):
@acr_preparer()
def test_list_repository_names(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -40,6 +43,9 @@ def test_list_repository_names(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_list_repository_names_by_page(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -63,6 +69,9 @@ def test_list_repository_names_by_page(self, containerregistry_anonregistry_endp

@acr_preparer()
def test_get_repository_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -73,6 +82,9 @@ def test_get_repository_properties(self, containerregistry_anonregistry_endpoint

@acr_preparer()
def test_list_manifest_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -84,6 +96,9 @@ def test_list_manifest_properties(self, containerregistry_anonregistry_endpoint)

@acr_preparer()
def test_get_manifest_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -95,6 +110,9 @@ def test_get_manifest_properties(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_list_tag_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -106,6 +124,9 @@ def test_list_tag_properties(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_delete_repository(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -114,6 +135,9 @@ def test_delete_repository(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_delete_tag(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -122,6 +146,9 @@ def test_delete_tag(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_delete_manifest(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)
assert client._credential is None

Expand All @@ -130,6 +157,9 @@ def test_delete_manifest(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_update_repository_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)

properties = client.get_repository_properties(HELLO_WORLD)
Expand All @@ -139,6 +169,9 @@ def test_update_repository_properties(self, containerregistry_anonregistry_endpo

@acr_preparer()
def test_update_tag_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do a pytest.mark.skipif instead, even though you're looking at a self parameter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a NameError: name 'self' is not defined if I do this, also the parameter is passed in by the preparer so it wouldn't be available either.

pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)

properties = client.get_tag_properties(HELLO_WORLD, "latest")
Expand All @@ -148,9 +181,12 @@ def test_update_tag_properties(self, containerregistry_anonregistry_endpoint):

@acr_preparer()
def test_update_manifest_properties(self, containerregistry_anonregistry_endpoint):
if not self.is_public_endpoint(containerregistry_anonregistry_endpoint):
pytest.skip("Not a public endpoint")

client = self.create_anon_client(containerregistry_anonregistry_endpoint)

properties = client.get_manifest_properties(HELLO_WORLD, "latest")

with pytest.raises(ClientAuthenticationError):
client.update_manifest_properties(HELLO_WORLD, "latest", properties, can_delete=True)
client.update_manifest_properties(HELLO_WORLD, "latest", properties, can_delete=True)
Loading