diff --git a/CHANGES/517.bugfix b/CHANGES/517.bugfix new file mode 100644 index 00000000..9707802d --- /dev/null +++ b/CHANGES/517.bugfix @@ -0,0 +1 @@ +No longer loose characters when decoding incorrect percent-sequences (like ``%e2%82%f8``). All non-decodable percent-sequences are now preserved. diff --git a/setup.cfg b/setup.cfg index 23da3f7a..bb54e139 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ junit_suite_name = yarl_test_suite [flake8] -ignore = E301,E302,E704,W503,W504,F811 +ignore = E203,E301,E302,E704,W503,W504,F811 max-line-length = 88 [mypy] diff --git a/tests/test_quoting.py b/tests/test_quoting.py index 1c9e87c5..8f9e54d0 100644 --- a/tests/test_quoting.py +++ b/tests/test_quoting.py @@ -317,12 +317,27 @@ def test_unquote_unsafe4(unquoter): assert unquoter(unsafe="@")("a@b") == "a%40b" -def test_unquote_non_ascii(unquoter): - assert unquoter()("%F8") == "%F8" +@pytest.mark.parametrize( + ("input", "expected"), + [ + ("%e2%82", "%e2%82"), + ("%e2%82ac", "%e2%82ac"), + ("%e2%82%f8", "%e2%82%f8"), + ("%e2%82%2b", "%e2%82+"), + ("%e2%82%e2%82%ac", "%e2%82€"), + ("%e2%82%e2%82", "%e2%82%e2%82"), + ], +) +def test_unquote_non_utf8(unquoter, input, expected): + assert unquoter()(input) == expected + + +def test_unquote_unsafe_non_utf8(unquoter): + assert unquoter(unsafe="\n")("%e2%82%0a") == "%e2%82%0A" -def test_unquote_non_ascii_non_tailing(unquoter): - assert unquoter()("%F8ab") == "%F8ab" +def test_unquote_plus_non_utf8(unquoter): + assert unquoter(qs=True)("%e2%82%2b") == "%e2%82%2B" def test_quote_non_ascii(quoter): diff --git a/yarl/_quoting_c.pyx b/yarl/_quoting_c.pyx index cfe1d3cb..1b8bea25 100644 --- a/yarl/_quoting_c.pyx +++ b/yarl/_quoting_c.pyx @@ -5,7 +5,7 @@ from libc.string cimport memcpy, memset from cpython.exc cimport PyErr_NoMemory from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free -from cpython.unicode cimport PyUnicode_DecodeASCII +from cpython.unicode cimport PyUnicode_DecodeASCII, PyUnicode_DecodeUTF8Stateful from string import ascii_letters, digits @@ -20,7 +20,6 @@ cdef str QS = '+&=;' DEF BUF_SIZE = 8 * 1024 # 8KiB cdef char BUFFER[BUF_SIZE] - cdef inline Py_UCS4 _to_hex(uint8_t v): if v < 10: return (v+0x30) # ord('0') == 0x30 @@ -295,44 +294,60 @@ cdef class _Unquoter: cdef str _do_unquote(self, str val): if len(val) == 0: return val - cdef str last_pct = '' - cdef bytearray pcts = bytearray() cdef list ret = [] + cdef char buffer[4] + cdef Py_ssize_t buflen = 0 + cdef Py_ssize_t consumed cdef str unquoted cdef Py_UCS4 ch = 0 - cdef int idx = 0 - cdef int length = len(val) + cdef Py_ssize_t idx = 0 + cdef Py_ssize_t length = len(val) + cdef Py_ssize_t start_pct while idx < length: ch = val[idx] idx += 1 - if pcts: - try: - unquoted = pcts.decode('utf8') - except UnicodeDecodeError: - pass - else: + if ch == '%' and idx <= length - 2: + ch = _restore_ch(val[idx], val[idx + 1]) + if ch != -1: + idx += 2 + assert buflen < 4 + buffer[buflen] = ch + buflen += 1 + try: + unquoted = PyUnicode_DecodeUTF8Stateful(buffer, buflen, + NULL, &consumed) + except UnicodeDecodeError: + start_pct = idx - buflen * 3 + buffer[0] = ch + buflen = 1 + ret.append(val[start_pct : idx - 3]) + try: + unquoted = PyUnicode_DecodeUTF8Stateful(buffer, buflen, + NULL, &consumed) + except UnicodeDecodeError: + buflen = 0 + ret.append(val[idx - 3 : idx]) + continue + if not unquoted: + assert consumed == 0 + continue + assert consumed == buflen + buflen = 0 if self._qs and unquoted in '+=&;': ret.append(self._qs_quoter(unquoted)) elif unquoted in self._unsafe: ret.append(self._quoter(unquoted)) else: ret.append(unquoted) - del pcts[:] - - if ch == '%' and idx <= length - 2: - ch = _restore_ch(val[idx], val[idx + 1]) - if ch != -1: - pcts.append(ch) - last_pct = val[idx - 1 : idx + 2] - idx += 2 continue else: ch = '%' - if pcts: - ret.append(last_pct) # %F8ab - last_pct = '' + if buflen: + start_pct = idx - 1 - buflen * 3 + ret.append(val[start_pct : idx - 1]) + buflen = 0 if ch == '+': if not self._qs or ch in self._unsafe: @@ -350,16 +365,7 @@ cdef class _Unquoter: ret.append(ch) - if pcts: - try: - unquoted = pcts.decode('utf8') - except UnicodeDecodeError: - ret.append(last_pct) # %F8 - else: - if self._qs and unquoted in '+=&;': - ret.append(self._qs_quoter(unquoted)) - elif unquoted in self._unsafe: - ret.append(self._quoter(unquoted)) - else: - ret.append(unquoted) + if buflen: + ret.append(val[length - buflen * 3 : length]) + return ''.join(ret) diff --git a/yarl/_quoting_py.py b/yarl/_quoting_py.py index a6d28c08..d6f33e15 100644 --- a/yarl/_quoting_py.py +++ b/yarl/_quoting_py.py @@ -1,3 +1,4 @@ +import codecs import re from string import ascii_letters, ascii_lowercase, digits from typing import Optional, cast @@ -16,6 +17,8 @@ _IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]") _IS_HEX_STR = re.compile("[A-Fa-f0-9][A-Fa-f0-9]") +utf8_decoder = codecs.getincrementaldecoder("utf-8") + class _Quoter: def __init__( @@ -127,19 +130,30 @@ def __call__(self, val: Optional[str]) -> Optional[str]: raise TypeError("Argument should be str") if not val: return "" - last_pct = "" - pcts = bytearray() + decoder = cast(codecs.BufferedIncrementalDecoder, utf8_decoder()) ret = [] idx = 0 while idx < len(val): ch = val[idx] idx += 1 - if pcts: - try: - unquoted = pcts.decode("utf8") - except UnicodeDecodeError: - pass - else: + if ch == "%" and idx <= len(val) - 2: + pct = val[idx : idx + 2] + if _IS_HEX_STR.fullmatch(pct): + b = bytes([int(pct, base=16)]) + idx += 2 + try: + unquoted = decoder.decode(b) + except UnicodeDecodeError: + start_pct = idx - 3 - len(decoder.buffer) * 3 + ret.append(val[start_pct : idx - 3]) + decoder.reset() + try: + unquoted = decoder.decode(b) + except UnicodeDecodeError: + ret.append(val[idx - 3 : idx]) + continue + if not unquoted: + continue if self._qs and unquoted in "+=&;": to_add = self._qs_quoter(unquoted) if to_add is None: # pragma: no cover @@ -152,19 +166,12 @@ def __call__(self, val: Optional[str]) -> Optional[str]: ret.append(to_add) else: ret.append(unquoted) - del pcts[:] - - if ch == "%" and idx <= len(val) - 2: - pct = val[idx : idx + 2] # noqa: E203 - if _IS_HEX_STR.fullmatch(pct): - pcts.append(int(pct, base=16)) - last_pct = "%" + pct - idx += 2 continue - if pcts: - ret.append(last_pct) # %F8ab - last_pct = "" + if decoder.buffer: + start_pct = idx - 1 - len(decoder.buffer) * 3 + ret.append(val[start_pct : idx - 1]) + decoder.reset() if ch == "+": if not self._qs or ch in self._unsafe: @@ -182,24 +189,9 @@ def __call__(self, val: Optional[str]) -> Optional[str]: ret.append(ch) - if pcts: - try: - unquoted = pcts.decode("utf8") - except UnicodeDecodeError: - ret.append(last_pct) # %F8 - else: - if self._qs and unquoted in "+=&;": - to_add = self._qs_quoter(unquoted) - if to_add is None: # pragma: no cover - raise RuntimeError("Cannot quote None") - ret.append(to_add) - elif unquoted in self._unsafe: - to_add = self._qs_quoter(unquoted) - if to_add is None: # pragma: no cover - raise RuntimeError("Cannot quote None") - ret.append(to_add) - else: - ret.append(unquoted) + if decoder.buffer: + ret.append(val[-len(decoder.buffer) * 3 :]) + ret2 = "".join(ret) if ret2 == val: return val