From e670d051c47638ec37ea51b88111e48724520aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?McCoy=20Pati=C3=B1o?= Date: Mon, 8 Mar 2021 16:58:02 -0800 Subject: [PATCH] Thanks, Charles! --- .../azure/keyvault/keys/crypto/_client.py | 34 ++++++++++++++----- .../keys/crypto/_providers/local_provider.py | 2 +- .../azure/keyvault/keys/crypto/aio/_client.py | 33 +++++++++++++----- .../tests/test_crypto_client.py | 16 ++++----- .../tests/test_crypto_client_async.py | 16 ++++----- .../tests/test_examples_crypto.py | 6 ++-- 6 files changed, 65 insertions(+), 42 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py index b89364e9149c..06ffba9b00e0 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py @@ -123,11 +123,18 @@ def __init__(self, key, credential, **kwargs): if not (self._jwk or self._key_id.version): raise ValueError("'key' must include a version") - self._local_provider = NoLocalCryptography() - self._initialized = False + if self._jwk: + try: + self._local_provider = get_local_cryptography_provider(self._key) + self._initialized = True + except Exception as ex: # pylint:disable=broad-except + raise ValueError("The provided jwk is not valid for local cryptography: {}".format(ex)) + else: + self._local_provider = NoLocalCryptography() + self._initialized = False - vault_url = "vault_url" if self._jwk else self._key_id.vault_url - super(CryptographyClient, self).__init__(vault_url=vault_url, credential=credential, **kwargs) + self._vault_url = "vault_url" if self._jwk else self._key_id.vault_url + super(CryptographyClient, self).__init__(vault_url=self._vault_url, credential=credential, **kwargs) @property def key_id(self): @@ -142,6 +149,17 @@ def key_id(self): return self._key_id.source_id return self._key.kid + @property + def vault_url(self): + # type: () -> Optional[str] + """The base vault URL of the client's key. + + This property may be None when a client is constructed with :func:`from_jwk`. + + :rtype: str + """ + return self._vault_url + @classmethod def from_jwk(cls, jwk): # type: (Union[JsonWebKey, dict]) -> CryptographyClient @@ -158,11 +176,9 @@ def from_jwk(cls, jwk): :language: python :dedent: 8 """ - if isinstance(jwk, JsonWebKey): - key = jwk - else: - key = JsonWebKey(**jwk) - return cls(key, object(), _jwk=True) + if not isinstance(jwk, JsonWebKey): + jwk = JsonWebKey(**jwk) + return cls(jwk, object(), _jwk=True) @distributed_trace def _initialize(self, **kwargs): diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py index a018e92c1831..96355c30dba4 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py @@ -28,7 +28,7 @@ class LocalCryptographyProvider(ABC): def __init__(self, key, **kwargs): # type: (JsonWebKey, **Any) -> None - self._allowed_ops = frozenset(key.key_ops) + self._allowed_ops = frozenset(key.key_ops or []) self._internal_key = self._get_internal_key(key) self._key = key self._key_id = kwargs.pop("_key_id", None) or key.kid diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py index 8b9e47d77dfb..e73130da4f69 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py @@ -78,11 +78,18 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCreden if not (self._jwk or self._key_id.version): raise ValueError("'key' must include a version") - self._local_provider = NoLocalCryptography() - self._initialized = False + if self._jwk: + try: + self._local_provider = get_local_cryptography_provider(self._key) + self._initialized = True + except Exception as ex: # pylint:disable=broad-except + raise ValueError("The provided jwk is not valid for local cryptography: {}".format(ex)) + else: + self._local_provider = NoLocalCryptography() + self._initialized = False - vault_url = "vault_url" if self._jwk else self._key_id.vault_url - super().__init__(vault_url=vault_url, credential=credential, **kwargs) + self._vault_url = "vault_url" if self._jwk else self._key_id.vault_url + super().__init__(vault_url=self._vault_url, credential=credential, **kwargs) @property def key_id(self) -> "Optional[str]": @@ -96,6 +103,16 @@ def key_id(self) -> "Optional[str]": return self._key_id.source_id return self._key.kid + @property + def vault_url(self) -> "Optional[str]": + """The base vault URL of the client's key. + + This property may be None when a client is constructed with :func:`from_jwk`. + + :rtype: str + """ + return self._vault_url + @classmethod def from_jwk(cls, jwk: "Union[JsonWebKey, dict]") -> "CryptographyClient": """Creates a client that can only perform cryptographic operations locally. @@ -111,11 +128,9 @@ def from_jwk(cls, jwk: "Union[JsonWebKey, dict]") -> "CryptographyClient": :language: python :dedent: 8 """ - if isinstance(jwk, JsonWebKey): - key = jwk - else: - key = JsonWebKey(**jwk) - return cls(key, object(), _jwk=True) + if not isinstance(jwk, JsonWebKey): + jwk = JsonWebKey(**jwk) + return cls(jwk, object(), _jwk=True) @distributed_trace_async async def _initialize(self, **kwargs): diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client.py b/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client.py index 608de82acb24..8884de5b3c1c 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client.py @@ -581,20 +581,16 @@ def test_local_only_mode_no_service_calls(): """A local-only CryptographyClient shouldn't call the service if an operation can't be performed locally""" mock_client = mock.Mock() - jwk = JsonWebKey() + jwk = JsonWebKey(kty="RSA", key_ops=[], n=b"10011", e=b"10001") client = CryptographyClient.from_jwk(jwk=jwk) client._client = mock_client - supports_nothing = mock.Mock(supports=mock.Mock(return_value=False)) - with mock.patch( - CryptographyClient.__module__ + ".get_local_cryptography_provider", lambda *args, **kwargs: supports_nothing - ): - with pytest.raises(NotImplementedError): - client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") + with pytest.raises(NotImplementedError): + client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.decrypt.call_count == 0 with pytest.raises(NotImplementedError): - client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...") + client.encrypt(EncryptionAlgorithm.a256_gcm, b"...") assert mock_client.encrypt.call_count == 0 with pytest.raises(NotImplementedError): @@ -602,7 +598,7 @@ def test_local_only_mode_no_service_calls(): assert mock_client.sign.call_count == 0 with pytest.raises(NotImplementedError): - client.verify(SignatureAlgorithm.rs256, b"...", b"...") + client.verify(SignatureAlgorithm.es256, b"...", b"...") assert mock_client.verify.call_count == 0 with pytest.raises(NotImplementedError): @@ -610,7 +606,7 @@ def test_local_only_mode_no_service_calls(): assert mock_client.unwrap_key.call_count == 0 with pytest.raises(NotImplementedError): - client.wrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") + client.wrap_key(KeyWrapAlgorithm.aes_256, b"...") assert mock_client.wrap_key.call_count == 0 diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client_async.py index 5c00ff1b03ee..785094774e65 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_crypto_client_async.py @@ -606,20 +606,16 @@ async def test_local_only_mode_no_service_calls(): """A local-only CryptographyClient shouldn't call the service if an operation can't be performed locally""" mock_client = mock.Mock() - jwk = JsonWebKey() + jwk = JsonWebKey(kty="RSA", key_ops=[], n=b"10011", e=b"10001") client = CryptographyClient.from_jwk(jwk=jwk) client._client = mock_client - supports_nothing = mock.Mock(supports=mock.Mock(return_value=False)) - with mock.patch( - CryptographyClient.__module__ + ".get_local_cryptography_provider", lambda *args, **kwargs: supports_nothing - ): - with pytest.raises(NotImplementedError): - await client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") + with pytest.raises(NotImplementedError): + await client.decrypt(EncryptionAlgorithm.rsa_oaep, b"...") assert mock_client.decrypt.call_count == 0 with pytest.raises(NotImplementedError): - await client.encrypt(EncryptionAlgorithm.rsa_oaep, b"...") + await client.encrypt(EncryptionAlgorithm.a256_gcm, b"...") assert mock_client.encrypt.call_count == 0 with pytest.raises(NotImplementedError): @@ -627,7 +623,7 @@ async def test_local_only_mode_no_service_calls(): assert mock_client.sign.call_count == 0 with pytest.raises(NotImplementedError): - await client.verify(SignatureAlgorithm.rs256, b"...", b"...") + await client.verify(SignatureAlgorithm.es256, b"...", b"...") assert mock_client.verify.call_count == 0 with pytest.raises(NotImplementedError): @@ -635,7 +631,7 @@ async def test_local_only_mode_no_service_calls(): assert mock_client.unwrap_key.call_count == 0 with pytest.raises(NotImplementedError): - await client.wrap_key(KeyWrapAlgorithm.rsa_oaep, b"...") + await client.wrap_key(KeyWrapAlgorithm.aes_256, b"...") assert mock_client.wrap_key.call_count == 0 diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_examples_crypto.py b/sdk/keyvault/azure-keyvault-keys/tests/test_examples_crypto.py index eb4413e61b13..4a24eb3c1c93 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_examples_crypto.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_examples_crypto.py @@ -4,7 +4,7 @@ # ------------------------------------ import functools -from azure.keyvault.keys import JsonWebKey, KeyClient +from azure.keyvault.keys import JsonWebKey, KeyClient, KeyOperation from azure.keyvault.keys.crypto import CryptographyClient from azure.keyvault.keys._shared import HttpChallengeCache from devtools_testutils import PowerShellPreparer @@ -21,11 +21,11 @@ def test_create_client_from_jwk(): # [START from_jwk] # create a CryptographyClient using a JsonWebKey instance - key = JsonWebKey(kty="RSA") + key = JsonWebKey(kty="RSA", key_ops=[KeyOperation.decrypt], n=b"10011", e=b"10001") crypto_client = CryptographyClient.from_jwk(jwk=key) # or a dictionary with JsonWebKey properties - key_dict = {"kty":"RSA"} + key_dict = {"kty":"RSA", "key_ops":[KeyOperation.decrypt], "n":b"10011", "e":b"10001"} crypto_client = CryptographyClient.from_jwk(jwk=key_dict) # [END from_jwk]