Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Key Vault] Add local-only mode to CryptographyClient #16565

Merged
merged 18 commits into from
Mar 10, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ._key_validity import raise_if_time_invalid
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._models import JsonWebKey, KeyVaultKey
from .._shared import KeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,6 +98,7 @@ class CryptographyClient(KeyVaultClientBase):

def __init__(self, key, credential, **kwargs):
# type: (Union[KeyVaultKey, str], TokenCredential, **Any) -> None
self._jwk = kwargs.pop("_jwk", False)

if isinstance(key, KeyVaultKey):
self._key = key
Expand All @@ -106,25 +107,55 @@ def __init__(self, key, credential, **kwargs):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
elif self._jwk:
self._key = key
self._key_id = key.kid
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not self._key_id.version:
if not (self._jwk or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
vault_url = "vault_url" if self._jwk else self._key_id.vault_url
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
super(CryptographyClient, self).__init__(vault_url=vault_url, credential=credential, **kwargs)

@property
def key_id(self):
# type: () -> str
"""The full identifier of the client's key.

This property may be None when a client is constructed with `CryptographyClient.from_jwk`.
mccoyp marked this conversation as resolved.
Show resolved Hide resolved

:rtype: str
"""
return self._key_id.source_id
if not self._jwk:
return self._key_id.source_id
return self._key_id

@classmethod
def from_jwk(cls, jwk):
# type: (Union[JsonWebKey, dict]) -> CryptographyClient
"""Creates a client that can only perform cryptographic operations locally.

:param jwk: the key's cryptographic material, as a JsonWebKey or dictionary.
:type jwk: JsonWebKey or dict
:rtype: CryptographyClient

.. literalinclude:: ../tests/test_examples_crypto.py
:start-after: [START from_jwk]
:end-before: [END from_jwk]
:caption: Create a CryptographyClient from a JsonWebKey
:language: python
:dedent: 8
"""
if isinstance(jwk, JsonWebKey):
key = jwk
else:
key = JsonWebKey(**jwk)
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
return cls(key, object(), _jwk=True)
mccoyp marked this conversation as resolved.
Show resolved Hide resolved

@distributed_trace
def _initialize(self, **kwargs):
Expand Down Expand Up @@ -186,6 +217,12 @@ def encrypt(self, algorithm, plaintext, **kwargs):
return self._local_provider.encrypt(algorithm, plaintext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local encrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "encrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -240,6 +277,12 @@ def decrypt(self, algorithm, ciphertext, **kwargs):
return self._local_provider.decrypt(algorithm, ciphertext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local decrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "decrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.decrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -277,6 +320,12 @@ def wrap_key(self, algorithm, key, **kwargs):
return self._local_provider.wrap_key(algorithm, key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local wrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "wrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.wrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -311,6 +360,12 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
return self._local_provider.unwrap_key(algorithm, encrypted_key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local unwrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "unwrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -345,6 +400,12 @@ def sign(self, algorithm, digest, **kwargs):
return self._local_provider.sign(algorithm, digest)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local sign operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "sign" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.sign(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -381,6 +442,12 @@ def verify(self, algorithm, digest, signature, **kwargs):
return self._local_provider.verify(algorithm, digest, signature)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local verify operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "verify" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.verify(
vault_base_url=self._key_id.vault_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from .. import KeyVaultKey
from typing import Union
from .. import JsonWebKey, KeyVaultKey


class _UTC_TZ(tzinfo):
Expand All @@ -29,7 +30,7 @@ def dst(self, dt):


def raise_if_time_invalid(key):
# type: (KeyVaultKey) -> None
# type: (Union[JsonWebKey, KeyVaultKey]) -> None
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@
from .local_provider import LocalCryptographyProvider
from .rsa import RsaCryptographyProvider
from .symmetric import SymmetricCryptographyProvider
from ... import KeyType
from ... import JsonWebKey, KeyType

if TYPE_CHECKING:
from typing import Union
from ... import KeyVaultKey


def get_local_cryptography_provider(key):
# type: (KeyVaultKey) -> LocalCryptographyProvider
if key.key_type in (KeyType.ec, KeyType.ec_hsm):
# type: (Union[JsonWebKey, KeyVaultKey]) -> LocalCryptographyProvider
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(key, JsonWebKey):
key_type = key.kty
else:
key_type = key.key_type

if key_type in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key)
if key.key_type in (KeyType.rsa, KeyType.rsa_hsm):
if key_type in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key)
if key.key_type in (KeyType.oct, KeyType.oct_hsm):
if key_type in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key)

raise ValueError('Unsupported key type "{}"'.format(key.key_type))
raise ValueError('Unsupported key type "{}"'.format(key_type))


class NoLocalCryptography(LocalCryptographyProvider):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from .local_provider import LocalCryptographyProvider
from .._internal import EllipticCurveKey
from ... import KeyOperation, KeyType
from ... import JsonWebKey, KeyOperation, KeyType

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Union
from .local_provider import Algorithm
from .._internal import Key
from ... import KeyVaultKey
Expand All @@ -19,10 +20,17 @@

class EllipticCurveCryptographyProvider(LocalCryptographyProvider):
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
if key.key_type not in (KeyType.ec, KeyType.ec_hsm):
# type: (Union[JsonWebKey, KeyVaultKey]) -> Key
if isinstance(key, JsonWebKey):
key_type = key.kty
jwk = key
else:
key_type = key.key_type
jwk = key.key

if key_type not in (KeyType.ec, KeyType.ec_hsm):
raise ValueError('"key" must be an EC or EC-HSM key')
return EllipticCurveKey.from_jwk(key.key)
return EllipticCurveKey.from_jwk(jwk)

def supports(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from azure.core.exceptions import AzureError

from .. import DecryptResult, EncryptResult, SignResult, UnwrapResult, VerifyResult, WrapResult
from ... import KeyOperation
from ... import JsonWebKey, KeyOperation

try:
ABC = abc.ABC
Expand All @@ -17,7 +17,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Union
from typing import Optional, Union
from .._internal.key import Key
from .. import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from ... import KeyVaultKey
Expand All @@ -27,14 +27,17 @@

class LocalCryptographyProvider(ABC):
def __init__(self, key):
# type: (KeyVaultKey) -> None
self._allowed_ops = frozenset(key.key_operations)
# type: (Union[JsonWebKey, KeyVaultKey]) -> None
if isinstance(key, JsonWebKey):
self._allowed_ops = frozenset(key.key_ops)
else:
self._allowed_ops = frozenset(key.key_operations)
self._internal_key = self._get_internal_key(key)
self._key = key

@abc.abstractmethod
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
# type: (Union[JsonWebKey, KeyVaultKey]) -> Key
pass

@abc.abstractmethod
Expand All @@ -44,11 +47,13 @@ def supports(self, operation, algorithm):

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the provider's key.

:rtype: str
"""
if isinstance(self._key, JsonWebKey):
return self._key.kid
return self._key.id

def _raise_if_unsupported(self, operation, algorithm):
Expand All @@ -64,34 +69,34 @@ def encrypt(self, algorithm, plaintext):
# type: (EncryptionAlgorithm, bytes) -> EncryptResult
self._raise_if_unsupported(KeyOperation.encrypt, algorithm)
ciphertext = self._internal_key.encrypt(plaintext, algorithm=algorithm.value)
return EncryptResult(key_id=self._key.id, algorithm=algorithm, ciphertext=ciphertext)
return EncryptResult(key_id=self.key_id, algorithm=algorithm, ciphertext=ciphertext)

def decrypt(self, algorithm, ciphertext):
# type: (EncryptionAlgorithm, bytes) -> DecryptResult
self._raise_if_unsupported(KeyOperation.decrypt, algorithm)
plaintext = self._internal_key.decrypt(ciphertext, iv=None, algorithm=algorithm.value)
return DecryptResult(key_id=self._key.id, algorithm=algorithm, plaintext=plaintext)
return DecryptResult(key_id=self.key_id, algorithm=algorithm, plaintext=plaintext)

def wrap_key(self, algorithm, key):
# type: (KeyWrapAlgorithm, bytes) -> WrapResult
self._raise_if_unsupported(KeyOperation.wrap_key, algorithm)
encrypted_key = self._internal_key.wrap_key(key, algorithm=algorithm.value)
return WrapResult(key_id=self._key.id, algorithm=algorithm, encrypted_key=encrypted_key)
return WrapResult(key_id=self.key_id, algorithm=algorithm, encrypted_key=encrypted_key)

def unwrap_key(self, algorithm, encrypted_key):
# type: (KeyWrapAlgorithm, bytes) -> UnwrapResult
self._raise_if_unsupported(KeyOperation.unwrap_key, algorithm)
unwrapped_key = self._internal_key.unwrap_key(encrypted_key, algorithm=algorithm.value)
return UnwrapResult(key_id=self._key.id, algorithm=algorithm, key=unwrapped_key)
return UnwrapResult(key_id=self.key_id, algorithm=algorithm, key=unwrapped_key)

def sign(self, algorithm, digest):
# type: (SignatureAlgorithm, bytes) -> SignResult
self._raise_if_unsupported(KeyOperation.sign, algorithm)
signature = self._internal_key.sign(digest, algorithm=algorithm.value)
return SignResult(key_id=self._key.id, algorithm=algorithm, signature=signature)
return SignResult(key_id=self.key_id, algorithm=algorithm, signature=signature)

def verify(self, algorithm, digest, signature):
# type: (SignatureAlgorithm, bytes, bytes) -> VerifyResult
self._raise_if_unsupported(KeyOperation.verify, algorithm)
is_valid = self._internal_key.verify(digest, signature, algorithm=algorithm.value)
return VerifyResult(key_id=self._key.id, algorithm=algorithm, is_valid=is_valid)
return VerifyResult(key_id=self.key_id, algorithm=algorithm, is_valid=is_valid)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from .local_provider import LocalCryptographyProvider
from .._internal import RsaKey
from ... import KeyOperation, KeyType
from ... import JsonWebKey, KeyOperation, KeyType

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Union
from .local_provider import Algorithm
from .._internal import Key
from ... import KeyVaultKey
Expand All @@ -19,10 +20,17 @@

class RsaCryptographyProvider(LocalCryptographyProvider):
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
if key.key_type not in (KeyType.rsa, KeyType.rsa_hsm):
# type: (Union[JsonWebKey, KeyVaultKey]) -> Key
if isinstance(key, JsonWebKey):
key_type = key.kty
jwk = key
else:
key_type = key.key_type
jwk = key.key

if key_type not in (KeyType.rsa, KeyType.rsa_hsm):
raise ValueError('"key" must be an RSA or RSA-HSM key')
return RsaKey.from_jwk(key.key)
return RsaKey.from_jwk(jwk)

def supports(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> bool
Expand Down
Loading