Skip to content

Commit

Permalink
Merge pull request #194 from Materials-Consortia/add_CORS-middleware
Browse files Browse the repository at this point in the history
Add CORSMiddleware to both server implementations.
Use "default" settings, i.e., allow origin from anywhere (`'*'`).
Allow _only_ method `GET`.
  • Loading branch information
CasperWA committed Feb 28, 2020
2 parents 7574d14 + 368e590 commit e9dc886
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
2 changes: 2 additions & 0 deletions optimade/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.cors import CORSMiddleware

from .entry_collections import MongoCollection
from .config import CONFIG
Expand Down Expand Up @@ -64,6 +65,7 @@ def load_entries(endpoint_name: str, endpoint_collection: MongoCollection):

# Add various middleware
app.add_middleware(RedirectSlashedURLs)
app.add_middleware(CORSMiddleware, allow_origins=["*"])


# Add various exception handlers
Expand Down
2 changes: 2 additions & 0 deletions optimade/server/main_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.cors import CORSMiddleware

from .config import CONFIG
from .middleware import RedirectSlashedURLs
Expand Down Expand Up @@ -58,6 +59,7 @@

# Add various middleware
app.add_middleware(RedirectSlashedURLs)
app.add_middleware(CORSMiddleware, allow_origins=["*"])


# Add various exception handlers
Expand Down
62 changes: 62 additions & 0 deletions tests/server/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# pylint: disable=relative-beyond-top-level
import unittest

from .utils import SetClient


class CORSMiddlewareTest(SetClient, unittest.TestCase):

server = "regular"

def test_regular_CORS_request(self):
response = self.client.get("/info", headers={"Origin": "http://example.org"})
self.assertIn(
("access-control-allow-origin", "*"),
tuple(response.headers.items()),
msg=f"Access-Control-Allow-Origin header not found in response headers: {response.headers}",
)

def test_preflight_CORS_request(self):
headers = {
"Origin": "http://example.org",
"Access-Control-Request-Method": "GET",
}
response = self.client.options("/info", headers=headers)
for response_header in (
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
):
self.assertIn(
response_header.lower(),
list(response.headers.keys()),
msg=f"{response_header} header not found in response headers: {response.headers}",
)


class IndexCORSMiddlewareTest(SetClient, unittest.TestCase):

server = "index"

def test_regular_CORS_request(self):
response = self.client.get("/info", headers={"Origin": "http://example.org"})
self.assertIn(
("access-control-allow-origin", "*"),
tuple(response.headers.items()),
msg=f"Access-Control-Allow-Origin header not found in response headers: {response.headers}",
)

def test_preflight_CORS_request(self):
headers = {
"Origin": "http://example.org",
"Access-Control-Request-Method": "GET",
}
response = self.client.options("/info", headers=headers)
for response_header in (
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
):
self.assertIn(
response_header.lower(),
list(response.headers.keys()),
msg=f"{response_header} header not found in response headers: {response.headers}",
)

0 comments on commit e9dc886

Please sign in to comment.