Skip to content

Commit

Permalink
add new method of determining token expiration instaed of relying on …
Browse files Browse the repository at this point in the history
…jwt.decode error code
  • Loading branch information
Steven Timm committed May 14, 2024
1 parent aa975b8 commit 60e336d
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions src/decisionengine_modules/NERSC/sources/NerscSFApi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import json

# import time
import time
import os

import jwt
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, config):
self.logger = self.logger.bind(
class_module=__name__.split(".")[-1],
)
self.localmap = {"uscms": "m2612", "fife": "m3249"}
self.localmap = {"uscms": "m2612", "fife": "m4599", "dunepro": "m3249" }
self.keys_list = ["hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name"]

def check_accesstoken(self, nersc_user):
Expand All @@ -62,43 +62,51 @@ def check_accesstoken(self, nersc_user):
except KeyError:
self.logger.error(f"Unknown user '{nersc_user}', exiting")
return None

print(nersc_user)
rawfile = params_dict["rawfile"]
pemfile = params_dict["private_key"]
clientidfile = params_dict["client_id_file"]
with open(clientidfile) as cifile:
client_id = cifile.read()
client_id = client_id.rstrip()

atoken = None
if not os.path.exists(rawfile):
self.logger.debug(f"{rawfile} does not exist. Need to generate")
else:
atoken = None
with open(rawfile) as afile:
atoken = afile.read()
atoken = atoken.rstrip()
# HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSignatureError
try:
jwt.decode(atoken, options={"verify_signature": False})
self.logger.debug("Access Token not expired. Returning without generating a new access token")
return atoken # This means the existing access token is not expired.

except jwt.ExpiredSignatureError:
self.logger.debug("Access Token expired")

certs = pem.parse_file(pemfile)
private_key = str(certs[0])
client = OAuth2Session(
client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt"
)
client.register_client_auth_method(PrivateKeyJWT(token_url))
resp = client.fetch_token(token_url, grant_type="client_credentials")
# HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSign

if atoken is not None:
rvalue = jwt.decode(atoken, options={"verify_signature": False})
ctime = int(time.time())
diff = ctime - rvalue['exp']
print( diff )
else:
self.logger.debug("there is no access token file, setting diff high to indicate expired")
diff=10000000

newtoken = resp["access_token"]
if diff < 0:
self.logger.debug("Access Token not expired. Returning without generating a new access token")
return atoken # This means the existing access token is not expired.

with open(rawfile, "w") as myfile:
myfile.write(newtoken)
return newtoken
else:
self.logger.debug("Access Token expired")

certs = pem.parse_file(pemfile)
private_key = str(certs[0])
client = OAuth2Session(
client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt"
)
client.register_client_auth_method(PrivateKeyJWT(token_url))
resp = client.fetch_token(token_url, grant_type="client_credentials")
newtoken = resp["access_token"]

with open(rawfile, "w") as myfile:
myfile.write(newtoken)
return newtoken

def get_headers2(self, access_token):
headers = {}
Expand All @@ -118,16 +126,32 @@ def requests_nersc(self, username):

def send_query(self):
results = []
print(self.constraints.get("usernames", []))
for username in self.constraints.get("usernames", []):
self.logger.debug("in send_query %s",username)
print(username)
returned_list = self.requests_nersc(username)
self.logger.debug(returned_list)
print(returned_list)
for each_dict in returned_list:
# HK> This if condition will choose only m3249 for fife and discard m3990
if each_dict["repo_name"] == self.localmap[username]:
local_dict = {each_key: each_dict[each_key] for each_key in self.keys_list}
local_dict["real_name"] = username
local_dict['real_name'] = username
results.append(local_dict)
return results

#self.localmap = {"uscms": "m2612", "fife": "m3249"}
#self.keys_list = [
# "hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name" ]

#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+
#| | hours_given | hours_used | id | project_hours_given | project_hours_used | repo_name |
#|----+---------------+--------------+-------+-----------------------+----------------------+-------------|
#| 0 | 600000 | 473490 | 54807 | 600000 | 473946 | m2612 |
#| 1 | 19109.1 | 0 | 63322 | 95545.7 | 24722.7 | m3249 |
#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+

def acquire(self):
self.logger.debug("in NerscSFApi acquire")
return {"Nersc_Allocation_SFAPI": pd.DataFrame(self.send_query())}
Expand Down

0 comments on commit 60e336d

Please sign in to comment.