Skip to content

Commit

Permalink
Implement additional checks on token
Browse files Browse the repository at this point in the history
  • Loading branch information
manisandro committed Jul 9, 2024
1 parent 2182a57 commit 53fabae
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
4 changes: 4 additions & 0 deletions schemas/sogis-mysoch-auth.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
"tenant_header_value": {
"description": "Value of tenant header to set when redirecting on successfull mysoch-auth",
"type": "string"
},
"max_token_duration": {
"description": "Maximum allowed JWT token validity duration (exp - nbf), in seconds. Default: 60",
"type": "integer"
}
},
"required": [
Expand Down
66 changes: 63 additions & 3 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import os
import sys
import time
from datetime import datetime
from time import time
from base64 import b64decode, urlsafe_b64encode
from flask import Flask, jsonify, request, abort, make_response, redirect
from flask_jwt_extended import create_access_token, set_access_cookies, unset_jwt_cookies
Expand Down Expand Up @@ -37,6 +38,26 @@
app.wsgi_app = TenantPrefixMiddleware(app.wsgi_app)
app.session_interface = TenantSessionInterface(os.environ)

class ExpiringSet():
def __init__(self, max_age_seconds):
self.age = max_age_seconds
self.container = {}

def contains(self, value):
if value not in self.container:
return False
if time() - self.container[value] > self.age:
del self.container[value]
return False

return True

def add(self, value):
self.container[value] = time()

prev_tokens = ExpiringSet(120)



@app.route('/login', methods=['GET'])
def login():
Expand All @@ -55,6 +76,11 @@ def login():
app.logger.info("login: No redirect URL")
abort(400, "No redirect URL")

if prev_tokens.contains(token):
app.logger.info("login: Token already used")
abort(400, "Token already used")
prev_tokens.add(token)

# Decode JWE
jwe_secret = urlsafe_b64encode(config.get("jwe_secret", "").encode()).decode()
jwt_secret = urlsafe_b64encode(config.get("jwt_secret", "").encode()).decode()
Expand All @@ -70,6 +96,21 @@ def login():
except:
abort(400, "Token decryption failed")

# Validate JWE header
app.logger.debug("JWE header: %s" % jwe.jose_header)
if jwe.jose_header.get('alg') != 'dir':
app.logger.info("login: Bad value for JWE alg")
abort(400)
if jwe.jose_header.get('enc') not in ['A128CBC-HS256', 'A256CBC-HS512']:
app.logger.info("login: Bad value for JWE enc")
abort(400)
if jwe.jose_header.get('typ') != 'JWT':
app.logger.info("login: Bad value for JWE typ")
abort(400)
if jwe.jose_header.get('cty') != 'JWT':
app.logger.info("login: Bad value for JWE cty")
abort(400)

# Verify and decode JWT
jwt = JWT()
# jwt.leeway = 100000000000000 # NOTE Testing only
Expand All @@ -79,13 +120,33 @@ def login():
app.logger.debug("Token validation failed: %s" % str(e))
abort(400, "Token validation failed")

# Validate JWT header
app.logger.debug("JWT header: %s" % jwt.header)
jwt_header = json.loads(jwt.header)
if jwt_header.get('alg') != 'HS256':
app.logger.info("login: Bad value for JWT alg")
abort(400)
if jwt_header.get('typ') != 'JWT':
app.logger.info("login: Bad value for JWT typ")
abort(400)

# Extract and validate JWT claims
claims = json.loads(jwt.claims)

app.logger.debug("Decoded claims %s" % jwt.claims)

# Verify iat and duration
if not claims.get("iat", None) or claims["iat"] > datetime.now().timestamp():
app.logger.info("login: Bad value for JWT claim iat")
abort(400)

if claims['exp'] - claims['nbf'] > config.get("max_token_duration", 60):
app.logger.info("login: Bad JWT token duration'")
abort(400)

# Verify ISS
if claims["iss"] not in config.get("allowed_iss", []):
app.logger.info("login: Bad value for iss")
app.logger.info("login: Bad value for JWT claim iss")
abort(400)

# Database
Expand All @@ -94,7 +155,6 @@ def login():
db = db_engine.db_engine(db_url)

# Verify login

userid = next(claims[userid_claim] for userid_claim in config.get("userid_claims", []) if claims.get(userid_claim))
displayname = next(claims[displayname_claim] for displayname_claim in config.get("displayname_claims", []) if claims.get(displayname_claim))
app.logger.debug("userid: %s, displayname: %s" % (userid, displayname))
Expand Down

0 comments on commit 53fabae

Please sign in to comment.