diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 46f0d12cf816..3227a86e4bfb 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -29,6 +29,9 @@ the keyword argument `interactive_browser_tenant_id`, or set the environment variable `AZURE_TENANT_ID`. ([#11548](https://github.com/Azure/azure-sdk-for-python/issues/11548)) +- `SharedTokenCacheCredential` can be initialized with an `AuthenticationRecord` + provided by a user credential. + ([#11448](https://github.com/Azure/azure-sdk-for-python/issues/11448)) - The user authentication API added to `DeviceCodeCredential` and `InteractiveBrowserCredential` in 1.4.0b3 is available on `UsernamePasswordCredential` as well. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index b748b700e5bf..f48ba99bdd56 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -14,8 +14,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Mapping - from azure.core.credentials import AccessToken + from typing import Any from .._internal import AadClientBase @@ -31,6 +30,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase): defines authorities for other clouds. :keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains tokens for multiple identities. + :keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as + :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential` """ @wrap_exceptions @@ -67,4 +68,4 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument def _get_auth_client(self, **kwargs): # type: (**Any) -> AadClientBase - return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs) + return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 6fba6c09b986..045f3e637072 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -27,6 +27,7 @@ # pylint:disable=unused-import,ungrouped-imports from typing import Any, Iterable, List, Mapping, Optional from .._internal import AadClientBase + from azure.identity import AuthenticationRecord CacheItem = Mapping[str, str] @@ -86,13 +87,22 @@ class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - authority = kwargs.pop("authority", None) - self._authority = normalize_authority(authority) if authority else get_default_authority() - - environment = urlparse(self._authority).netloc - self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) - self._username = username - self._tenant_id = kwargs.pop("tenant_id", None) + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + if self._auth_record: + # authenticate in the tenant that produced the record unless 'tenant_id' specifies another + authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + self._tenant_id = self._auth_record.tenant_id + self._authority = self._auth_record.authority + self._username = self._auth_record.username + self._environment_aliases = frozenset((self._authority,)) + else: + authenticating_tenant = "organizations" + authority = kwargs.pop("authority", None) + self._authority = normalize_authority(authority) if authority else get_default_authority() + environment = urlparse(self._authority).netloc + self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) + self._username = username + self._tenant_id = kwargs.pop("tenant_id", None) cache = kwargs.pop("_cache", None) # for ease of testing @@ -110,7 +120,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument if cache: self._cache = cache self._client = self._get_auth_client( - authority=self._authority, cache=cache, **kwargs + authority=self._authority, tenant_id=authenticating_tenant, cache=cache, **kwargs ) # type: Optional[AadClientBase] else: self._client = None @@ -161,6 +171,14 @@ def _get_account(self, username=None, tenant_id=None): # cache is empty or contains no refresh token -> user needs to sign in raise CredentialUnavailableError(message=NO_ACCOUNTS) + if self._auth_record: + for account in accounts: + if account.get("home_account_id") == self._auth_record.home_account_id: + return account + raise CredentialUnavailableError( + message="The cache contains no account matching the given AuthenticationRecord." + ) + filtered_accounts = _filtered_accounts(accounts, username, tenant_id) if len(filtered_accounts) == 1: return filtered_accounts[0] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index ec021eba460d..c3ad6db261d9 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -29,6 +29,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase): defines authorities for other clouds. :keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains tokens for multiple identities. + :keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as + :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential` """ async def __aenter__(self): @@ -74,4 +76,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) def _get_auth_client(self, **kwargs: "Any") -> "AadClientBase": - return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs) + return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index daf5a73dd503..9d8cba6ca757 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -4,7 +4,11 @@ # ------------------------------------ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential +from azure.identity import ( + AuthenticationRecord, + CredentialUnavailableError, + SharedTokenCacheCredential, +) from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( KNOWN_ALIASES, @@ -502,6 +506,112 @@ def test_authority_environment_variable(): assert token.token == expected_access_token +def test_authentication_record_empty_cache(): + record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) + + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + + +def test_authentication_record_no_match(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + cache = populated_cache( + get_account_event( + "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, + ), + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + + +def test_authentication_record(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache(account) + + transport = validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = credential.get_token("scope") + assert token.token == expected_access_token + + +def test_auth_record_multiple_accounts_for_username(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + expected_account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache( + expected_account, + get_account_event( # this account matches all but the record's tenant + username, + object_id, + "different-" + tenant_id, + authority=authority, + client_id=client_id, + refresh_token="not-" + expected_refresh_token, + ), + ) + + transport = validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = credential.get_token("scope") + assert token.token == expected_access_token + + +def test_authentication_record_authenticating_tenant(): + """when given a record and 'tenant_id', the credential should authenticate in the latter""" + + expected_tenant_id = "tenant-id" + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") + + with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: + SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + + assert get_auth_client.call_count == 1 + _, kwargs = get_auth_client.call_args + assert kwargs["tenant_id"] == expected_tenant_id + + def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None ): diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index c718eaf3481c..4552b4e8c9f3 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -7,7 +7,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import CredentialUnavailableError, KnownAuthorities +from azure.identity import AuthenticationRecord, CredentialUnavailableError from azure.identity.aio import SharedTokenCacheCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( @@ -566,3 +566,113 @@ async def test_authority_environment_variable(): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) token = await credential.get_token("scope") assert token.token == expected_access_token + + +@pytest.mark.asyncio +async def test_authentication_record_empty_cache(): + record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) + + with pytest.raises(CredentialUnavailableError): + await credential.get_token("scope") + + +@pytest.mark.asyncio +async def test_authentication_record_no_match(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + cache = populated_cache( + get_account_event( + "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, + ), + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + with pytest.raises(CredentialUnavailableError): + await credential.get_token("scope") + + +@pytest.mark.asyncio +async def test_authentication_record(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache(account) + + transport = async_validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = await credential.get_token("scope") + assert token.token == expected_access_token + + +@pytest.mark.asyncio +async def test_auth_record_multiple_accounts_for_username(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + expected_account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache( + expected_account, + get_account_event( # this account matches all but the record's tenant + username, + object_id, + "different-" + tenant_id, + authority=authority, + client_id=client_id, + refresh_token="not-" + expected_refresh_token, + ), + ) + + transport = async_validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = await credential.get_token("scope") + assert token.token == expected_access_token + + +def test_authentication_record_authenticating_tenant(): + """when given a record and 'tenant_id', the credential should authenticate in the latter""" + + expected_tenant_id = "tenant-id" + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") + + with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: + SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + + assert get_auth_client.call_count == 1 + _, kwargs = get_auth_client.call_args + assert kwargs["tenant_id"] == expected_tenant_id