From ee068a76d9e4ead4adf710eb43617a9a2fb5dddb Mon Sep 17 00:00:00 2001 From: Archmonger <16909269+Archmonger@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:32:43 -0700 Subject: [PATCH] AsyncSlicedFile --- src/whitenoise/responders.py | 87 +++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 21 deletions(-) diff --git a/src/whitenoise/responders.py b/src/whitenoise/responders.py index dce9318c..05bf96eb 100644 --- a/src/whitenoise/responders.py +++ b/src/whitenoise/responders.py @@ -50,11 +50,14 @@ class SlicedFile(BufferedIOBase): """ def __init__(self, fileobj, start, end): - fileobj.seek(start) self.fileobj = fileobj self.remaining = end - start + 1 + self.seeked = False def read(self, size=-1): + if not self.seeked: + self.fileobj.seek(self.start) + self.seeked = True if self.remaining <= 0: return b"" if size < 0: @@ -64,10 +67,37 @@ def read(self, size=-1): data = self.fileobj.read(size) self.remaining -= len(data) return data - + def close(self): self.fileobj.close() +class AsyncSlicedFile(BufferedIOBase): + """ + Variant of `SlicedFile` that works with async file objects. + """ + + def __init__(self, fileobj, start, end): + self.fileobj = fileobj + self.remaining = end - start + 1 + self.seeked = False + + async def read(self, size=-1): + if not self.seeked: + await self.fileobj.seek(self.start) + self.seeked = True + if self.remaining <= 0: + return b"" + if size < 0: + size = self.remaining + else: + size = min(size, self.remaining) + data = await self.fileobj.read(size) + self.remaining -= len(data) + return data + + async def close(self): + await self.fileobj.close() + class StaticFile: def __init__(self, path, headers, encodings=None, stat_cache=None): @@ -99,23 +129,8 @@ def get_response(self, method, request_headers): pass return Response(HTTPStatus.OK, headers, file_handle) - def get_range_response(self, range_header, base_headers, file_handle): - headers = [] - for item in base_headers: - if item[0] == "Content-Length": - size = int(item[1]) - else: - headers.append(item) - start, end = self.get_byte_range(range_header, size) - if start >= end: - return self.get_range_not_satisfiable_response(file_handle, size) - if file_handle is not None: - file_handle = SlicedFile(file_handle, start, end) - headers.append(("Content-Range", f"bytes {start}-{end}/{size}")) - headers.append(("Content-Length", str(end - start + 1))) - return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle) - async def aget_response(self, method, request_headers): + """Variant of `get_response` that works with async HTTP requests.""" if method not in ("GET", "HEAD"): return NOT_ALLOWED_RESPONSE if self.is_not_modified(request_headers): @@ -138,7 +153,7 @@ async def aget_response(self, method, request_headers): pass return Response(HTTPStatus.OK, headers, file_handle) - async def aget_range_response(self, range_header, base_headers, file_handle): + def get_range_response(self, range_header, base_headers, file_handle): headers = [] for item in base_headers: if item[0] == "Content-Length": @@ -148,8 +163,25 @@ async def aget_range_response(self, range_header, base_headers, file_handle): start, end = self.get_byte_range(range_header, size) if start >= end: return self.get_range_not_satisfiable_response(file_handle, size) - if file_handle is not None and start != 0: - await file_handle.seek(start) + if file_handle is not None: + file_handle = SlicedFile(file_handle, start, end) + headers.append(("Content-Range", f"bytes {start}-{end}/{size}")) + headers.append(("Content-Length", str(end - start + 1))) + return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle) + + async def aget_range_response(self, range_header, base_headers, file_handle): + """Variant of `get_range_response` that works with async file objects.""" + headers = [] + for item in base_headers: + if item[0] == "Content-Length": + size = int(item[1]) + else: + headers.append(item) + start, end = self.get_byte_range(range_header, size) + if start >= end: + return self.aget_range_not_satisfiable_response(file_handle, size) + if file_handle is not None: + file_handle = AsyncSlicedFile(file_handle, start, end) headers.append(("Content-Range", f"bytes {start}-{end}/{size}")) headers.append(("Content-Length", str(end - start + 1))) return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle) @@ -192,6 +224,19 @@ def get_range_not_satisfiable_response(file_handle, size): None, ) + @staticmethod + async def aget_range_not_satisfiable_response(file_handle, size): + """Variant of `get_range_not_satisfiable_response` that works with + async file objects.""" + if file_handle is not None: + await file_handle.close() + return Response( + HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, + [("Content-Range", f"bytes */{size}")], + None, + ) + + @staticmethod def get_file_stats(path, encodings, stat_cache): # Primary file has an encoding of None