Skip to content

Commit

Permalink
Make auth slightly more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Sep 26, 2023
1 parent 5ae253e commit eb74725
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
58 changes: 38 additions & 20 deletions src/diracx/client/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
import json
import requests
import logging

from pathlib import Path
from typing import Any, Dict, List, Optional, cast
Expand Down Expand Up @@ -38,6 +39,9 @@ def patch_sdk():
"""


logger = logging.getLogger(__name__)


class DiracTokenCredential(TokenCredential):
"""Tailor get_token() for our context"""

Expand Down Expand Up @@ -98,12 +102,21 @@ def on_request(
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
self._token = self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
try:
credentials = json.loads(self._credential.location.read_text())
except Exception:
logger.warning(
"Cannot load credentials from %s", self._credential.location
)
else:
self._token = self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

if self._token:
request.http_request.headers[
"Authorization"
] = f"Bearer {self._token.token}"


class DiracClient(DiracGenerated):
Expand Down Expand Up @@ -160,6 +173,7 @@ def refresh_token(
)

if response.status_code != 200:
location.unlink()
raise RuntimeError(
f"An issue occured while refreshing your access token: {response.json()['detail']}"
)
Expand Down Expand Up @@ -192,24 +206,28 @@ def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
raise RuntimeError("credentials are not set")

# Load the existing credentials
if not token:
credentials = json.loads(location.read_text())
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
)

# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if not is_token_valid(token):
return None

return token
try:
if not token:
credentials = json.loads(location.read_text())
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
)
except Exception:
logger.warning("Cannot load credentials from %s", location)
pass
else:
# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if is_token_valid(token):
return token
return None


def is_token_valid(token: AccessToken) -> bool:
"""Condition to get a new token"""
# TODO: Should we check against the userinfo endpoint?
return (
datetime.utcfromtimestamp(token.expires_on) - datetime.utcnow()
).total_seconds() > 300
25 changes: 19 additions & 6 deletions src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import json
import logging
from types import TracebackType
from pathlib import Path
from typing import Any, List, Optional
Expand All @@ -24,6 +25,8 @@
"DiracClient",
] # Add all objects you want publicly available to users at this package level

logger = logging.getLogger(__name__)


def patch_sdk():
"""Do not remove from this file.
Expand Down Expand Up @@ -104,19 +107,29 @@ async def on_request(
credentials: dict[str, Any]

try:
# TODO: Use httpx and await this call
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
self._token = await self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
try:
credentials = json.loads(self._credential.location.read_text())
except Exception:
logger.warning(
"Cannot load credentials from %s", self._credential.location
)
else:
self._token = await self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

if self._token:
request.http_request.headers[
"Authorization"
] = f"Bearer {self._token.token}"


class DiracClient(DiracGenerated):
Expand Down

0 comments on commit eb74725

Please sign in to comment.