Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add any additional claims to AuthenticationRequiredError #17136

Merged
merged 3 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Renamed `CertificateCredential` keyword argument `certificate_bytes` to
`certificate_data`

### Added
- The `AuthenticationRequiredError.claims` property provides any additional
claims required by a user credential's `authenticate()` method

## 1.6.0b1 (2021-02-09)
### Changed
- Raised minimum msal version to 1.7.0
Expand Down
19 changes: 16 additions & 3 deletions sdk/identity/azure-identity/azure/identity/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ class CredentialUnavailableError(ClientAuthenticationError):


class AuthenticationRequiredError(CredentialUnavailableError):
"""Interactive authentication is required to acquire a token."""
"""Interactive authentication is required to acquire a token.

def __init__(self, scopes, message=None, error_details=None, **kwargs):
# type: (Iterable[str], Optional[str], Optional[str], **Any) -> None
This error is raised only by interactive user credentials configured not to automatically prompt for user
interaction as needed. Its properties provide additional information that may be required to authenticate. The
control_interactive_prompts sample demonstrates handling this error by calling a credential's "authenticate"
method.
"""

def __init__(self, scopes, message=None, error_details=None, claims=None, **kwargs):
# type: (Iterable[str], Optional[str], Optional[str], Optional[str], **Any) -> None
self._claims = claims
self._scopes = scopes
self._error_details = error_details
if not message:
Expand All @@ -36,3 +43,9 @@ def error_details(self):
# type: () -> Optional[str]
"""Additional authentication error details from Azure Active Directory"""
return self._error_details

@property
def claims(self):
# type: () -> Optional[str]
"""Additional claims required in the next authentication"""
return self._claims
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def authenticate(self, **kwargs):
:keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by
:func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token
for these scopes.
:keyword str claims: additional claims required in the token, such as those provided by
:func:`AuthenticationRequiredError.claims`
:rtype: ~azure.identity.AuthenticationRecord
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
Expand All @@ -182,24 +184,23 @@ def authenticate(self, **kwargs):
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
result = None
claims = kwargs.get("claims")
if self._auth_record:
app = self._get_app()
for account in app.get_accounts(username=self._auth_record.username):
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, result is either None or the content of an AAD error response
if result:
details = result.get("error_description") or result.get("error")
raise AuthenticationRequiredError(scopes, error_details=details)
raise AuthenticationRequiredError(scopes)
raise AuthenticationRequiredError(scopes, error_details=details, claims=claims)
raise AuthenticationRequiredError(scopes, claims=claims)

def _get_app(self):
# type: () -> msal.PublicClientApplication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
print("This sample expects environment variable 'VAULT_URL' to be set with the URL of a Key Vault.")
sys.exit(1)


# If it's important for your application to prompt for authentication only at certain times,
# create the credential with disable_automatic_authentication=True. This configures the credential to raise
# when interactive authentication is required, instead of immediately beginning that authentication.
Expand All @@ -30,9 +29,9 @@
secret_names = [s.name for s in client.list_properties_of_secrets()]
except AuthenticationRequiredError as ex:
# Interactive authentication is necessary to authorize the client's request. The exception carries the
# requested authentication scopes. If you pass these to 'authenticate', it will cache an access token
# for those scopes.
credential.authenticate(scopes=ex.scopes)
# requested authentication scopes as well as any additional claims the service requires. If you pass
# both to 'authenticate', it will cache an access token for the necessary scopes.
credential.authenticate(scopes=ex.scopes, claims=ex.claims)

# the client operation should now succeed
secret_names = [s.name for s in client.list_properties_of_secrets()]
14 changes: 10 additions & 4 deletions sdk/identity/azure-identity/tests/test_browser_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,11 @@ def test_cannot_bind_redirect_uri():


def test_claims_challenge():
"""get_token should pass any claims challenge to MSAL token acquisition APIs"""
"""get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs"""

expected_claims = '{"access_token": {"essential": "true"}'

oauth_state = "..."
auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
auth_code_response = {"code": "authorization-code", "state": ["..."]}
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

msal_acquire_token_result = dict(
Expand All @@ -250,12 +249,19 @@ def test_claims_challenge():
msal_app.acquire_token_by_auth_code_flow.return_value = msal_acquire_token_result

with patch(WEBBROWSER_OPEN, lambda _: True):
credential.get_token("scope", claims=expected_claims)
credential.authenticate(scopes=["scope"], claims=expected_claims)

assert msal_app.acquire_token_by_auth_code_flow.call_count == 1
args, kwargs = msal_app.acquire_token_by_auth_code_flow.call_args
assert kwargs["claims_challenge"] == expected_claims

with patch(WEBBROWSER_OPEN, lambda _: True):
credential.get_token("scope", claims=expected_claims)

assert msal_app.acquire_token_by_auth_code_flow.call_count == 2
args, kwargs = msal_app.acquire_token_by_auth_code_flow.call_args
assert kwargs["claims_challenge"] == expected_claims

msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}]
msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result
credential.get_token("scope", claims=expected_claims)
Expand Down
11 changes: 9 additions & 2 deletions sdk/identity/azure-identity/tests/test_device_code_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_client_capabilities():


def test_claims_challenge():
"""get_token should pass any claims challenge to MSAL token acquisition APIs"""
"""get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs"""

msal_acquire_token_result = dict(
build_aad_response(access_token="**", id_token=build_id_token()),
Expand All @@ -292,12 +292,19 @@ def test_claims_challenge():
msal_app = get_mock_app()
msal_app.initiate_device_flow.return_value = {"message": "it worked"}
msal_app.acquire_token_by_device_flow.return_value = msal_acquire_token_result
credential.get_token("scope", claims=expected_claims)

credential.authenticate(scopes=["scope"], claims=expected_claims)

assert msal_app.acquire_token_by_device_flow.call_count == 1
args, kwargs = msal_app.acquire_token_by_device_flow.call_args
assert kwargs["claims_challenge"] == expected_claims

credential.get_token("scope", claims=expected_claims)

assert msal_app.acquire_token_by_device_flow.call_count == 2
args, kwargs = msal_app.acquire_token_by_device_flow.call_args
assert kwargs["claims_challenge"] == expected_claims

msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}]
msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result
credential.get_token("scope", claims=expected_claims)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ def test_disable_automatic_authentication():
)

scope = "scope"
expected_claims = "..."
with pytest.raises(AuthenticationRequiredError) as ex:
credential.get_token(scope)
credential.get_token(scope, claims=expected_claims)

# the exception should carry the requested scopes and any error message from AAD
# the exception should carry the requested scopes and claims, and any error message from AAD
assert ex.value.scopes == (scope,)
assert ex.value.error_details == expected_details
assert ex.value.claims == expected_claims


def test_scopes_round_trip():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_client_capabilities():


def test_claims_challenge():
"""get_token should pass any claims challenge to MSAL token acquisition APIs"""
"""get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs"""

msal_acquire_token_result = dict(
build_aad_response(access_token="**", id_token=build_id_token()),
Expand All @@ -176,12 +176,18 @@ def test_claims_challenge():
with patch.object(UsernamePasswordCredential, "_get_app") as get_mock_app:
msal_app = get_mock_app()
msal_app.acquire_token_by_username_password.return_value = msal_acquire_token_result
credential.get_token("scope", claims=expected_claims)

credential.authenticate(scopes=["scope"], claims=expected_claims)
assert msal_app.acquire_token_by_username_password.call_count == 1
args, kwargs = msal_app.acquire_token_by_username_password.call_args
assert kwargs["claims_challenge"] == expected_claims

credential.get_token("scope", claims=expected_claims)

assert msal_app.acquire_token_by_username_password.call_count == 2
args, kwargs = msal_app.acquire_token_by_username_password.call_args
assert kwargs["claims_challenge"] == expected_claims

msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}]
msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result
credential.get_token("scope", claims=expected_claims)
Expand Down