Skip to content

Commit

Permalink
SharedTokenCacheCredential takes an optional AuthenticationRecord (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 5, 2020
1 parent b15cede commit 994c77d
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 14 deletions.
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
the keyword argument `interactive_browser_tenant_id`, or set the environment
variable `AZURE_TENANT_ID`.
([#11548](https://github.com/Azure/azure-sdk-for-python/issues/11548))
- `SharedTokenCacheCredential` can be initialized with an `AuthenticationRecord`
provided by a user credential.
([#11448](https://github.com/Azure/azure-sdk-for-python/issues/11448))
- The user authentication API added to `DeviceCodeCredential` and
`InteractiveBrowserCredential` in 1.4.0b3 is available on
`UsernamePasswordCredential` as well.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Mapping
from azure.core.credentials import AccessToken
from typing import Any
from .._internal import AadClientBase


Expand All @@ -31,6 +30,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
"""

@wrap_exceptions
Expand Down Expand Up @@ -67,4 +68,4 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord

CacheItem = Mapping[str, str]

Expand Down Expand Up @@ -86,13 +87,22 @@ class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()

environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless 'tenant_id' specifies another
authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._tenant_id = self._auth_record.tenant_id
self._authority = self._auth_record.authority
self._username = self._auth_record.username
self._environment_aliases = frozenset((self._authority,))
else:
authenticating_tenant = "organizations"
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)

cache = kwargs.pop("_cache", None) # for ease of testing

Expand All @@ -110,7 +120,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
if cache:
self._cache = cache
self._client = self._get_auth_client(
authority=self._authority, cache=cache, **kwargs
authority=self._authority, tenant_id=authenticating_tenant, cache=cache, **kwargs
) # type: Optional[AadClientBase]
else:
self._client = None
Expand Down Expand Up @@ -161,6 +171,14 @@ def _get_account(self, username=None, tenant_id=None):
# cache is empty or contains no refresh token -> user needs to sign in
raise CredentialUnavailableError(message=NO_ACCOUNTS)

if self._auth_record:
for account in accounts:
if account.get("home_account_id") == self._auth_record.home_account_id:
return account
raise CredentialUnavailableError(
message="The cache contains no account matching the given AuthenticationRecord."
)

filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
if len(filtered_accounts) == 1:
return filtered_accounts[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
"""

async def __aenter__(self):
Expand Down Expand Up @@ -74,4 +76,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))

def _get_auth_client(self, **kwargs: "Any") -> "AadClientBase":
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)
112 changes: 111 additions & 1 deletion sdk/identity/azure-identity/tests/test_shared_cache_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# ------------------------------------
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential
from azure.identity import (
AuthenticationRecord,
CredentialUnavailableError,
SharedTokenCacheCredential,
)
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
KNOWN_ALIASES,
Expand Down Expand Up @@ -502,6 +506,112 @@ def test_authority_environment_variable():
assert token.token == expected_access_token


def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id


def get_account_event(
username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities
from azure.identity import AuthenticationRecord, CredentialUnavailableError
from azure.identity.aio import SharedTokenCacheCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
Expand Down Expand Up @@ -566,3 +566,113 @@ async def test_authority_environment_variable():
credential = SharedTokenCacheCredential(transport=transport, _cache=cache)
token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id

0 comments on commit 994c77d

Please sign in to comment.