Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance, especially in data with many CR-LF #137

Merged
merged 4 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 61 additions & 65 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,29 +976,11 @@ def __init__(
# Setup marks. These are used to track the state of data received.
self.marks: dict[str, int] = {}

# TODO: Actually use this rather than the dumb version we currently use
# # Precompute the skip table for the Boyer-Moore-Horspool algorithm.
# skip = [len(boundary) for x in range(256)]
# for i in range(len(boundary) - 1):
# skip[ord_char(boundary[i])] = len(boundary) - i - 1
#
# # We use a tuple since it's a constant, and marginally faster.
# self.skip = tuple(skip)

# Save our boundary.
if isinstance(boundary, str): # pragma: no cover
boundary = boundary.encode("latin-1")
self.boundary = b"\r\n--" + boundary

# Get a set of characters that belong to our boundary.
self.boundary_chars = frozenset(self.boundary)

# We also create a lookbehind list.
# Note: the +8 is since we can have, at maximum, "\r\n--" + boundary +
# "--\r\n" at the final boundary, and the length of '\r\n--' and
# '--\r\n' is 8 bytes.
self.lookbehind = [NULL for _ in range(len(boundary) + 8)]

def write(self, data: bytes) -> int:
"""Write some data to the parser, which will perform size verification,
and then parse the data into the appropriate location (e.g. header,
Expand Down Expand Up @@ -1061,21 +1043,43 @@ def delete_mark(name: str, reset: bool = False) -> None:
# end of the buffer, and reset the mark, instead of deleting it. This
# is used at the end of the function to call our callbacks with any
# remaining data in this chunk.
def data_callback(name: str, remaining: bool = False) -> None:
def data_callback(name: str, end_i: int, remaining: bool = False) -> None:
marked_index = self.marks.get(name)
if marked_index is None:
return

# If we're getting remaining data, we ignore the current i value
# and just call with the remaining data.
if remaining:
self.callback(name, data, marked_index, length)
self.marks[name] = 0

# Otherwise, we call it from the mark to the current byte we're
# processing.
if (end_i <= marked_index):
# There is no additional data to send.
pass
elif (marked_index >= 0):
# We are emitting data from the local buffer.
self.callback(name, data, marked_index, end_i)
else:
# Some of the data comes from a partial boundary match.
# and requires look-behind.
# We need to use self.flags (and not flags) because we care about
# the state when we entered the loop.
lookbehind_len = -marked_index
if (lookbehind_len <= len(boundary)):
self.callback(name, boundary, 0, lookbehind_len)
elif (self.flags & FLAG_PART_BOUNDARY):
lookback = boundary + b"\r\n"
self.callback(name, lookback, 0, lookbehind_len)
elif (self.flags & FLAG_LAST_BOUNDARY):
lookback = boundary + b"--\r\n"
self.callback(name, lookback, 0, lookbehind_len)
else: # pragma: no cover (error case)
self.logger.warning("Look-back buffer error")

if end_i > 0:
self.callback(name, data, 0, end_i)
# If we're getting remaining data, we have got all the data we
# can be certain is not a boundary, leaving only a partial boundary match.
if remaining:
self.marks[name] = end_i - length
else:
self.callback(name, data, marked_index, i)
self.marks.pop(name, None)

# For each byte...
Expand Down Expand Up @@ -1183,7 +1187,7 @@ def data_callback(name: str, remaining: bool = False) -> None:
raise e

# Call our callback with the header field.
data_callback("header_field")
data_callback("header_field", i)

# Move to parsing the header value.
state = MultipartState.HEADER_VALUE_START
Expand Down Expand Up @@ -1212,7 +1216,7 @@ def data_callback(name: str, remaining: bool = False) -> None:
# If we've got a CR, we're nearly done our headers. Otherwise,
# we do nothing and just move past this character.
if c == CR:
data_callback("header_value")
data_callback("header_value", i)
self.callback("header_end")
state = MultipartState.HEADER_VALUE_ALMOST_DONE

Expand Down Expand Up @@ -1256,46 +1260,46 @@ def data_callback(name: str, remaining: bool = False) -> None:
# We're processing our part data right now. During this, we
# need to efficiently search for our boundary, since any data
# on any number of lines can be a part of the current data.
# We use the Boyer-Moore-Horspool algorithm to efficiently
# search through the remainder of the buffer looking for our
# boundary.

# Save the current value of our index. We use this in case we
# find part of a boundary, but it doesn't match fully.
prev_index = index

# Set up variables.
boundary_length = len(boundary)
boundary_end = boundary_length - 1
data_length = length
boundary_chars = self.boundary_chars

# If our index is 0, we're starting a new part, so start our
# search.
if index == 0:
# Search forward until we either hit the end of our buffer,
# or reach a character that's in our boundary.
i += boundary_end
while i < data_length - 1 and data[i] not in boundary_chars:
i += boundary_length

# Reset i back the length of our boundary, which is the
# earliest possible location that could be our match (i.e.
# if we've just broken out of our loop since we saw the
# last character in our boundary)
i -= boundary_end
# The most common case is likely to be that the whole
# boundary is present in the buffer.
# Calling `find` is much faster than iterating here.
i0 = data.find(boundary, i, data_length)
if i0 >= 0:
# We matched the whole boundary string.
index = boundary_length - 1
i = i0 + boundary_length - 1
else:
# No match found for whole string.
# There may be a partial boundary at the end of the
# data, which the find will not match.
# Since the length should to be searched is limited to
# the boundary length, just perform a naive search.
i = max(i, data_length - boundary_length)

# Search forward until we either hit the end of our buffer,
# or reach a potential start of the boundary.
while i < data_length - 1 and data[i] != boundary[0]:
i += 1

c = data[i]

# Now, we have a couple of cases here. If our index is before
# the end of the boundary...
if index < boundary_length:
# If the character matches...
if boundary[index] == c:
# If we found a match for our boundary, we send the
# existing data.
if index == 0:
data_callback("part_data")

# The current character matches, so continue!
index += 1
else:
Expand Down Expand Up @@ -1332,6 +1336,8 @@ def data_callback(name: str, remaining: bool = False) -> None:
# Unset the part boundary flag.
flags &= ~FLAG_PART_BOUNDARY

# We have identified a boundary, callback for any data before it.
data_callback("part_data", i - index)
# Callback indicating that we've reached the end of
# a part, and are starting a new one.
self.callback("part_end")
Expand All @@ -1353,6 +1359,8 @@ def data_callback(name: str, remaining: bool = False) -> None:
elif flags & FLAG_LAST_BOUNDARY:
# We need a second hyphen here.
if c == HYPHEN:
# We have identified a boundary, callback for any data before it.
data_callback("part_data", i - index)
# Callback to end the current part, and then the
# message.
self.callback("part_end")
Expand All @@ -1362,26 +1370,14 @@ def data_callback(name: str, remaining: bool = False) -> None:
# No match, so reset index.
index = 0

# If we have an index, we need to keep this byte for later, in
# case we can't match the full boundary.
if index > 0:
self.lookbehind[index - 1] = c

# Otherwise, our index is 0. If the previous index is not, it
# means we reset something, and we need to take the data we
# thought was part of our boundary and send it along as actual
# data.
elif prev_index > 0:
# Callback to write the saved data.
lb_data = join_bytes(self.lookbehind)
self.callback("part_data", lb_data, 0, prev_index)

if index == 0 and prev_index > 0:
# Overwrite our previous index.
prev_index = 0

# Re-set our mark for part data.
set_mark("part_data")

# Re-consider the current character, since this could be
# the start of the boundary itself.
i -= 1
Expand Down Expand Up @@ -1410,9 +1406,9 @@ def data_callback(name: str, remaining: bool = False) -> None:
# that we haven't yet reached the end of this 'thing'. So, by setting
# the mark to 0, we cause any data callbacks that take place in future
# calls to this function to start from the beginning of that buffer.
data_callback("header_field", True)
data_callback("header_value", True)
data_callback("part_data", True)
data_callback("header_field", length, True)
data_callback("header_value", length, True)
data_callback("part_data", length - index, True)

# Save values to locals.
self.state = state
Expand Down
35 changes: 28 additions & 7 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,14 @@ def test_not_aligned(self):

http_tests.append({"name": fname, "test": test_data, "result": yaml_data})

# Datasets used for single-byte writing test.
single_byte_tests = [
"almost_match_boundary",
"almost_match_boundary_without_CR",
"almost_match_boundary_without_LF",
"almost_match_boundary_without_final_hyphen",
"single_field_single_file",
]

def split_all(val):
"""
Expand Down Expand Up @@ -843,17 +851,19 @@ def test_random_splitting(self):
self.assert_field(b"field", b"test1")
self.assert_file(b"file", b"file.txt", b"test2")

def test_feed_single_bytes(self):
@parametrize("param", [ t for t in http_tests if t["name"] in single_byte_tests])
def test_feed_single_bytes(self, param):
"""
This test parses a simple multipart body 1 byte at a time.
This test parses multipart bodies 1 byte at a time.
"""
# Load test data.
test_file = "single_field_single_file.http"
test_file = param["name"] + ".http"
boundary = param["result"]["boundary"]
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
test_data = f.read()

# Create form parser.
self.make("boundary")
self.make(boundary)

# Write all bytes.
# NOTE: Can't simply do `for b in test_data`, since that gives
Expand All @@ -868,9 +878,20 @@ def test_feed_single_bytes(self):
# Assert we processed everything.
self.assertEqual(i, len(test_data))

# Assert that our file and field are here.
self.assert_field(b"field", b"test1")
self.assert_file(b"file", b"file.txt", b"test2")
# Assert that the parser gave us the appropriate fields/files.
for e in param["result"]["expected"]:
# Get our type and name.
type = e["type"]
name = e["name"].encode("latin-1")

if type == "field":
self.assert_field(name, e["data"])

elif type == "file":
self.assert_file(name, e["file_name"].encode("latin-1"), e["data"])

else:
assert False

def test_feed_blocks(self):
"""
Expand Down