diff --git a/django_unicorn/components/unicorn_template_response.py b/django_unicorn/components/unicorn_template_response.py index 09c342d4..a6363c9f 100644 --- a/django_unicorn/components/unicorn_template_response.py +++ b/django_unicorn/components/unicorn_template_response.py @@ -1,10 +1,10 @@ import logging from django.template.response import TemplateResponse -from django.utils.safestring import mark_safe import orjson from bs4 import BeautifulSoup +from bs4.dammit import EntitySubstitution from bs4.element import Tag from bs4.formatter import HTMLFormatter @@ -22,6 +22,9 @@ class UnsortedAttributes(HTMLFormatter): Prevent beautifulsoup from re-ordering attributes. """ + def __init__(self): + super().__init__(entity_substitution=EntitySubstitution.substitute_html) + def attributes(self, tag: Tag): for k, v in tag.attrs.items(): yield k, v @@ -115,7 +118,6 @@ def render(self): root_element.insert_after(t) rendered_template = UnicornTemplateResponse._desoupify(soup) - rendered_template = mark_safe(rendered_template) self.component.rendered(rendered_template) response.content = rendered_template diff --git a/django_unicorn/components/unicorn_view.py b/django_unicorn/components/unicorn_view.py index 138d9596..58e4d27d 100644 --- a/django_unicorn/components/unicorn_view.py +++ b/django_unicorn/components/unicorn_view.py @@ -10,7 +10,6 @@ from django.core.exceptions import ImproperlyConfigured from django.db.models import Model from django.http import HttpRequest -from django.utils.html import conditional_escape from django.views.generic.base import TemplateView from cachetools.lru import LRUCache @@ -341,14 +340,6 @@ def get_frontend_context_variables(self) -> str: if field_name in frontend_context_variables: del frontend_context_variables[field_name] - safe_fields = [] - # Keep a list of fields that are safe to not sanitize from `frontend_context_variables` - if hasattr(self, "Meta") and hasattr(self.Meta, "safe"): - if isinstance(self.Meta.safe, Sequence): - for field_name in self.Meta.safe: - if field_name in frontend_context_variables: - safe_fields.append(field_name) - # Add cleaned values to `frontend_content_variables` based on the widget in form's fields form = self._get_form(attributes) @@ -372,18 +363,6 @@ def get_frontend_context_variables(self) -> str: ): frontend_context_variables[key] = value - for ( - frontend_context_variable_key, - frontend_context_variable_value, - ) in frontend_context_variables.items(): - if ( - isinstance(frontend_context_variable_value, str) - and frontend_context_variable_key not in safe_fields - ): - frontend_context_variables[frontend_context_variable_key] = escape( - frontend_context_variable_value - ) - encoded_frontend_context_variables = serializer.dumps( frontend_context_variables ) diff --git a/django_unicorn/views/__init__.py b/django_unicorn/views/__init__.py index a810a8a7..ebaa8b16 100644 --- a/django_unicorn/views/__init__.py +++ b/django_unicorn/views/__init__.py @@ -1,11 +1,12 @@ import copy import logging from functools import wraps -from typing import Dict +from typing import Dict, Sequence from django.core.cache import caches from django.http import HttpRequest, JsonResponse from django.http.response import HttpResponseNotModified +from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_protect from django.views.decorators.http import require_POST @@ -126,6 +127,20 @@ def _process_component_request( # Re-load frontend context variables to deal with non-serializable properties component_request.data = orjson.loads(component.get_frontend_context_variables()) + # Get set of attributes that should be marked as `safe` + safe_fields = [] + if hasattr(component, "Meta") and hasattr(component.Meta, "safe"): + if isinstance(component.Meta.safe, Sequence): + for field_name in component.Meta.safe: + if field_name in component._attributes().keys(): + safe_fields.append(field_name) + + # Mark safe attributes as such before rendering + for field_name in safe_fields: + value = getattr(component, field_name) + if isinstance(value, str): + setattr(component, field_name, mark_safe(value)) + # Send back all available data for reset or refresh actions updated_data = component_request.data diff --git a/tests/components/test_component.py b/tests/components/test_component.py index f143724b..02e1184d 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -82,39 +82,6 @@ def test_get_frontend_context_variables(component): assert frontend_context_variables_dict.get("name") == "World" -def test_get_frontend_context_variables_xss(component): - # Set component.name to a potential XSS attack - component.name = '' - - frontend_context_variables = component.get_frontend_context_variables() - frontend_context_variables_dict = orjson.loads(frontend_context_variables) - assert len(frontend_context_variables_dict) == 1 - assert ( - frontend_context_variables_dict.get("name") - == "<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>" - ) - - -def test_get_frontend_context_variables_safe(component): - # Set component.name to a potential XSS attack - component.name = '' - - class Meta: - safe = [ - "name", - ] - - setattr(component, "Meta", Meta()) - - frontend_context_variables = component.get_frontend_context_variables() - frontend_context_variables_dict = orjson.loads(frontend_context_variables) - assert len(frontend_context_variables_dict) == 1 - assert ( - frontend_context_variables_dict.get("name") - == '' - ) - - def test_get_context_data(component): context_data = component.get_context_data() assert ( diff --git a/tests/components/test_unicorn_template_response.py b/tests/components/test_unicorn_template_response.py index 30b908b7..e0f63332 100644 --- a/tests/components/test_unicorn_template_response.py +++ b/tests/components/test_unicorn_template_response.py @@ -1,7 +1,10 @@ import pytest from bs4 import BeautifulSoup -from django_unicorn.components.unicorn_template_response import get_root_element +from django_unicorn.components.unicorn_template_response import ( + UnicornTemplateResponse, + get_root_element, +) def test_get_root_element(): @@ -44,3 +47,14 @@ def test_get_root_element_no_element(): actual = get_root_element(soup) assert str(actual) == expected + + +def test_desoupify(): + html = "