diff --git a/README.md b/README.md index d7ba30b..988a9de 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,14 @@ users = { @auth.verify_password def verify_password(username, password): - if username in users: - return check_password_hash(users.get(username), password) - return False + if username in users and \ + check_password_hash(users.get(username), password): + return username @app.route('/') @auth.login_required def index(): - return "Hello, %s!" % auth.username() + return "Hello, %s!" % auth.current_user() if __name__ == '__main__': app.run() diff --git a/examples/basic_auth.py b/examples/basic_auth.py index 75aa865..692b213 100644 --- a/examples/basic_auth.py +++ b/examples/basic_auth.py @@ -23,15 +23,15 @@ @auth.verify_password def verify_password(username, password): - if username in users: - return check_password_hash(users.get(username), password) - return False + if username in users and check_password_hash(users.get(username), + password): + return username @app.route('/') @auth.login_required def index(): - return "Hello, %s!" % auth.username() + return "Hello, %s!" % auth.current_user() if __name__ == '__main__': diff --git a/examples/multi_auth.py b/examples/multi_auth.py index 0d79e47..3e47803 100644 --- a/examples/multi_auth.py +++ b/examples/multi_auth.py @@ -7,7 +7,7 @@ The root URL for this application can be accessed via basic auth, providing username and password, or via token auth, providing a bearer JWS token. """ -from flask import Flask, g +from flask import Flask from flask_httpauth import HTTPBasicAuth, HTTPTokenAuth, MultiAuth from werkzeug.security import generate_password_hash, check_password_hash from itsdangerous import TimedJSONWebSignatureSerializer as JWS @@ -34,31 +34,25 @@ @basic_auth.verify_password def verify_password(username, password): - g.user = None if username in users: if check_password_hash(users.get(username), password): - g.user = username - return True - return False + return username @token_auth.verify_token def verify_token(token): - g.user = None try: data = jws.loads(token) except: # noqa: E722 return False if 'username' in data: - g.user = data['username'] - return True - return False + return data['username'] @app.route('/') @multi_auth.login_required def index(): - return "Hello, %s!" % g.user + return "Hello, %s!" % multi_auth.current_user() if __name__ == '__main__': diff --git a/examples/token_auth.py b/examples/token_auth.py index 9b18678..e2579c8 100644 --- a/examples/token_auth.py +++ b/examples/token_auth.py @@ -12,7 +12,7 @@ The response should include the username, which is obtained from the token. """ -from flask import Flask, g +from flask import Flask from flask_httpauth import HTTPTokenAuth from itsdangerous import TimedJSONWebSignatureSerializer as Serializer @@ -32,21 +32,18 @@ @auth.verify_token def verify_token(token): - g.user = None try: data = token_serializer.loads(token) except: # noqa: E722 return False if 'username' in data: - g.user = data['username'] - return True - return False + return data['username'] @app.route('/') @auth.login_required def index(): - return "Hello, %s!" % g.user + return "Hello, %s!" % auth.current_user() if __name__ == '__main__': diff --git a/flask_httpauth.py b/flask_httpauth.py index 79143cf..fb70469 100644 --- a/flask_httpauth.py +++ b/flask_httpauth.py @@ -11,7 +11,7 @@ from functools import wraps from hashlib import md5 from random import Random, SystemRandom -from flask import request, make_response, session +from flask import request, make_response, session, g from werkzeug.datastructures import Authorization from werkzeug.security import safe_str_cmp @@ -134,11 +134,14 @@ def decorated(*args, **kwargs): password = self.get_auth_password(auth) user = self.authenticate(auth, password) - if not user or not self.authorize(role, user, auth): + if user in (False, None) or not self.authorize( + role, user, auth): # Clear TCP receive buffer of any pending data request.data return self.auth_error_callback() + g.flask_httpauth_user = user if user is not True \ + else auth.username if auth else None return f(*args, **kwargs) return decorated @@ -151,6 +154,10 @@ def username(self): return "" return request.authorization.username + def current_user(self): + if hasattr(g, 'flask_httpauth_user'): + return g.flask_httpauth_user + class HTTPBasicAuth(HTTPAuth): def __init__(self, scheme=None, realm=None): @@ -177,16 +184,16 @@ def authenticate(self, auth, stored_password): if self.verify_password_callback: return self.verify_password_callback(username, client_password) if not auth: - return False + return if self.hash_password_callback: try: client_password = self.hash_password_callback(client_password) except TypeError: client_password = self.hash_password_callback(username, client_password) - return client_password is not None and \ + return auth.username if client_password is not None and \ stored_password is not None and \ - safe_str_cmp(client_password, stored_password) + safe_str_cmp(client_password, stored_password) else None class HTTPDigestAuth(HTTPAuth): @@ -223,7 +230,7 @@ def default_generate_opaque(): def default_verify_opaque(opaque): session_opaque = session.get("auth_opaque") - if opaque is None or session_opaque is None: + if opaque is None or session_opaque is None: # pragma: no cover return False return safe_str_cmp(opaque, session_opaque) @@ -341,3 +348,7 @@ def decorated(*args, **kwargs): if f: return login_required_internal(f) return login_required_internal + + def current_user(self): + if hasattr(g, 'flask_httpauth_user'): # pragma: no cover + return g.flask_httpauth_user diff --git a/tests/test_basic_verify_password.py b/tests/test_basic_verify_password.py index 00718a7..b552ecb 100644 --- a/tests/test_basic_verify_password.py +++ b/tests/test_basic_verify_password.py @@ -5,6 +5,8 @@ class HTTPAuthTestCase(unittest.TestCase): + use_old_style_callback = False + def setUp(self): app = Flask(__name__) app.config['SECRET_KEY'] = 'my secret' @@ -13,18 +15,29 @@ def setUp(self): @basic_verify_auth.verify_password def basic_verify_auth_verify_password(username, password): - g.anon = False - if username == 'john': - return password == 'hello' - elif username == 'susan': - return password == 'bye' - elif username == '': - g.anon = True - return True - return False + if self.use_old_style_callback: + g.anon = False + if username == 'john': + return password == 'hello' + elif username == 'susan': + return password == 'bye' + elif username == '': + g.anon = True + return True + return False + else: + g.anon = False + if username == 'john' and password == 'hello': + return 'john' + elif username == 'susan' and password == 'bye': + return 'susan' + elif username == '': + g.anon = True + return '' @basic_verify_auth.error_handler def error_handler(): + self.assertIsNone(basic_verify_auth.current_user()) return 'error', 403 # use a custom error status @app.route('/') @@ -34,8 +47,12 @@ def index(): @app.route('/basic-verify') @basic_verify_auth.login_required def basic_verify_auth_route(): - return 'basic_verify_auth:' + basic_verify_auth.username() + \ - ' anon:' + str(g.anon) + if self.use_old_style_callback: + return 'basic_verify_auth:' + basic_verify_auth.username() + \ + ' anon:' + str(g.anon) + else: + return 'basic_verify_auth:' + \ + basic_verify_auth.current_user() + ' anon:' + str(g.anon) self.app = app self.basic_verify_auth = basic_verify_auth @@ -57,3 +74,7 @@ def test_verify_auth_login_invalid(self): '/basic-verify', headers={'Authorization': 'Basic ' + creds}) self.assertEqual(response.status_code, 403) self.assertTrue('WWW-Authenticate' in response.headers) + + +class HTTPAuthTestCaseOldStyle(HTTPAuthTestCase): + use_old_style_callback = True diff --git a/tests/test_multi.py b/tests/test_multi.py index d660267..f9d3253 100644 --- a/tests/test_multi.py +++ b/tests/test_multi.py @@ -44,7 +44,7 @@ def index(): @app.route('/protected') @multi_auth.login_required def auth_route(): - return 'access granted' + return 'access granted:' + str(multi_auth.current_user()) @app.route('/protected-with-role') @multi_auth.login_required(role='foo') @@ -65,7 +65,7 @@ def test_multi_auth_login_valid_basic(self): creds = base64.b64encode(b'john:hello').decode('utf-8') response = self.client.get( '/protected', headers={'Authorization': 'Basic ' + creds}) - self.assertEqual(response.data.decode('utf-8'), 'access granted') + self.assertEqual(response.data.decode('utf-8'), 'access granted:john') def test_multi_auth_login_invalid_basic(self): creds = base64.b64encode(b'john:bye').decode('utf-8') @@ -80,7 +80,7 @@ def test_multi_auth_login_valid_token(self): response = self.client.get( '/protected', headers={'Authorization': 'MyToken this-is-the-token!'}) - self.assertEqual(response.data.decode('utf-8'), 'access granted') + self.assertEqual(response.data.decode('utf-8'), 'access granted:None') def test_multi_auth_login_invalid_token(self): response = self.client.get( diff --git a/tests/test_roles.py b/tests/test_roles.py index 06bf4f7..4efa027 100644 --- a/tests/test_roles.py +++ b/tests/test_roles.py @@ -18,6 +18,8 @@ def roles_auth_verify_password(username, password): return password == 'hello' elif username == 'susan': return password == 'bye' + elif username == 'cindy': + return password == 'byebye' elif username == '': g.anon = True return True @@ -30,6 +32,8 @@ def get_user_roles(auth): return 'normal' elif username == 'susan': return ('normal', 'special') + elif username == 'cindy': + return None @roles_auth.error_handler def error_handler(): @@ -81,13 +85,20 @@ def test_verify_auth_login_valid_special(self): '/special', headers={'Authorization': 'Basic ' + creds}) self.assertEqual(response.data, b'special:susan') - def test_verify_auth_login_invalid_special(self): + def test_verify_auth_login_invalid_special_1(self): creds = base64.b64encode(b'john:hello').decode('utf-8') response = self.client.get( '/special', headers={'Authorization': 'Basic ' + creds}) self.assertEqual(response.status_code, 403) self.assertTrue('WWW-Authenticate' in response.headers) + def test_verify_auth_login_invalid_special_2(self): + creds = base64.b64encode(b'cindy:byebye').decode('utf-8') + response = self.client.get( + '/special', headers={'Authorization': 'Basic ' + creds}) + self.assertEqual(response.status_code, 403) + self.assertTrue('WWW-Authenticate' in response.headers) + def test_verify_auth_login_valid_normal_or_special_1(self): creds = base64.b64encode(b'susan:bye').decode('utf-8') response = self.client.get( diff --git a/tests/test_token.py b/tests/test_token.py index d3dab59..b9b5b05 100644 --- a/tests/test_token.py +++ b/tests/test_token.py @@ -12,7 +12,8 @@ def setUp(self): @token_auth.verify_token def verify_token(token): - return token == 'this-is-the-token!' + if token == 'this-is-the-token!': + return 'user' @token_auth.error_handler def error_handler(): @@ -25,7 +26,7 @@ def index(): @app.route('/protected') @token_auth.login_required def token_auth_route(): - return 'token_auth' + return 'token_auth:' + token_auth.current_user() self.app = app self.token_auth = token_auth @@ -47,13 +48,13 @@ def test_token_auth_login_valid(self): response = self.client.get( '/protected', headers={'Authorization': 'MyToken this-is-the-token!'}) - self.assertEqual(response.data.decode('utf-8'), 'token_auth') + self.assertEqual(response.data.decode('utf-8'), 'token_auth:user') def test_token_auth_login_valid_different_case(self): response = self.client.get( '/protected', headers={'Authorization': 'mytoken this-is-the-token!'}) - self.assertEqual(response.data.decode('utf-8'), 'token_auth') + self.assertEqual(response.data.decode('utf-8'), 'token_auth:user') def test_token_auth_login_invalid_token(self): response = self.client.get(