Skip to content

Commit

Permalink
feat: support dynamic response content types
Browse files Browse the repository at this point in the history
Switch these off of the accept header within a request. This allows
users to dynamically render a response based on the clients passed
value, with a fallback to whatever media type is specified by default.
The list is implicitly ordered by priority. I don't love that it's a
nested loop, however, I think that it should be fine, for relatively
few content types.
  • Loading branch information
bradykieffer committed Sep 17, 2024
1 parent 32d5df4 commit 64ce1c4
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 29 deletions.
10 changes: 6 additions & 4 deletions ninja/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,16 +457,18 @@ def create_response(
response.content = content
else:
response = HttpResponse(
content, status=status, content_type=self.get_content_type()
content, status=status, content_type=self.get_content_type(request)
)

return response

def create_temporal_response(self, request: HttpRequest) -> HttpResponse:
return HttpResponse("", content_type=self.get_content_type())
return HttpResponse("", content_type=self.get_content_type(request))

def get_content_type(self) -> str:
return f"{self.renderer.media_type}; charset={self.renderer.charset}"
def get_content_type(self, request: HttpRequest) -> str:
return (
f"{self.renderer.get_media_type(request)}; charset={self.renderer.charset}"
)

def get_openapi_schema(
self,
Expand Down
25 changes: 23 additions & 2 deletions ninja/renderers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
import itertools
import json
from typing import Any, Mapping, Optional, Type
from typing import Any, List, Mapping, Type

from django.http import HttpRequest
from django.http.request import parse_accept_header

from ninja.responses import NinjaJSONEncoder

__all__ = ["BaseRenderer", "JSONRenderer"]


class BaseRenderer:
media_type: Optional[str] = None
media_type: str
charset: str = "utf-8"

def get_media_type(self, request: HttpRequest) -> str:
return self.media_type

def render(self, request: HttpRequest, data: Any, *, response_status: int) -> Any:
raise NotImplementedError("Please implement .render() method")


class BaseDynamicRenderer(BaseRenderer):
media_types: List[str]

def get_media_type(self, request: HttpRequest) -> str:
accepted_media_types = parse_accept_header(request.headers.get("accept", "*/*"))
media_type_gen = (
media_type
for media_type, accepted_type in itertools.product(
self.media_types, accepted_media_types
)
if accepted_type.match(media_type)
)

return next(media_type_gen, self.media_type)


class JSONRenderer(BaseRenderer):
media_type = "application/json"
encoder_class: Type[json.JSONEncoder] = NinjaJSONEncoder
Expand Down
135 changes: 112 additions & 23 deletions tests/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
import json
from io import StringIO
from unittest.mock import Mock

import pytest
from django.utils.encoding import force_str
from django.utils.xmlutils import SimplerXMLGenerator

from ninja import NinjaAPI
from ninja.renderers import BaseRenderer
from ninja.renderers import BaseDynamicRenderer, BaseRenderer
from ninja.responses import NinjaJSONEncoder
from ninja.testing import TestClient


def _to_xml(xml, data):
if isinstance(data, (list, tuple)):
for item in data:
xml.startElement("item", {})
_to_xml(xml, item)
xml.endElement("item")

elif isinstance(data, dict):
for key, value in data.items():
xml.startElement(key, {})
_to_xml(xml, value)
xml.endElement(key)

elif data is None:
# Don't output any value
pass

else:
xml.characters(force_str(data))


class XMLRenderer(BaseRenderer):
media_type = "text/xml"

Expand All @@ -17,41 +41,55 @@ def render(self, request, data, *, response_status):
xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument()
xml.startElement("data", {})
self._to_xml(xml, data)
_to_xml(xml, data)
xml.endElement("data")
xml.endDocument()
return stream.getvalue()

def _to_xml(self, xml, data):
if isinstance(data, (list, tuple)):
for item in data:
xml.startElement("item", {})
self._to_xml(xml, item)
xml.endElement("item")

elif isinstance(data, dict):
for key, value in data.items():
xml.startElement(key, {})
self._to_xml(xml, value)
xml.endElement(key)
class CSVRenderer(BaseRenderer):
media_type = "text/csv"

elif data is None:
# Don't output any value
pass
def render(self, request, data, *, response_status):
content = [",".join(data[0].keys())]
for item in data:
content.append(",".join(item.values()))
return "\n".join(content)

else:
xml.characters(force_str(data))


class CSVRenderer(BaseRenderer):
media_type = "text/csv"
class DynamicRenderer(BaseDynamicRenderer):
media_type = "application/json"
media_types = ["application/json", "text/csv", "text/xml"]

def render(self, request, data, *, response_status):
accept = request.headers.get("accept", "application/json")

if accept.startswith("text/xml"):
return self.render_xml(data)
elif accept.startswith("text/csv"):
return self.render_csv(data)
else:
return self.render_json(data)

def render_csv(self, data):
content = [",".join(data[0].keys())]
for item in data:
content.append(",".join(item.values()))
return "\n".join(content)

def render_xml(self, data):
stream = StringIO()
xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument()
xml.startElement("data", {})
_to_xml(xml, data)
xml.endElement("data")
xml.endDocument()
return stream.getvalue()

def render_json(self, data):
return json.dumps(data, cls=NinjaJSONEncoder)


def operation(request):
return [
Expand All @@ -62,10 +100,12 @@ def operation(request):

api_xml = NinjaAPI(renderer=XMLRenderer())
api_csv = NinjaAPI(renderer=CSVRenderer())
api_dynamic = NinjaAPI(renderer=DynamicRenderer())


api_xml.get("/test")(operation)
api_csv.get("/test")(operation)
api_dynamic.get("/test")(operation)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -94,10 +134,59 @@ def test_response_class(api, content_type, expected_content):
assert response.content.decode() == expected_content


def test_implment_render():
class FooRenderer(BaseRenderer):
@pytest.mark.parametrize(
"accept,expected_content",
[
(
"text/xml; charset=utf-8",
'<?xml version="1.0" encoding="utf-8"?>\n<data>'
"<item><name>Jonathan</name><lastname>Doe</lastname></item>"
"<item><name>Sarah</name><lastname>Calvin</lastname></item>"
"</data>",
),
(
"text/csv; charset=utf-8",
"name,lastname\nJonathan,Doe\nSarah,Calvin",
),
(
"application/json; charset=utf-8",
'[{"name": "Jonathan", "lastname": "Doe"}, {"name": "Sarah", "lastname": "Calvin"}]',
),
],
)
def test_dynamic_response_class(accept, expected_content):
client = TestClient(api_dynamic)
response = client.get("/test", headers={"Accept": accept})
assert response.status_code == 200
assert response["Content-Type"] == accept
assert response.content.decode() == expected_content


@pytest.mark.parametrize("Base", [BaseRenderer, BaseDynamicRenderer])
def test_implement_render(Base):
class FooRenderer(Base):
pass

renderer = FooRenderer()
with pytest.raises(NotImplementedError):
renderer.render(None, None, response_status=200)


@pytest.mark.parametrize(
"accept,expected_media_type",
[
("text/xml", "text/xml"),
("text/csv", "text/csv"),
("*/*", "text/xml"),
("blahblahblah", "text/xml"),
],
)
def test_get_media_type(accept, expected_media_type):
class FooRenderer(BaseDynamicRenderer):
media_type = "text/xml"
media_types = ["text/xml", "text/csv"]

request = Mock()
request.headers = {"accept": accept}

assert FooRenderer().get_media_type(request) == expected_media_type

0 comments on commit 64ce1c4

Please sign in to comment.