Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] Add base test classes that support the test proxy #24937

Merged
merged 1 commit into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tools/azure-sdk-tools/devtools_testutils/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .api_version_policy import ApiVersionAssertPolicy
from .service_versions import service_version_map, ServiceVersion, is_version_before
from .testcase import StorageTestCase, LogCaptured
from .testcase import StorageTestCase, StorageRecordedTestCase, LogCaptured

__all__ = ["ApiVersionAssertPolicy", "service_version_map", "StorageTestCase", "ServiceVersion", "is_version_before",
"LogCaptured"]
__all__ = [
"ApiVersionAssertPolicy",
"service_version_map",
"StorageTestCase",
"StorageRecordedTestCase",
"ServiceVersion",
"is_version_before",
"LogCaptured"
]
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .asynctestcase import AsyncStorageTestCase
from .asynctestcase import AsyncStorageTestCase, AsyncStorageRecordedTestCase

__all__ = ["AsyncStorageTestCase"]
__all__ = ["AsyncStorageTestCase", "AsyncStorageRecordedTestCase"]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import functools

from .. import StorageTestCase
from .. import StorageTestCase, StorageRecordedTestCase
from ...fake_credentials_async import AsyncFakeCredential

from azure_devtools.scenario_tests.patches import mock_in_unit_test
Expand Down Expand Up @@ -67,3 +67,34 @@ def generate_oauth_token(self):

def generate_fake_token(self):
return AsyncFakeCredential()


class AsyncStorageRecordedTestCase(StorageRecordedTestCase):

@staticmethod
def await_prepared_test(test_fn):
"""Synchronous wrapper for async test methods. Used to avoid making changes
upstream to AbstractPreparer (which doesn't await the functions it wraps)
"""

@functools.wraps(test_fn)
def run(test_class_instance, *args, **kwargs):
trim_kwargs_from_test_function(test_fn, kwargs)
loop = asyncio.get_event_loop()
return loop.run_until_complete(test_fn(test_class_instance, **kwargs))

return run

def generate_oauth_token(self):
if self.is_live:
from azure.identity.aio import ClientSecretCredential

return ClientSecretCredential(
self.get_settings_value("TENANT_ID"),
self.get_settings_value("CLIENT_ID"),
self.get_settings_value("CLIENT_SECRET"),
)
return self.generate_fake_token()

def generate_fake_token(self):
return AsyncFakeCredential()
176 changes: 170 additions & 6 deletions tools/azure-sdk-tools/devtools_testutils/storage/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import division

from datetime import datetime, timedelta
from io import StringIO
import logging
import math
import os
Expand All @@ -15,17 +16,14 @@
import time
import zlib

from devtools_testutils import AzureTestCase
import pytest

from devtools_testutils import AzureTestCase, AzureRecordedTestCase

from .processors import XMSRequestIDBody
from . import ApiVersionAssertPolicy, service_version_map
from .. import FakeTokenCredential

try:
from cStringIO import StringIO # Python 2
except ImportError:
from io import StringIO

try:
from azure.storage.blob import generate_account_sas, AccountSasPermissions, ResourceTypes
except:
Expand All @@ -39,6 +37,19 @@
ENABLE_LOGGING = True


def generate_sas_token():
fake_key = "a" * 30 + "b" * 30

return "?" + generate_account_sas(
account_name="test", # name of the storage account
account_key=fake_key, # key for the storage account
resource_types=ResourceTypes(object=True),
permission=AccountSasPermissions(read=True, list=True),
start=datetime.now() - timedelta(hours=24),
expiry=datetime.now() + timedelta(days=8),
)


class StorageTestCase(AzureTestCase):
def __init__(self, *args, **kwargs):
super(StorageTestCase, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -209,6 +220,159 @@ def create_storage_client_from_conn_str(self, client, *args, **kwargs):
return client.from_connection_string(*args, **kwargs)


class StorageRecordedTestCase(AzureRecordedTestCase):

def setup_class(cls):
cls.logger = logging.getLogger("azure.storage")
cls.sas_token = generate_sas_token()

def setup_method(self, _):
self.configure_logging()

def connection_string(self, account_name, key):
return (
"DefaultEndpointsProtocol=https;AcCounTName="
+ account_name
+ ";AccOuntKey="
+ str(key)
+ ";EndpoIntSuffix=core.windows.net"
)

def account_url(self, storage_account, storage_type):
"""Return an url of storage account.

:param str storage_account: Storage account name
:param str storage_type: The Storage type part of the URL. Should be "blob", or "queue", etc.
"""
protocol = os.environ.get("PROTOCOL", "https")
suffix = os.environ.get("ACCOUNT_URL_SUFFIX", "core.windows.net")
return f"{protocol}://{storage_account}.{storage_type}.{suffix}"

def configure_logging(self):
enable_logging = ENABLE_LOGGING

self.enable_logging() if enable_logging else self.disable_logging()

def enable_logging(self):
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
self.logger.handlers = [handler]
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = True
self.logger.disabled = False

def disable_logging(self):
self.logger.propagate = False
self.logger.disabled = True
self.logger.handlers = []

def get_random_bytes(self, size):
# recordings don't like random stuff. making this more
# deterministic.
return b"a" * size

def get_random_text_data(self, size):
"""Returns random unicode text data exceeding the size threshold for
chunking blob upload."""
checksum = zlib.adler32(self.qualified_test_name.encode()) & 0xFFFFFFFF
rand = random.Random(checksum)
text = u""
words = [u"hello", u"world", u"python", u"啊齄丂狛狜"]
while len(text) < size:
index = int(rand.random() * (len(words) - 1))
text = text + u" " + words[index]

return text

@staticmethod
def _set_test_proxy(service, settings):
if settings.USE_PROXY:
service.set_proxy(
settings.PROXY_HOST,
settings.PROXY_PORT,
settings.PROXY_USER,
settings.PROXY_PASSWORD,
)

def assertNamedItemInContainer(self, container, item_name, msg=None):
def _is_string(obj):
return isinstance(obj, str)

for item in container:
if _is_string(item):
if item == item_name:
return
elif isinstance(item, dict):
if item_name == item["name"]:
return
elif item.name == item_name:
return
elif hasattr(item, "snapshot") and item.snapshot == item_name:
return

error_message = f"{repr(item_name)} not found in {[str(c) for c in container]}"
pytest.fail(error_message)

def assertNamedItemNotInContainer(self, container, item_name, msg=None):
for item in container:
if item.name == item_name:
error_message = f"{repr(item_name)} unexpectedly found in {repr(container)}"
pytest.fail(error_message)

def assert_upload_progress(self, size, max_chunk_size, progress, unknown_size=False):
"""Validates that the progress chunks align with our chunking procedure."""
total = None if unknown_size else size
small_chunk_size = size % max_chunk_size
assert len(progress) == math.ceil(size / max_chunk_size)
for i in progress:
assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
assert i[1] == total

def assert_download_progress(self, size, max_chunk_size, max_get_size, progress):
"""Validates that the progress chunks align with our chunking procedure."""
if size <= max_get_size:
assert len(progress) == 1
assert progress[0][0], size
assert progress[0][1], size
else:
small_chunk_size = (size - max_get_size) % max_chunk_size
assert len(progress) == 1 + math.ceil((size - max_get_size) / max_chunk_size)

assert progress[0][0], max_get_size
assert progress[0][1], size
for i in progress[1:]:
assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
assert i[1] == size

def generate_oauth_token(self):
if self.is_live:
from azure.identity import ClientSecretCredential

return ClientSecretCredential(
self.get_settings_value("TENANT_ID"),
self.get_settings_value("CLIENT_ID"),
self.get_settings_value("CLIENT_SECRET"),
)
return self.generate_fake_token()

def generate_fake_token(self):
return FakeTokenCredential()

def _get_service_version(self, **kwargs):
env_version = service_version_map.get(os.environ.get("AZURE_LIVE_TEST_SERVICE_VERSION", "LATEST"))
return kwargs.pop("service_version", env_version)

def create_storage_client(self, client, *args, **kwargs):
kwargs["api_version"] = self._get_service_version(**kwargs)
kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])]
return client(*args, **kwargs)

def create_storage_client_from_conn_str(self, client, *args, **kwargs):
kwargs["api_version"] = self._get_service_version(**kwargs)
kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])]
return client.from_connection_string(*args, **kwargs)


class LogCaptured(object):
def __init__(self, test_case=None):
# accept the test case so that we may reset logging after capturing logs
Expand Down