Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWK support for HMAC and RSA keys #202

Merged
merged 1 commit into from
Oct 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 144 additions & 14 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import hashlib
import hmac
import json

from .compat import binary_type, constant_time_compare, is_string_type

from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import der_to_raw_signature, raw_to_der_signature
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
to_base64url_uint
)

try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey, RSAPublicKey
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
Expand Down Expand Up @@ -77,6 +84,20 @@ def verify(self, msg, key, sig):
"""
raise NotImplementedError

@staticmethod
def to_jwk(key_obj):
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
def from_jwk(jwk):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError


class NoneAlgorithm(Algorithm):
"""
Expand Down Expand Up @@ -112,11 +133,7 @@ def __init__(self, hash_alg):
self.hash_alg = hash_alg

def prepare_key(self, key):
if not is_string_type(key):
raise TypeError('Expecting a string- or bytes-formatted key.')

if not isinstance(key, binary_type):
key = key.encode('utf-8')
key = force_bytes(key)

invalid_strings = [
b'-----BEGIN PUBLIC KEY-----',
Expand All @@ -131,6 +148,22 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
return json.dumps({
'k': force_unicode(base64url_encode(force_bytes(key_obj))),
'kty': 'oct'
})

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')

return base64url_decode(obj['k'])

def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()

Expand All @@ -156,9 +189,8 @@ def prepare_key(self, key):
isinstance(key, RSAPublicKey):
return key

if is_string_type(key):
if not isinstance(key, binary_type):
key = key.encode('utf-8')
if isinstance(key, string_types):
key = force_bytes(key)

try:
if key.startswith(b'ssh-rsa'):
Expand All @@ -172,6 +204,105 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
obj = None

if getattr(key_obj, 'private_numbers', None):
# Private key
numbers = key_obj.private_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['sign'],
'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
'd': force_unicode(to_base64url_uint(numbers.d)),
'p': force_unicode(to_base64url_uint(numbers.p)),
'q': force_unicode(to_base64url_uint(numbers.q)),
'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
'qi': force_unicode(to_base64url_uint(numbers.iqmp))
}

elif getattr(key_obj, 'verifier', None):
# Public key
numbers = key_obj.public_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['verify'],
'n': force_unicode(to_base64url_uint(numbers.n)),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find specific documentation on encoding RSA keys, but the examples I have seen include modulus (n) and public exponent (e) values in both the private and public keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct. That was an oversight on my part. I missed the text that says:

In addition to the members used to represent RSA public keys, the
following members are used to represent RSA private keys.

'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
raise InvalidKeyError('Not a public or private key')

return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
try:
obj = json.loads(jwk)
except ValueError:
raise InvalidKeyError('Key is not valid JSON')

if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')

if 'd' in obj and 'e' in obj and 'n' in obj:
Copy link

@clintonb clintonb Aug 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither 'e' nor 'n' is included the private key JWK created above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. I definitely missed that. I'll fix it.

# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')

other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)

if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')

public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj['d']),
p=from_base64url_uint(obj['p']),
q=from_base64url_uint(obj['q']),
dmp1=from_base64url_uint(obj['dp']),
dmq1=from_base64url_uint(obj['dq']),
iqmp=from_base64url_uint(obj['qi']),
public_numbers=public_numbers
)
else:
d = from_base64url_uint(obj['d'])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e
)

numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers
)

return numbers.private_key(default_backend())
elif 'n' in obj and 'e' in obj:
# Public key
numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')

def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
Expand Down Expand Up @@ -213,9 +344,8 @@ def prepare_key(self, key):
isinstance(key, EllipticCurvePublicKey):
return key

if is_string_type(key):
if not isinstance(key, binary_type):
key = key.encode('utf-8')
if isinstance(key, string_types):
key = force_bytes(key)

# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
Expand Down
14 changes: 8 additions & 6 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .algorithms import Algorithm, get_default_algorithms # NOQA
from .compat import binary_type, string_types, text_type
from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError
from .utils import base64url_decode, base64url_encode, merge_dict
from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict


class PyJWS(object):
Expand Down Expand Up @@ -82,11 +82,13 @@ def encode(self, payload, key, algorithm='HS256', headers=None,
self._validate_headers(headers)
header.update(headers)

json_header = json.dumps(
header,
separators=(',', ':'),
cls=json_encoder
).encode('utf-8')
json_header = force_bytes(
json.dumps(
header,
separators=(',', ':'),
cls=json_encoder
)
)

segments.append(base64url_encode(json_header))
segments.append(base64url_encode(payload))
Expand Down
28 changes: 23 additions & 5 deletions jwt/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
versions of python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
import sys
import hmac
import struct

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: alphabetize these imports.

import sys


PY3 = sys.version_info[0] == 3
Expand All @@ -20,10 +21,6 @@
string_types = (text_type, binary_type)


def is_string_type(val):
return any([isinstance(val, typ) for typ in string_types])


def timedelta_total_seconds(delta):
try:
delta.total_seconds
Expand Down Expand Up @@ -56,3 +53,24 @@ def constant_time_compare(val1, val2):
result |= ord(x) ^ ord(y)

return result == 0

# Use int.to_bytes if it exists (Python 3)
if getattr(int, 'to_bytes', None):
def bytes_from_int(val):
remaining = val
byte_length = 0

while remaining != 0:
remaining = remaining >> 8
byte_length += 1

return val.to_bytes(byte_length, 'big', signed=False)
else:
def bytes_from_int(val):
buf = []
while val:
val, remainder = divmod(val, 256)
buf.append(remainder)

buf.reverse()
return struct.pack('%sB' % len(buf), *buf)
46 changes: 46 additions & 0 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import base64
import binascii
import struct

from .compat import binary_type, bytes_from_int, text_type

try:
from cryptography.hazmat.primitives.asymmetric.utils import (
Expand All @@ -9,7 +12,28 @@
pass


def force_unicode(value):
if isinstance(value, binary_type):
return value.decode('utf-8')
elif isinstance(value, text_type):
return value
else:
raise TypeError('Expected a string value')


def force_bytes(value):
if isinstance(value, text_type):
return value.encode('utf-8')
elif isinstance(value, binary_type):
return value
else:
raise TypeError('Expected a string value')


def base64url_decode(input):
if isinstance(input, text_type):
input = input.encode('ascii')

rem = len(input) % 4

if rem > 0:
Expand All @@ -22,6 +46,28 @@ def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')


def to_base64url_uint(val):
if val < 0:
raise ValueError('Must be a positive integer')

int_bytes = bytes_from_int(val)

if len(int_bytes) == 0:
int_bytes = b'\x00'

return base64url_encode(int_bytes)


def from_base64url_uint(val):
if isinstance(val, text_type):
val = val.encode('ascii')

data = base64url_decode(val)

buf = struct.unpack('%sB' % len(data), data)
return int(''.join(["%02x" % byte for byte in buf]), 16)


def merge_dict(original, updates):
if not updates:
return original
Expand Down
Loading