diff --git a/backend/common/authorizer.py b/backend/common/authorizer.py index 970b8bb324637..74eaa2c7e993a 100644 --- a/backend/common/authorizer.py +++ b/backend/common/authorizer.py @@ -2,14 +2,22 @@ from functools import lru_cache import requests +from requests.adapters import HTTPAdapter +from urllib3 import Retry from backend.common.corpora_config import CorporaAuthConfig from backend.common.utils.http_exceptions import UnauthorizedError from backend.common.utils.jwt import jwt_decode, get_unverified_header -auth0_session_with_retry = requests.Session() -# TODO: these read the configuration on initialization and cause problems with tests -# retry_config = Retry(total=3, backoff_factor=1, status_forcelist=CorporaAuthConfig().retry_status_forcelist) -# auth0_session_with_retry.mount("https://", HTTPAdapter(max_retries=retry_config)) +_auth0_session_with_retry = None + + +def get_auth0_session_with_retry(): + global _auth0_session_with_retry + if _auth0_session_with_retry is None: + _auth0_session_with_retry = requests.Session() + retry_config = Retry(total=3, backoff_factor=1, status_forcelist=CorporaAuthConfig().retry_status_forcelist) + _auth0_session_with_retry.mount("https://", HTTPAdapter(max_retries=retry_config)) + return _auth0_session_with_retry def assert_authorized_token(token: str, audience: str = None) -> dict: @@ -46,7 +54,7 @@ def assert_authorized_token(token: str, audience: str = None) -> dict: def get_userinfo_from_auth0(token: str) -> dict: auth_config = CorporaAuthConfig() - res = auth0_session_with_retry.get(auth_config.api_userinfo_url, headers={"Authorization": f"Bearer {token}"}) + res = get_auth0_session_with_retry().get(auth_config.api_userinfo_url, headers={"Authorization": f"Bearer {token}"}) res.raise_for_status() return res.json() @@ -57,7 +65,7 @@ def get_openid_config(openid_provider: str): :param openid_provider: the openid provider's domain. :return: the openid configuration """ - res = auth0_session_with_retry.get("{op}/.well-known/openid-configuration".format(op=openid_provider)) + res = get_auth0_session_with_retry().get("{op}/.well-known/openid-configuration".format(op=openid_provider)) res.raise_for_status() return res.json() @@ -69,5 +77,5 @@ def get_public_keys(openid_provider: str): :param openid_provider: the openid provider's domain. :return: Public Keys """ - keys = auth0_session_with_retry.get(get_openid_config(openid_provider)["jwks_uri"]).json()["keys"] + keys = get_auth0_session_with_retry().get(get_openid_config(openid_provider)["jwks_uri"]).json()["keys"] return {key["kid"]: key for key in keys} diff --git a/backend/layers/business/business.py b/backend/layers/business/business.py index 30c31861a19c6..35f4a20c40872 100644 --- a/backend/layers/business/business.py +++ b/backend/layers/business/business.py @@ -3,7 +3,6 @@ import logging from typing import Iterable, Optional, Tuple -from backend.common.providers.crossref_provider import CrossrefDOINotFoundException, CrossrefException from backend.layers.business.business_interface import BusinessLogicInterface from backend.layers.business.entities import ( CollectionMetadataUpdate, @@ -50,7 +49,11 @@ Link, ) from backend.layers.persistence.persistence_interface import DatabaseProviderInterface -from backend.layers.thirdparty.crossref_provider import CrossrefProviderInterface +from backend.layers.thirdparty.crossref_provider import ( + CrossrefDOINotFoundException, + CrossrefException, + CrossrefProviderInterface, +) from backend.layers.thirdparty.s3_provider import S3ProviderInterface from backend.layers.thirdparty.step_function_provider import StepFunctionProviderInterface from backend.layers.thirdparty.uri_provider import UriProviderInterface diff --git a/tests/unit/backend/layers/api/test_portal_api.py b/tests/unit/backend/layers/api/test_portal_api.py index 5df85b2523b7d..cef7a155c0ae8 100644 --- a/tests/unit/backend/layers/api/test_portal_api.py +++ b/tests/unit/backend/layers/api/test_portal_api.py @@ -24,7 +24,7 @@ from furl import furl -from backend.common.providers.crossref_provider import CrossrefDOINotFoundException, CrossrefFetchException +from backend.layers.thirdparty.crossref_provider import CrossrefDOINotFoundException, CrossrefFetchException from backend.portal.api.collections_common import verify_collection_body from tests.unit.backend.layers.common.base_test import ( DatasetArtifactUpdate, diff --git a/tests/unit/backend/layers/business/test_business.py b/tests/unit/backend/layers/business/test_business.py index 5542efe1c9e33..6e4c710778483 100644 --- a/tests/unit/backend/layers/business/test_business.py +++ b/tests/unit/backend/layers/business/test_business.py @@ -4,7 +4,11 @@ from unittest.mock import Mock from uuid import uuid4 -from backend.common.providers.crossref_provider import CrossrefDOINotFoundException, CrossrefException +from backend.layers.thirdparty.crossref_provider import ( + CrossrefDOINotFoundException, + CrossrefException, + CrossrefProviderInterface, +) from backend.layers.business.business import ( BusinessLogic, CollectionMetadataUpdate, @@ -36,7 +40,6 @@ ) from backend.layers.persistence.persistence import DatabaseProvider from backend.layers.persistence.persistence_mock import DatabaseProviderMock -from backend.layers.thirdparty.crossref_provider import CrossrefProviderInterface from backend.layers.thirdparty.s3_provider import S3ProviderInterface from backend.layers.thirdparty.step_function_provider import StepFunctionProviderInterface from backend.layers.thirdparty.uri_provider import FileInfo, UriProviderInterface