From c67487952feef6ba8dc8fd01790a0d369d1f2981 Mon Sep 17 00:00:00 2001 From: Casper Welzel Andersen <43357585+CasperWA@users.noreply.github.com> Date: Mon, 4 May 2020 21:03:01 +0200 Subject: [PATCH] Redirect special OpenAPI endpoints for custom base URLs (#72) Redirect all non-major version prefix URLs to major-version prefix URLs for the special OpenAPI docs endpoints: - docs_url: /extensions/docs - redoc_url: /extensions/redoc - openapi_url: /extensions/openapi.json --- aiida_optimade/main.py | 11 +++++++---- aiida_optimade/middleware.py | 36 ++++++++++++++++++++++++++++++++++++ aiida_optimade/utils.py | 23 +++++++++++++++++++++++ tests/test_server.py | 2 +- 4 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 aiida_optimade/middleware.py diff --git a/aiida_optimade/main.py b/aiida_optimade/main.py index f9a847eb..7e8e9aa1 100644 --- a/aiida_optimade/main.py +++ b/aiida_optimade/main.py @@ -16,22 +16,24 @@ from optimade.server.middleware import EnsureQueryParamIntegrity from optimade.server.routers.utils import BASE_URL_PREFIXES +from aiida_optimade.middleware import RedirectOpenApiDocs from aiida_optimade.routers import ( info, structures, ) +from aiida_optimade.utils import get_custom_base_url_path, OPEN_API_ENDPOINTS if CONFIG.debug: # pragma: no cover print("DEBUG MODE") - # Load AiiDA profile PROFILE_NAME = os.getenv("AIIDA_PROFILE") load_profile(PROFILE_NAME) if CONFIG.debug: # pragma: no cover print(f"AiiDA Profile: {PROFILE_NAME}") +DOCS_ENDPOINT_PREFIX = f"{get_custom_base_url_path()}{BASE_URL_PREFIXES['major']}" APP = FastAPI( title="OPTIMADE API for AiiDA", description=( @@ -43,15 +45,16 @@ "reproducible." ), version=__api_version__, - docs_url=f"{BASE_URL_PREFIXES['major']}/extensions/docs", - redoc_url=f"{BASE_URL_PREFIXES['major']}/extensions/redoc", - openapi_url=f"{BASE_URL_PREFIXES['major']}/extensions/openapi.json", + docs_url=f"{DOCS_ENDPOINT_PREFIX}{OPEN_API_ENDPOINTS['docs']}", + redoc_url=f"{DOCS_ENDPOINT_PREFIX}{OPEN_API_ENDPOINTS['redoc']}", + openapi_url=f"{DOCS_ENDPOINT_PREFIX}{OPEN_API_ENDPOINTS['openapi']}", ) # Add various middleware APP.add_middleware(CORSMiddleware, allow_origins=["*"]) APP.add_middleware(EnsureQueryParamIntegrity) +APP.add_middleware(RedirectOpenApiDocs) # Add various exception handlers diff --git a/aiida_optimade/middleware.py b/aiida_optimade/middleware.py new file mode 100644 index 00000000..125f77b9 --- /dev/null +++ b/aiida_optimade/middleware.py @@ -0,0 +1,36 @@ +import urllib.parse + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import RedirectResponse + +from optimade.server.routers.utils import BASE_URL_PREFIXES + +from aiida_optimade.utils import OPEN_API_ENDPOINTS + + +class RedirectOpenApiDocs(BaseHTTPMiddleware): + """Redirect URLs from non-major version prefix URLs to major-version prefix URLs + + This is relevant for the OpenAPI JSON, Docs, and ReDocs URLs. + """ + + async def dispatch(self, request: Request, call_next): + parsed_url = urllib.parse.urlsplit(str(request.url)) + for endpoint in OPEN_API_ENDPOINTS.values(): + # Important to start with the longest (or full) URL prefix first. + for version_prefix in [ + BASE_URL_PREFIXES["patch"], + BASE_URL_PREFIXES["minor"], + ]: + if parsed_url.path.endswith(f"{version_prefix}{endpoint}"): + new_path = parsed_url.path.replace( + f"{version_prefix}", f"{BASE_URL_PREFIXES['major']}" + ) + redirect_url = ( + f"{parsed_url.scheme}://{parsed_url.netloc}{new_path}" + f"?{parsed_url.query}" + ) + return RedirectResponse(redirect_url) + response = await call_next(request) + return response diff --git a/aiida_optimade/utils.py b/aiida_optimade/utils.py index 154435bc..72ea504e 100644 --- a/aiida_optimade/utils.py +++ b/aiida_optimade/utils.py @@ -1,4 +1,14 @@ from typing import Tuple +import urllib.parse + +from optimade.server.config import CONFIG + + +OPEN_API_ENDPOINTS = { + "docs": "/extensions/docs", + "redoc": "/extensions/redoc", + "openapi": "/extensions/openapi.json", +} def retrieve_queryable_properties( @@ -30,3 +40,16 @@ def retrieve_queryable_properties( properties[name][extra_key] = value[extra_key] return properties, all_properties + + +def get_custom_base_url_path(): + """Return path part of custom base URL""" + if CONFIG.base_url is not None: + res = urllib.parse.urlparse(CONFIG.base_url).path + else: + res = urllib.parse.urlparse(CONFIG.base_url).path.decode() + + if res.endswith("/"): + res = res[:-1] + + return res diff --git a/tests/test_server.py b/tests/test_server.py index 2209d805..7511f4ea 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,7 +14,7 @@ # Use specific AiiDA profile if os.getenv("AIIDA_PROFILE", None) is None: - os.environ["AIIDA_PROFILE"] = "optimade_v1_aiida_sqla" + os.environ["AIIDA_PROFILE"] = "optimade_sqla" from optimade.models import ( ResponseMeta,