Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Key Vault] Add local-only mode to CryptographyClient #16565

Merged
merged 18 commits into from
Mar 10, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from ._key_validity import raise_if_time_invalid
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._models import JsonWebKey, KeyVaultKey
from .._shared import KeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
# pylint:disable=unused-import
# pylint:disable=unused-import,ungrouped-imports
from datetime import datetime
from typing import Any, Optional, Union
from azure.core.credentials import TokenCredential
from . import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from .._shared import KeyVaultResourceId

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,33 +100,69 @@ class CryptographyClient(KeyVaultClientBase):

def __init__(self, key, credential, **kwargs):
# type: (Union[KeyVaultKey, str], TokenCredential, **Any) -> None
self._jwk = kwargs.pop("_jwk", False)
self._not_before = None # type: Optional[datetime]
self._expires_on = None # type: Optional[datetime]
self._key_id = None # type: Optional[KeyVaultResourceId]

if isinstance(key, KeyVaultKey):
self._key = key
self._key = key.key
self._key_id = parse_key_vault_id(key.id)
if key.properties._attributes: # pylint:disable=protected-access
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
self._not_before = key.properties.not_before
self._expires_on = key.properties.expires_on
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
elif self._jwk:
self._key = key
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not self._key_id.version:
if not (self._jwk or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
vault_url = "vault_url" if self._jwk else self._key_id.vault_url
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
super(CryptographyClient, self).__init__(vault_url=vault_url, credential=credential, **kwargs)

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the client's key.

This property may be None when a client is constructed with :func:`from_jwk`.

:rtype: str
"""
return self._key_id.source_id
if not self._jwk:
return self._key_id.source_id
return self._key.kid

@classmethod
def from_jwk(cls, jwk):
# type: (Union[JsonWebKey, dict]) -> CryptographyClient
"""Creates a client that can only perform cryptographic operations locally.

:param jwk: the key's cryptographic material, as a JsonWebKey or dictionary.
:type jwk: JsonWebKey or dict
:rtype: CryptographyClient

.. literalinclude:: ../tests/test_examples_crypto.py
:start-after: [START from_jwk]
:end-before: [END from_jwk]
:caption: Create a CryptographyClient from a JsonWebKey
:language: python
:dedent: 8
"""
if isinstance(jwk, JsonWebKey):
key = jwk
else:
key = JsonWebKey(**jwk)
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
return cls(key, object(), _jwk=True)
mccoyp marked this conversation as resolved.
Show resolved Hide resolved

@distributed_trace
def _initialize(self, **kwargs):
Expand All @@ -138,15 +176,15 @@ def _initialize(self, **kwargs):
key_bundle = self._client.get_key(
self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs
)
self._key = KeyVaultKey._from_key_bundle(key_bundle) # pylint:disable=protected-access
self._key = KeyVaultKey._from_key_bundle(key_bundle).key # pylint:disable=protected-access
except HttpResponseError as ex:
# if we got a 403, we don't have keys/get permission and won't try to get the key again
# (other errors may be transient)
self._keys_get_forbidden = ex.status_code == 403

# if we have the key material, create a local crypto provider with it
if self._key:
self._local_provider = get_local_cryptography_provider(self._key)
self._local_provider = get_local_cryptography_provider(self._key, _key_id=self.key_id)
self._initialized = True
else:
# try to get the key again next time unless we know we're forbidden to do so
Expand Down Expand Up @@ -181,11 +219,17 @@ def encrypt(self, algorithm, plaintext, **kwargs):
self._initialize(**kwargs)

if self._local_provider.supports(KeyOperation.encrypt, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.encrypt(algorithm, plaintext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local encrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "encrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -240,6 +284,12 @@ def decrypt(self, algorithm, ciphertext, **kwargs):
return self._local_provider.decrypt(algorithm, ciphertext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local decrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "decrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.decrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -272,11 +322,17 @@ def wrap_key(self, algorithm, key, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.wrap_key, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.wrap_key(algorithm, key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local wrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "wrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.wrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -311,6 +367,12 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
return self._local_provider.unwrap_key(algorithm, encrypted_key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local unwrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "unwrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -340,11 +402,17 @@ def sign(self, algorithm, digest, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.sign, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.sign(algorithm, digest)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local sign operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "sign" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.sign(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -381,6 +449,12 @@ def verify(self, algorithm, digest, signature, **kwargs):
return self._local_provider.verify(algorithm, digest, signature)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local verify operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "verify" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.verify(
vault_base_url=self._key_id.vault_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from .. import KeyVaultKey
from typing import Optional


class _UTC_TZ(tzinfo):
Expand All @@ -28,20 +28,12 @@ def dst(self, dt):
_UTC = _UTC_TZ()


def raise_if_time_invalid(key):
# type: (KeyVaultKey) -> None
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
except AttributeError:
# we consider the key valid because a user must have deliberately created it
# (if it came from Key Vault, it would have those attributes)
return

def raise_if_time_invalid(not_before, expires_on):
# type: (Optional[datetime], Optional[datetime]) -> None
now = datetime.now(_UTC)
if (nbf and exp) and not nbf <= now <= exp:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(nbf, exp))
if nbf and nbf > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(nbf))
if exp and exp <= now:
raise ValueError("This client's key expired at {} (UTC)".format(exp))
if (not_before and expires_on) and not not_before <= now <= expires_on:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(not_before, expires_on))
if not_before and not_before > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(not_before))
if expires_on and expires_on <= now:
raise ValueError("This client's key expires_onired at {} (UTC)".format(expires_on))
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
from ... import KeyType

if TYPE_CHECKING:
from ... import KeyVaultKey
from typing import Any
from ... import JsonWebKey


def get_local_cryptography_provider(key):
# type: (KeyVaultKey) -> LocalCryptographyProvider
if key.key_type in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key)
if key.key_type in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key)
if key.key_type in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key)
def get_local_cryptography_provider(key, **kwargs):
# type: (JsonWebKey, **Any) -> LocalCryptographyProvider
if key.kty in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key, **kwargs)
if key.kty in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key, **kwargs)
if key.kty in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key, **kwargs)

raise ValueError('Unsupported key type "{}"'.format(key.key_type))
raise ValueError('Unsupported key type "{}"'.format(key.kty))


class NoLocalCryptography(LocalCryptographyProvider):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# pylint:disable=unused-import
from .local_provider import Algorithm
from .._internal import Key
from ... import KeyVaultKey
from ... import JsonWebKey

_PRIVATE_KEY_OPERATIONS = frozenset((KeyOperation.decrypt, KeyOperation.sign, KeyOperation.unwrap_key))


class EllipticCurveCryptographyProvider(LocalCryptographyProvider):
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
if key.key_type not in (KeyType.ec, KeyType.ec_hsm):
# type: (JsonWebKey) -> Key
if key.kty not in (KeyType.ec, KeyType.ec_hsm):
raise ValueError('"key" must be an EC or EC-HSM key')
return EllipticCurveKey.from_jwk(key.key)
return EllipticCurveKey.from_jwk(key)

def supports(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,25 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Union
from typing import Any, Optional, Union
from .._internal.key import Key
from .. import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from ... import KeyVaultKey
from ... import JsonWebKey

Algorithm = Union[EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm]


class LocalCryptographyProvider(ABC):
def __init__(self, key):
# type: (KeyVaultKey) -> None
self._allowed_ops = frozenset(key.key_operations)
def __init__(self, key, **kwargs):
# type: (JsonWebKey, **Any) -> None
self._allowed_ops = frozenset(key.key_ops)
self._internal_key = self._get_internal_key(key)
self._key = key
self._key_id = kwargs.pop("_key_id", None) or key.kid

@abc.abstractmethod
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
# type: (JsonWebKey) -> Key
pass

@abc.abstractmethod
Expand All @@ -44,12 +45,12 @@ def supports(self, operation, algorithm):

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the provider's key.

:rtype: str
"""
return self._key.id
return self._key_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be either a valid Key Vault identifier or whatever the user gave us:

Suggested change
return self._key_id
return self._key.get("kid")

Does that remove the need to ask a caller to give us _key_id separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking here is that a KeyVaultKey provided to CryptographyClient could have a valid Key Vault identifier for its id, but that the JsonWebKey could have a different kid. I assumed it made more sense to return the Key Vault identifier in the operation result if possible, which is why I added the _key_id keyword argument

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would happen only when someone does KeyVaultKey(key_id, jwk=jwk) with key_id != jwk["kid"]. Granted, that is possible, but seems unlikely and given KeyVaultKey's docs ask for a Key Vault identifier for key_id, it seems reasonable not to special case it. My $0.02 🤑


def _raise_if_unsupported(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> None
Expand All @@ -64,34 +65,34 @@ def encrypt(self, algorithm, plaintext):
# type: (EncryptionAlgorithm, bytes) -> EncryptResult
self._raise_if_unsupported(KeyOperation.encrypt, algorithm)
ciphertext = self._internal_key.encrypt(plaintext, algorithm=algorithm.value)
return EncryptResult(key_id=self._key.id, algorithm=algorithm, ciphertext=ciphertext)
return EncryptResult(key_id=self._key_id, algorithm=algorithm, ciphertext=ciphertext)

def decrypt(self, algorithm, ciphertext):
# type: (EncryptionAlgorithm, bytes) -> DecryptResult
self._raise_if_unsupported(KeyOperation.decrypt, algorithm)
plaintext = self._internal_key.decrypt(ciphertext, iv=None, algorithm=algorithm.value)
return DecryptResult(key_id=self._key.id, algorithm=algorithm, plaintext=plaintext)
return DecryptResult(key_id=self._key_id, algorithm=algorithm, plaintext=plaintext)

def wrap_key(self, algorithm, key):
# type: (KeyWrapAlgorithm, bytes) -> WrapResult
self._raise_if_unsupported(KeyOperation.wrap_key, algorithm)
encrypted_key = self._internal_key.wrap_key(key, algorithm=algorithm.value)
return WrapResult(key_id=self._key.id, algorithm=algorithm, encrypted_key=encrypted_key)
return WrapResult(key_id=self._key_id, algorithm=algorithm, encrypted_key=encrypted_key)

def unwrap_key(self, algorithm, encrypted_key):
# type: (KeyWrapAlgorithm, bytes) -> UnwrapResult
self._raise_if_unsupported(KeyOperation.unwrap_key, algorithm)
unwrapped_key = self._internal_key.unwrap_key(encrypted_key, algorithm=algorithm.value)
return UnwrapResult(key_id=self._key.id, algorithm=algorithm, key=unwrapped_key)
return UnwrapResult(key_id=self._key_id, algorithm=algorithm, key=unwrapped_key)

def sign(self, algorithm, digest):
# type: (SignatureAlgorithm, bytes) -> SignResult
self._raise_if_unsupported(KeyOperation.sign, algorithm)
signature = self._internal_key.sign(digest, algorithm=algorithm.value)
return SignResult(key_id=self._key.id, algorithm=algorithm, signature=signature)
return SignResult(key_id=self._key_id, algorithm=algorithm, signature=signature)

def verify(self, algorithm, digest, signature):
# type: (SignatureAlgorithm, bytes, bytes) -> VerifyResult
self._raise_if_unsupported(KeyOperation.verify, algorithm)
is_valid = self._internal_key.verify(digest, signature, algorithm=algorithm.value)
return VerifyResult(key_id=self._key.id, algorithm=algorithm, is_valid=is_valid)
return VerifyResult(key_id=self._key_id, algorithm=algorithm, is_valid=is_valid)
Loading