diff --git a/optimade/server/main.py b/optimade/server/main.py index 104272df6..24aaa6c74 100644 --- a/optimade/server/main.py +++ b/optimade/server/main.py @@ -9,6 +9,7 @@ from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.cors import CORSMiddleware from .entry_collections import MongoCollection from .config import CONFIG @@ -64,6 +65,7 @@ def load_entries(endpoint_name: str, endpoint_collection: MongoCollection): # Add various middleware app.add_middleware(RedirectSlashedURLs) +app.add_middleware(CORSMiddleware, allow_origins=["*"]) # Add various exception handlers diff --git a/optimade/server/main_index.py b/optimade/server/main_index.py index 6b29d253d..7feda6063 100644 --- a/optimade/server/main_index.py +++ b/optimade/server/main_index.py @@ -9,6 +9,7 @@ from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.cors import CORSMiddleware from .config import CONFIG from .middleware import RedirectSlashedURLs @@ -58,6 +59,7 @@ # Add various middleware app.add_middleware(RedirectSlashedURLs) +app.add_middleware(CORSMiddleware, allow_origins=["*"]) # Add various exception handlers diff --git a/tests/server/test_middleware.py b/tests/server/test_middleware.py new file mode 100644 index 000000000..608c2afb2 --- /dev/null +++ b/tests/server/test_middleware.py @@ -0,0 +1,62 @@ +# pylint: disable=relative-beyond-top-level +import unittest + +from .utils import SetClient + + +class CORSMiddlewareTest(SetClient, unittest.TestCase): + + server = "regular" + + def test_regular_CORS_request(self): + response = self.client.get("/info", headers={"Origin": "http://example.org"}) + self.assertIn( + ("access-control-allow-origin", "*"), + tuple(response.headers.items()), + msg=f"Access-Control-Allow-Origin header not found in response headers: {response.headers}", + ) + + def test_preflight_CORS_request(self): + headers = { + "Origin": "http://example.org", + "Access-Control-Request-Method": "GET", + } + response = self.client.options("/info", headers=headers) + for response_header in ( + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + ): + self.assertIn( + response_header.lower(), + list(response.headers.keys()), + msg=f"{response_header} header not found in response headers: {response.headers}", + ) + + +class IndexCORSMiddlewareTest(SetClient, unittest.TestCase): + + server = "index" + + def test_regular_CORS_request(self): + response = self.client.get("/info", headers={"Origin": "http://example.org"}) + self.assertIn( + ("access-control-allow-origin", "*"), + tuple(response.headers.items()), + msg=f"Access-Control-Allow-Origin header not found in response headers: {response.headers}", + ) + + def test_preflight_CORS_request(self): + headers = { + "Origin": "http://example.org", + "Access-Control-Request-Method": "GET", + } + response = self.client.options("/info", headers=headers) + for response_header in ( + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + ): + self.assertIn( + response_header.lower(), + list(response.headers.keys()), + msg=f"{response_header} header not found in response headers: {response.headers}", + )