Skip to content

Commit

Permalink
Remove unnecessary compatibility shims for Python 2 (#498)
Browse files Browse the repository at this point in the history
As the project is Python 3 only, can remove the compatibility shims in
compat.py.

Type checking has been simplified where it can:
  - str is iterable
  - bytes is iterable
  - use isinstance instead of issubclass

The remaining function bytes_from_int() has been moved to utils.py.
  • Loading branch information
jdufresne committed Jun 19, 2020
1 parent 07210ee commit dc8dc7d
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 83 deletions.
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ API Reference
Use ``verify_exp`` instead


:param str|iterable audience: optional, the value for ``verify_aud`` check
:param iterable audience: optional, the value for ``verify_aud`` check
:param str issuer: optional, the value for ``verify_iss`` check
:param int|float leeway: a time margin in seconds for the expiration check
:param bool verify_expiration:
Expand Down
9 changes: 7 additions & 2 deletions jwt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def decode_payload(args):
raise OSError("Cannot read from stdin: terminal not a TTY")

token = token.encode("utf-8")
data = decode(token, key=args.key, verify=args.verify,
audience=args.audience, issuer=args.issuer)
data = decode(
token,
key=args.key,
verify=args.verify,
audience=args.audience,
issuer=args.issuer,
)

return json.dumps(data)

Expand Down
7 changes: 3 additions & 4 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import hmac
import json

from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import (
base64url_decode,
Expand Down Expand Up @@ -216,7 +215,7 @@ def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()

def verify(self, msg, key, sig):
return constant_time_compare(sig, self.sign(msg, key))
return hmac.compare_digest(sig, self.sign(msg, key))


if has_crypto: # noqa: C901
Expand All @@ -238,7 +237,7 @@ def prepare_key(self, key):
if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey):
return key

if isinstance(key, string_types):
if isinstance(key, (bytes, str)):
key = force_bytes(key)

try:
Expand Down Expand Up @@ -395,7 +394,7 @@ def prepare_key(self, key):
):
return key

if isinstance(key, string_types):
if isinstance(key, (bytes, str)):
key = force_bytes(key)

# Attempt to load key. We don't know if it's
Expand Down
10 changes: 5 additions & 5 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import binascii
import json
import warnings
from collections.abc import Mapping

from .algorithms import requires_cryptography # NOQA
from .algorithms import Algorithm, get_default_algorithms, has_crypto
from .compat import Mapping, binary_type, string_types, text_type
from .exceptions import (
DecodeError,
InvalidAlgorithmError,
Expand Down Expand Up @@ -178,12 +178,12 @@ def get_unverified_header(self, jwt):
return headers

def _load(self, jwt):
if isinstance(jwt, text_type):
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")

if not issubclass(type(jwt), binary_type):
if not isinstance(jwt, bytes):
raise DecodeError(
"Invalid token type. Token must be a {}".format(binary_type)
"Invalid token type. Token must be a {}".format(bytes)
)

try:
Expand Down Expand Up @@ -249,7 +249,7 @@ def _validate_headers(self, headers):
self._validate_kid(headers["kid"])

def _validate_kid(self, kid):
if not isinstance(kid, string_types):
if not isinstance(kid, (bytes, str)):
raise InvalidTokenError("Key ID header parameter must be a string")


Expand Down
12 changes: 6 additions & 6 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import warnings
from calendar import timegm
from collections.abc import Iterable, Mapping
from datetime import datetime, timedelta

from .algorithms import Algorithm, get_default_algorithms # NOQA
from .api_jws import PyJWS
from .compat import Iterable, Mapping, string_types
from .exceptions import (
DecodeError,
ExpiredSignatureError,
Expand Down Expand Up @@ -149,8 +149,8 @@ def _validate_claims(
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()

if not isinstance(audience, (string_types, type(None), Iterable)):
raise TypeError("audience must be a string, iterable, or None")
if not isinstance(audience, (type(None), Iterable)):
raise TypeError("audience must be an iterable or None")

self._validate_required_claims(payload, options)

Expand Down Expand Up @@ -220,14 +220,14 @@ def _validate_aud(self, payload, audience):

audience_claims = payload["aud"]

if isinstance(audience_claims, string_types):
if isinstance(audience_claims, (bytes, str)):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
raise InvalidAudienceError("Invalid claim format in token")
if any(not isinstance(c, string_types) for c in audience_claims):
if any(not isinstance(c, (bytes, str)) for c in audience_claims):
raise InvalidAudienceError("Invalid claim format in token")

if isinstance(audience, string_types):
if isinstance(audience, (bytes, str)):
audience = [audience]

if not any(aud in audience_claims for aud in audience):
Expand Down
30 changes: 0 additions & 30 deletions jwt/compat.py

This file was deleted.

5 changes: 2 additions & 3 deletions jwt/contrib/algorithms/py_ecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import ecdsa

from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type


class ECAlgorithm(Algorithm):
Expand All @@ -33,8 +32,8 @@ def prepare_key(self, key):
):
return key

if isinstance(key, string_types):
if isinstance(key, text_type):
if isinstance(key, (bytes, str)):
if isinstance(key, str):
key = key.encode("utf-8")

# Attempt to load key. We don't know if it's
Expand Down
5 changes: 2 additions & 3 deletions jwt/contrib/algorithms/py_ed25519.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type


class Ed25519Algorithm(Algorithm):
Expand All @@ -33,8 +32,8 @@ def prepare_key(self, key):
if isinstance(key, (Ed25519PrivateKey, Ed25519PublicKey)):
return key

if isinstance(key, string_types):
if isinstance(key, text_type):
if isinstance(key, (bytes, str)):
if isinstance(key, str):
key = key.encode("utf-8")
str_key = key.decode("utf-8")

Expand Down
5 changes: 2 additions & 3 deletions jwt/contrib/algorithms/pycrypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from Crypto.Signature import PKCS1_v1_5

from jwt.algorithms import Algorithm
from jwt.compat import string_types, text_type


class RSAAlgorithm(Algorithm):
Expand All @@ -30,8 +29,8 @@ def prepare_key(self, key):
if isinstance(key, RSA._RSAobj):
return key

if isinstance(key, string_types):
if isinstance(key, text_type):
if isinstance(key, (bytes, str)):
if isinstance(key, str):
key = key.encode("utf-8")

key = RSA.importKey(key)
Expand Down
25 changes: 17 additions & 8 deletions jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import binascii
import struct

from .compat import binary_type, bytes_from_int, text_type

try:
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature,
Expand All @@ -14,25 +12,25 @@


def force_unicode(value):
if isinstance(value, binary_type):
if isinstance(value, bytes):
return value.decode("utf-8")
elif isinstance(value, text_type):
elif isinstance(value, str):
return value
else:
raise TypeError("Expected a string value")


def force_bytes(value):
if isinstance(value, text_type):
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, binary_type):
elif isinstance(value, bytes):
return value
else:
raise TypeError("Expected a string value")


def base64url_decode(input):
if isinstance(input, text_type):
if isinstance(input, str):
input = input.encode("ascii")

rem = len(input) % 4
Expand Down Expand Up @@ -60,7 +58,7 @@ def to_base64url_uint(val):


def from_base64url_uint(val):
if isinstance(val, text_type):
if isinstance(val, str):
val = val.encode("ascii")

data = base64url_decode(val)
Expand Down Expand Up @@ -92,6 +90,17 @@ def bytes_to_number(string):
return int(binascii.b2a_hex(string), 16)


def bytes_from_int(val):
remaining = val
byte_length = 0

while remaining != 0:
remaining = remaining >> 8
byte_length += 1

return val.to_bytes(byte_length, "big", signed=False)


def der_to_raw_signature(der_sig, curve):
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_decode_with_invalid_audience_param_throws_exception(self, jwt):
jwt.decode(example_jwt, secret, audience=1)

exception = context.value
assert str(exception) == "audience must be a string, iterable, or None"
assert str(exception) == "audience must be an iterable or None"

def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt):
secret = "secret"
Expand Down
17 changes: 0 additions & 17 deletions tests/test_compat.py

This file was deleted.

0 comments on commit dc8dc7d

Please sign in to comment.