Skip to content

Commit

Permalink
Expose leeway in clients
Browse files Browse the repository at this point in the history
Commit 3da1fdc introduced a "leeway" parameter for proactive token
refreshing. Expose this parameter in clients (e.g., requests client) to
allow configuring it by the library's users.
  • Loading branch information
michalismeng committed Apr 16, 2024
1 parent 610622e commit 64655bf
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 11 deletions.
6 changes: 3 additions & 3 deletions authlib/integrations/httpx_client/oauth2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, **kwargs):
update_token=None, leeway=60, **kwargs):

# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
Expand All @@ -75,7 +75,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
update_token=update_token, leeway=leeway, **kwargs
)

async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
Expand Down Expand Up @@ -106,7 +106,7 @@ async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAUL

async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired():
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
if refresh_token and url:
Expand Down
7 changes: 4 additions & 3 deletions authlib/integrations/requests_client/assertion_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class AssertionAuth(OAuth2Auth):
def ensure_active_token(self):
if not self.token or self.token.is_expired() and self.client:
if self.client and (not self.token or self.token.is_expired(self.client.leeway)):
return self.client.refresh_token()


Expand All @@ -25,15 +25,16 @@ class AssertionSession(AssertionClient, Session):
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, default_timeout=None, **kwargs):
claims=None, token_placement='header', scope=None, default_timeout=None,
leeway=60, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
token_placement=token_placement, scope=scope, leeway=leeway, **kwargs
)

def request(self, method, url, withhold_token=False, auth=None, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions authlib/integrations/requests_client/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class OAuth2Session(OAuth2Client, Session):
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
:param default_timeout: If settled, every requests will have a default timeout.
"""
client_auth_class = OAuth2ClientAuth
Expand All @@ -79,7 +82,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=None,
scope=None, state=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, default_timeout=None, **kwargs):
update_token=None, leeway=60, default_timeout=None, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
Expand All @@ -91,7 +94,7 @@ def __init__(self, client_id=None, client_secret=None,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, state=state, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
update_token=update_token, leeway=leeway, **kwargs
)

def fetch_access_token(self, url=None, **kwargs):
Expand Down
10 changes: 8 additions & 2 deletions authlib/oauth2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class OAuth2Client:
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
"""
client_auth_class = ClientAuth
token_auth_class = TokenAuth
Expand All @@ -52,7 +55,8 @@ def __init__(self, session, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, state=None, redirect_uri=None, code_challenge_method=None,
token=None, token_placement='header', update_token=None, **metadata):
token=None, token_placement='header', update_token=None, leeway=60,
**metadata):

self.session = session
self.client_id = client_id
Expand Down Expand Up @@ -97,6 +101,8 @@ def __init__(self, session, client_id=None, client_secret=None,
}
self._auth_methods = {}

self.leeway = leeway

def register_client_auth_method(self, auth):
"""Extend client authenticate for token endpoint.
Expand Down Expand Up @@ -263,7 +269,7 @@ def refresh_token(self, url=None, refresh_token=None, body='',
def ensure_active_token(self, token=None):
if token is None:
token = self.token
if not token.is_expired():
if not token.is_expired(leeway=self.leeway):
return True
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
Expand Down
3 changes: 2 additions & 1 deletion authlib/oauth2/rfc7521/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AssertionClient:

def __init__(self, session, token_endpoint, issuer, subject,
audience=None, grant_type=None, claims=None,
token_placement='header', scope=None, **kwargs):
token_placement='header', scope=None, leeway=60, **kwargs):

self.session = session

Expand All @@ -38,6 +38,7 @@ def __init__(self, session, token_endpoint, issuer, subject,
if self.token_auth_class is not None:
self.token_auth = self.token_auth_class(None, token_placement, self)
self._kwargs = kwargs
self.leeway = leeway

@property
def token(self):
Expand Down
4 changes: 4 additions & 0 deletions docs/client/oauth2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ it has expired::
>>> openid_configuration = requests.get("https://example.org/.well-known/openid-configuration").json()
>>> session = OAuth2Session(…, token_endpoint=openid_configuration["token_endpoint"])

By default, the token will be refreshed 60 seconds before its actual expiry time, to avoid clock skew issues.
You can control this behaviour by setting the ``leeway`` parameter of the :class:`~requests_client.OAuth2Session`
class.

Manually refreshing tokens
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
12 changes: 12 additions & 0 deletions tests/clients/test_requests/test_oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ def test_token_status(self):

self.assertTrue(sess.token.is_expired)

def test_token_status2(self):
token = dict(access_token='a', token_type='bearer', expires_in=10)
sess = OAuth2Session('foo', token=token, leeway=15)

self.assertTrue(sess.token.is_expired(sess.leeway))

def test_token_status3(self):
token = dict(access_token='a', token_type='bearer', expires_in=10)
sess = OAuth2Session('foo', token=token, leeway=5)

self.assertFalse(sess.token.is_expired(sess.leeway))

def test_token_expired(self):
token = dict(access_token='a', token_type='bearer', expires_at=100)
sess = OAuth2Session('foo', token=token)
Expand Down

0 comments on commit 64655bf

Please sign in to comment.