Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add read method to StorageStreamDownloader #24275

Merged
merged 16 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import warnings
from io import BytesIO
from typing import Generic, Iterator, TypeVar
from typing import Generic, Iterator, Optional, TypeVar

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError
Expand Down Expand Up @@ -335,6 +335,7 @@ def __init__(
self._non_empty_ranges = None
self._response = None
self._encryption_data = None
self._offset = 0
jalauzon-msft marked this conversation as resolved.
Show resolved Hide resolved

# The cls is passed in via download_cls to avoid conflicting arg name with Generic.__new__
# but needs to be changed to cls in the request options.
Expand Down Expand Up @@ -552,6 +553,69 @@ def chunks(self):
downloader=iter_downloader,
chunk_size=self._config.max_chunk_get_size)

def read(self, size: Optional[int] = -1) -> T:
"""
Read up to size bytes from the object and return them. If size
is specified as -1, all bytes will be read.
"""
if size == -1:
return self.readall()
if size == 0 or self.size == 0:
data = b''
if self._encoding:
return data.decode(self._encoding)
return data

stream = BytesIO()
remaining_size = size

# Start by reading from current_content if there is data left
if self._offset < len(self._current_content):
start = self._offset
end = min(remaining_size, len(self._current_content) - self._offset)
read = stream.write(self._current_content[start:end])

remaining_size -= read
self._offset += read

if remaining_size > 0:
end_range = min(self._offset + remaining_size, self.size)
parallel = self._max_concurrency > 1
downloader = _ChunkDownloader(
client=self._clients.blob,
non_empty_ranges=self._non_empty_ranges,
total_size=remaining_size,
chunk_size=self._config.max_chunk_get_size,
current_progress=self._offset,
start_range=self._offset,
end_range=end_range,
stream=stream,
parallel=parallel,
validate_content=self._validate_content,
encryption_options=self._encryption_options,
encryption_data=self._encryption_data,
use_location=self._location_mode,
**self._request_options
)

if parallel:
jalauzon-msft marked this conversation as resolved.
Show resolved Hide resolved
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(self._max_concurrency) as executor:
list(executor.map(
with_current_context(downloader.process_chunk),
downloader.get_chunk_offsets()
))
else:
for chunk in downloader.get_chunk_offsets():
downloader.process_chunk(chunk)

self._offset += remaining_size

data = stream.getvalue()
if self._encoding:
return data.decode(self._encoding)
return data

def readall(self):
# type: () -> T
"""Download the contents of this blob.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from io import BytesIO
from itertools import islice
from typing import AsyncIterator, Generic, TypeVar
from typing import AsyncIterator, Generic, Optional, TypeVar

import asyncio
from aiohttp import ClientPayloadError
Expand Down Expand Up @@ -243,6 +243,7 @@ def __init__(
self._non_empty_ranges = None
self._response = None
self._encryption_data = None
self._offset = 0

self._initial_range = None
self._initial_offset = None
Expand Down Expand Up @@ -456,6 +457,88 @@ def chunks(self):
downloader=iter_downloader,
chunk_size=self._config.max_chunk_get_size)

async def read(self, size: Optional[int] = -1) -> T:
"""
Read up to size bytes from the object and return them. If size
is specified as -1, all bytes will be read.
"""
if size == -1:
return await self.readall()
if size == 0 or self.size == 0:
data = b''
if self._encoding:
return data.decode(self._encoding)
return data

stream = BytesIO()
remaining_size = size

# Start by reading from current_content if there is data left
if self._offset < len(self._current_content):
start = self._offset
end = min(remaining_size, len(self._current_content) - self._offset)
read = stream.write(self._current_content[start:end])

remaining_size -= read
self._offset += read

if remaining_size > 0:
end_range = min(self._offset + remaining_size, self.size)
parallel = self._max_concurrency > 1
downloader = _AsyncChunkDownloader(
client=self._clients.blob,
non_empty_ranges=self._non_empty_ranges,
total_size=remaining_size,
chunk_size=self._config.max_chunk_get_size,
current_progress=self._offset,
start_range=self._offset,
end_range=end_range,
stream=stream,
parallel=parallel,
validate_content=self._validate_content,
encryption_options=self._encryption_options,
encryption_data=self._encryption_data,
use_location=self._location_mode,
**self._request_options
)

dl_tasks = downloader.get_chunk_offsets()
running_futures = [
asyncio.ensure_future(downloader.process_chunk(d))
for d in islice(dl_tasks, 0, self._max_concurrency)
]
while running_futures:
# Wait for some download to finish before adding a new one
done, running_futures = await asyncio.wait(
running_futures, return_when=asyncio.FIRST_COMPLETED)
try:
for task in done:
task.result()
except HttpResponseError as error:
process_storage_error(error)
try:
next_chunk = next(dl_tasks)
except StopIteration:
break
else:
running_futures.add(asyncio.ensure_future(downloader.process_chunk(next_chunk)))

if running_futures:
# Wait for the remaining downloads to finish
done, _running_futures = await asyncio.wait(running_futures)
try:
for task in done:
task.result()
except HttpResponseError as error:
process_storage_error(error)

self._offset += remaining_size

data = stream.getvalue()
if self._encoding:
return data.decode(self._encoding)
return data

async def readall(self):
# type: () -> T
"""Download the contents of this blob.
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions sdk/storage/azure-storage-blob/tests/test_blob_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
from io import StringIO, BytesIO
from json import loads
from math import ceil
from os import (
urandom,
path,
Expand Down Expand Up @@ -696,5 +697,30 @@ def test_get_blob_to_star(self, storage_account_name, storage_account_key):
self.assertEqual(self.bytes, stream_blob.read())
self.assertEqual(self.bytes.decode(), text_blob)

@pytest.mark.live_test_only
@BlobPreparer()
def test_get_blob_read(self, storage_account_name, storage_account_key):
self._setup(storage_account_name, storage_account_key)
self.bsc.require_encryption = True
self.bsc.key_encryption_key = KeyWrapper('key1')

data = b'12345' * 205 * 25 # 25625 bytes
blob = self.bsc.get_blob_client(self.container_name, self._get_blob_reference(BlobType.BLOCKBLOB))
blob.upload_blob(data, overwrite=True)
stream = blob.download_blob(max_concurrency=3)

# Act
result = bytearray()
read_size = 3000
num_chunks = int(ceil(len(data) / read_size))
for i in range(num_chunks):
content = stream.read(read_size)
start = i * read_size
end = start + read_size
assert data[start:end] == content
result.extend(content)

# Assert
assert result == data

# ------------------------------------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from io import StringIO, BytesIO
from json import loads
from math import ceil
from os import (
urandom,
path,
Expand Down Expand Up @@ -750,5 +751,31 @@ async def test_get_blob_to_star_async(self, storage_account_name, storage_accoun
self.assertEqual(self.bytes, stream_blob.read())
self.assertEqual(self.bytes.decode(), text_blob)

@pytest.mark.live_test_only
@BlobPreparer()
async def test_get_blob_read(self, storage_account_name, storage_account_key):
await self._setup(storage_account_name, storage_account_key)
self.bsc.require_encryption = True
self.bsc.key_encryption_key = KeyWrapper('key1')

data = b'12345' * 205 * 25 # 25625 bytes
blob = self.bsc.get_blob_client(self.container_name, self._get_blob_reference(BlobType.BLOCKBLOB))
await blob.upload_blob(data, overwrite=True)
stream = await blob.download_blob(max_concurrency=3)

# Act
result = bytearray()
read_size = 3000
num_chunks = int(ceil(len(data) / read_size))
for i in range(num_chunks):
content = await stream.read(read_size)
start = i * read_size
end = start + read_size
assert data[start:end] == content
result.extend(content)

# Assert
assert result == data


# ------------------------------------------------------------------------------
35 changes: 35 additions & 0 deletions sdk/storage/azure-storage-blob/tests/test_blob_encryption_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import os
from json import dumps, loads
from math import ceil

import pytest
from azure.core import MatchConditions
Expand Down Expand Up @@ -756,6 +757,40 @@ def test_get_blob_using_chunks_iter(self, storage_account_name, storage_account_
# Assert
self.assertEqual(len(content), total)

@pytest.mark.live_test_only
@BlobPreparer()
def test_get_blob_using_read(self, storage_account_name, storage_account_key):
self._setup(storage_account_name, storage_account_key)
kek = KeyWrapper('key1')
bsc = BlobServiceClient(
self.account_url(storage_account_name, "blob"),
credential=storage_account_key,
max_single_get_size=4 * MiB,
max_chunk_get_size=4 * MiB,
require_encryption=True,
encryption_version='2.0',
key_encryption_key=kek)

blob = bsc.get_blob_client(self.container_name, self._get_blob_reference())
data = b'abcde' * 4 * MiB # 20 MiB
blob.upload_blob(data, overwrite=True)

# Act
stream = blob.download_blob(max_concurrency=2)

result = bytearray()
read_size = 5 * MiB
num_chunks = int(ceil(len(data) / read_size))
for i in range(num_chunks):
content = stream.read(read_size)
start = i * read_size
end = start + read_size
assert data[start:end] == content
result.extend(content)

# Assert
assert result == data

@pytest.mark.skip(reason="Intended for manual testing due to blob size.")
@pytest.mark.live_test_only
@BlobPreparer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import os
from json import dumps, loads
from math import ceil

import pytest
from azure.core import MatchConditions
Expand Down Expand Up @@ -756,6 +757,40 @@ async def test_get_blob_using_chunks_iter(self, storage_account_name, storage_ac
# Assert
self.assertEqual(len(content), total)

@pytest.mark.live_test_only
@BlobPreparer()
async def test_get_blob_using_read(self, storage_account_name, storage_account_key):
await self._setup(storage_account_name, storage_account_key)
kek = KeyWrapper('key1')
bsc = BlobServiceClient(
self.account_url(storage_account_name, "blob"),
credential=storage_account_key,
max_single_get_size=4 * MiB,
max_chunk_get_size=4 * MiB,
require_encryption=True,
encryption_version='2.0',
key_encryption_key=kek)

blob = bsc.get_blob_client(self.container_name, self._get_blob_reference())
data = b'abcde' * 4 * MiB # 20 MiB
await blob.upload_blob(data, overwrite=True)

# Act
stream = await blob.download_blob(max_concurrency=2)

result = bytearray()
read_size = 5 * MiB
num_chunks = int(ceil(len(data) / read_size))
for i in range(num_chunks):
content = await stream.read(read_size)
start = i * read_size
end = start + read_size
assert data[start:end] == content
result.extend(content)

# Assert
assert result == data

@pytest.mark.skip(reason="Intended for manual testing due to blob size.")
@pytest.mark.live_test_only
@BlobPreparer()
Expand Down
Loading