diff --git a/django_unicorn/components/unicorn_template_response.py b/django_unicorn/components/unicorn_template_response.py index 09c342d4..a6363c9f 100644 --- a/django_unicorn/components/unicorn_template_response.py +++ b/django_unicorn/components/unicorn_template_response.py @@ -1,10 +1,10 @@ import logging from django.template.response import TemplateResponse -from django.utils.safestring import mark_safe import orjson from bs4 import BeautifulSoup +from bs4.dammit import EntitySubstitution from bs4.element import Tag from bs4.formatter import HTMLFormatter @@ -22,6 +22,9 @@ class UnsortedAttributes(HTMLFormatter): Prevent beautifulsoup from re-ordering attributes. """ + def __init__(self): + super().__init__(entity_substitution=EntitySubstitution.substitute_html) + def attributes(self, tag: Tag): for k, v in tag.attrs.items(): yield k, v @@ -115,7 +118,6 @@ def render(self): root_element.insert_after(t) rendered_template = UnicornTemplateResponse._desoupify(soup) - rendered_template = mark_safe(rendered_template) self.component.rendered(rendered_template) response.content = rendered_template diff --git a/django_unicorn/components/unicorn_view.py b/django_unicorn/components/unicorn_view.py index 138d9596..58e4d27d 100644 --- a/django_unicorn/components/unicorn_view.py +++ b/django_unicorn/components/unicorn_view.py @@ -10,7 +10,6 @@ from django.core.exceptions import ImproperlyConfigured from django.db.models import Model from django.http import HttpRequest -from django.utils.html import conditional_escape from django.views.generic.base import TemplateView from cachetools.lru import LRUCache @@ -341,14 +340,6 @@ def get_frontend_context_variables(self) -> str: if field_name in frontend_context_variables: del frontend_context_variables[field_name] - safe_fields = [] - # Keep a list of fields that are safe to not sanitize from `frontend_context_variables` - if hasattr(self, "Meta") and hasattr(self.Meta, "safe"): - if isinstance(self.Meta.safe, Sequence): - for field_name in self.Meta.safe: - if field_name in frontend_context_variables: - safe_fields.append(field_name) - # Add cleaned values to `frontend_content_variables` based on the widget in form's fields form = self._get_form(attributes) @@ -372,18 +363,6 @@ def get_frontend_context_variables(self) -> str: ): frontend_context_variables[key] = value - for ( - frontend_context_variable_key, - frontend_context_variable_value, - ) in frontend_context_variables.items(): - if ( - isinstance(frontend_context_variable_value, str) - and frontend_context_variable_key not in safe_fields - ): - frontend_context_variables[frontend_context_variable_key] = escape( - frontend_context_variable_value - ) - encoded_frontend_context_variables = serializer.dumps( frontend_context_variables ) diff --git a/django_unicorn/views/__init__.py b/django_unicorn/views/__init__.py index a810a8a7..ebaa8b16 100644 --- a/django_unicorn/views/__init__.py +++ b/django_unicorn/views/__init__.py @@ -1,11 +1,12 @@ import copy import logging from functools import wraps -from typing import Dict +from typing import Dict, Sequence from django.core.cache import caches from django.http import HttpRequest, JsonResponse from django.http.response import HttpResponseNotModified +from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_protect from django.views.decorators.http import require_POST @@ -126,6 +127,20 @@ def _process_component_request( # Re-load frontend context variables to deal with non-serializable properties component_request.data = orjson.loads(component.get_frontend_context_variables()) + # Get set of attributes that should be marked as `safe` + safe_fields = [] + if hasattr(component, "Meta") and hasattr(component.Meta, "safe"): + if isinstance(component.Meta.safe, Sequence): + for field_name in component.Meta.safe: + if field_name in component._attributes().keys(): + safe_fields.append(field_name) + + # Mark safe attributes as such before rendering + for field_name in safe_fields: + value = getattr(component, field_name) + if isinstance(value, str): + setattr(component, field_name, mark_safe(value)) + # Send back all available data for reset or refresh actions updated_data = component_request.data diff --git a/tests/components/test_component.py b/tests/components/test_component.py index f143724b..02e1184d 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -82,39 +82,6 @@ def test_get_frontend_context_variables(component): assert frontend_context_variables_dict.get("name") == "World" -def test_get_frontend_context_variables_xss(component): - # Set component.name to a potential XSS attack - component.name = '' - - frontend_context_variables = component.get_frontend_context_variables() - frontend_context_variables_dict = orjson.loads(frontend_context_variables) - assert len(frontend_context_variables_dict) == 1 - assert ( - frontend_context_variables_dict.get("name") - == "<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>" - ) - - -def test_get_frontend_context_variables_safe(component): - # Set component.name to a potential XSS attack - component.name = '' - - class Meta: - safe = [ - "name", - ] - - setattr(component, "Meta", Meta()) - - frontend_context_variables = component.get_frontend_context_variables() - frontend_context_variables_dict = orjson.loads(frontend_context_variables) - assert len(frontend_context_variables_dict) == 1 - assert ( - frontend_context_variables_dict.get("name") - == '' - ) - - def test_get_context_data(component): context_data = component.get_context_data() assert ( diff --git a/tests/components/test_unicorn_template_response.py b/tests/components/test_unicorn_template_response.py index 30b908b7..e0f63332 100644 --- a/tests/components/test_unicorn_template_response.py +++ b/tests/components/test_unicorn_template_response.py @@ -1,7 +1,10 @@ import pytest from bs4 import BeautifulSoup -from django_unicorn.components.unicorn_template_response import get_root_element +from django_unicorn.components.unicorn_template_response import ( + UnicornTemplateResponse, + get_root_element, +) def test_get_root_element(): @@ -44,3 +47,14 @@ def test_get_root_element_no_element(): actual = get_root_element(soup) assert str(actual) == expected + + +def test_desoupify(): + html = "
<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>!\n
\n\n" + expected = "
<a><style>@keyframes x{}</style><a style=\"animation-name:x\" onanimationend=\"alert(1)\"></a>!\n
\n" + + soup = BeautifulSoup(html, "html.parser") + + actual = UnicornTemplateResponse._desoupify(soup) + + assert expected == actual diff --git a/tests/templates/test_component_kwargs.html b/tests/templates/test_component_kwargs.html index 66396cf9..301274bf 100644 --- a/tests/templates/test_component_kwargs.html +++ b/tests/templates/test_component_kwargs.html @@ -1,3 +1,3 @@
- ->{{ hello }}<- + {{ hello }}
\ No newline at end of file diff --git a/tests/templates/test_component_kwargs_with_html_entity.html b/tests/templates/test_component_kwargs_with_html_entity.html new file mode 100644 index 00000000..dc4c9a1b --- /dev/null +++ b/tests/templates/test_component_kwargs_with_html_entity.html @@ -0,0 +1,3 @@ +
+ ->{{ hello }}<- +
\ No newline at end of file diff --git a/tests/templates/test_component_variable.html b/tests/templates/test_component_variable.html new file mode 100644 index 00000000..adf2bbcc --- /dev/null +++ b/tests/templates/test_component_variable.html @@ -0,0 +1,3 @@ +
+ {{ hello }} +
\ No newline at end of file diff --git a/tests/templatetags/test_unicorn_render.py b/tests/templatetags/test_unicorn_render.py index f112cea3..96daf712 100644 --- a/tests/templatetags/test_unicorn_render.py +++ b/tests/templatetags/test_unicorn_render.py @@ -23,6 +23,15 @@ def __init__(self, *args, **kwargs): self.hello = kwargs.get("test_kwarg") +class FakeComponentKwargsWithHtmlEntity(UnicornView): + template_name = "templates/test_component_kwargs_with_html_entity.html" + hello = "world" + + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + self.hello = kwargs.get("test_kwarg") + + class FakeComponentModel(UnicornView): template_name = "templates/test_component_model.html" model_id = None @@ -55,7 +64,7 @@ def test_unicorn_render_kwarg(): context = {} actual = unicorn_node.render(context) - assert "->tested!<-" in actual + assert "tested!" in actual def test_unicorn_render_context_variable(): @@ -67,7 +76,19 @@ def test_unicorn_render_context_variable(): context = {"test_var": {"nested": "variable!"}} actual = unicorn_node.render(context) - assert "->variable!<-" in actual + assert "variable!" in actual + + +def test_unicorn_render_with_invalid_html(): + token = Token( + TokenType.TEXT, + "unicorn 'tests.templatetags.test_unicorn_render.FakeComponentKwargsWithHtmlEntity' test_kwarg=test_var.nested", + ) + unicorn_node = unicorn(None, token) + context = {"test_var": {"nested": "variable!"}} + actual = unicorn_node.render(context) + + assert "->variable!<-" in actual def test_unicorn_render_parent(settings): diff --git a/tests/views/message/test_sync_input.py b/tests/views/message/test_sync_input.py index d411e953..719f56d1 100644 --- a/tests/views/message/test_sync_input.py +++ b/tests/views/message/test_sync_input.py @@ -1,5 +1,3 @@ -import orjson - from tests.views.message.utils import post_and_get_response diff --git a/tests/views/test_process_component_request.py b/tests/views/test_process_component_request.py new file mode 100644 index 00000000..9ab26b20 --- /dev/null +++ b/tests/views/test_process_component_request.py @@ -0,0 +1,51 @@ +from django_unicorn.components import UnicornView +from tests.views.message.utils import post_and_get_response + + +class FakeComponent(UnicornView): + template_name = "templates/test_component_variable.html" + + hello = "" + + +class FakeComponentSafe(UnicornView): + template_name = "templates/test_component_variable.html" + + hello = "" + + class Meta: + safe = ("hello",) + + +def test_html_entities_encoded(client): + data = {"hello": "test"} + action_queue = [ + {"payload": {"name": "hello", "value": "test1"}, "type": "syncInput",} + ] + response = post_and_get_response( + client, + url="/message/tests.views.test_process_component_request.FakeComponent", + data=data, + action_queue=action_queue, + ) + + assert not response["errors"] + assert response["data"].get("hello") == "test1" + assert "<b>test1</b>" in response["dom"] + + +def test_safe_html_entities_not_encoded(client): + data = {"hello": "test"} + action_queue = [ + {"payload": {"name": "hello", "value": "test1"}, "type": "syncInput",} + ] + response = post_and_get_response( + client, + url="/message/tests.views.test_process_component_request.FakeComponentSafe", + data=data, + action_queue=action_queue, + ) + + assert not response["errors"] + assert response["data"].get("hello") == "test1" + assert "test1" in response["dom"]