Skip to content

Commit

Permalink
AsyncSlicedFile
Browse files Browse the repository at this point in the history
  • Loading branch information
Archmonger committed Jul 21, 2023
1 parent 7af5b32 commit ee068a7
Showing 1 changed file with 66 additions and 21 deletions.
87 changes: 66 additions & 21 deletions src/whitenoise/responders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ class SlicedFile(BufferedIOBase):
"""

def __init__(self, fileobj, start, end):
fileobj.seek(start)
self.fileobj = fileobj
self.remaining = end - start + 1
self.seeked = False

def read(self, size=-1):
if not self.seeked:
self.fileobj.seek(self.start)
self.seeked = True
if self.remaining <= 0:
return b""
if size < 0:
Expand All @@ -64,10 +67,37 @@ def read(self, size=-1):
data = self.fileobj.read(size)
self.remaining -= len(data)
return data

def close(self):
self.fileobj.close()

class AsyncSlicedFile(BufferedIOBase):
"""
Variant of `SlicedFile` that works with async file objects.
"""

def __init__(self, fileobj, start, end):
self.fileobj = fileobj
self.remaining = end - start + 1
self.seeked = False

async def read(self, size=-1):
if not self.seeked:
await self.fileobj.seek(self.start)
self.seeked = True
if self.remaining <= 0:
return b""
if size < 0:
size = self.remaining
else:
size = min(size, self.remaining)
data = await self.fileobj.read(size)
self.remaining -= len(data)
return data

async def close(self):
await self.fileobj.close()


class StaticFile:
def __init__(self, path, headers, encodings=None, stat_cache=None):
Expand Down Expand Up @@ -99,23 +129,8 @@ def get_response(self, method, request_headers):
pass
return Response(HTTPStatus.OK, headers, file_handle)

def get_range_response(self, range_header, base_headers, file_handle):
headers = []
for item in base_headers:
if item[0] == "Content-Length":
size = int(item[1])
else:
headers.append(item)
start, end = self.get_byte_range(range_header, size)
if start >= end:
return self.get_range_not_satisfiable_response(file_handle, size)
if file_handle is not None:
file_handle = SlicedFile(file_handle, start, end)
headers.append(("Content-Range", f"bytes {start}-{end}/{size}"))
headers.append(("Content-Length", str(end - start + 1)))
return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle)

async def aget_response(self, method, request_headers):
"""Variant of `get_response` that works with async HTTP requests."""
if method not in ("GET", "HEAD"):
return NOT_ALLOWED_RESPONSE
if self.is_not_modified(request_headers):
Expand All @@ -138,7 +153,7 @@ async def aget_response(self, method, request_headers):
pass
return Response(HTTPStatus.OK, headers, file_handle)

async def aget_range_response(self, range_header, base_headers, file_handle):
def get_range_response(self, range_header, base_headers, file_handle):
headers = []
for item in base_headers:
if item[0] == "Content-Length":
Expand All @@ -148,8 +163,25 @@ async def aget_range_response(self, range_header, base_headers, file_handle):
start, end = self.get_byte_range(range_header, size)
if start >= end:
return self.get_range_not_satisfiable_response(file_handle, size)
if file_handle is not None and start != 0:
await file_handle.seek(start)
if file_handle is not None:
file_handle = SlicedFile(file_handle, start, end)
headers.append(("Content-Range", f"bytes {start}-{end}/{size}"))
headers.append(("Content-Length", str(end - start + 1)))
return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle)

async def aget_range_response(self, range_header, base_headers, file_handle):
"""Variant of `get_range_response` that works with async file objects."""
headers = []
for item in base_headers:
if item[0] == "Content-Length":
size = int(item[1])
else:
headers.append(item)
start, end = self.get_byte_range(range_header, size)
if start >= end:
return self.aget_range_not_satisfiable_response(file_handle, size)
if file_handle is not None:
file_handle = AsyncSlicedFile(file_handle, start, end)
headers.append(("Content-Range", f"bytes {start}-{end}/{size}"))
headers.append(("Content-Length", str(end - start + 1)))
return Response(HTTPStatus.PARTIAL_CONTENT, headers, file_handle)
Expand Down Expand Up @@ -192,6 +224,19 @@ def get_range_not_satisfiable_response(file_handle, size):
None,
)

@staticmethod
async def aget_range_not_satisfiable_response(file_handle, size):
"""Variant of `get_range_not_satisfiable_response` that works with
async file objects."""
if file_handle is not None:
await file_handle.close()
return Response(
HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE,
[("Content-Range", f"bytes */{size}")],
None,
)


@staticmethod
def get_file_stats(path, encodings, stat_cache):
# Primary file has an encoding of None
Expand Down

0 comments on commit ee068a7

Please sign in to comment.