diff --git a/google/cloud/storage/blob.py b/google/cloud/storage/blob.py index 0d663e775..4c493485f 100644 --- a/google/cloud/storage/blob.py +++ b/google/cloud/storage/blob.py @@ -1697,7 +1697,7 @@ def _get_writable_metadata(self): return object_metadata - def _get_upload_arguments(self, client, content_type): + def _get_upload_arguments(self, client, content_type, filename=None): """Get required arguments for performing an upload. The content type returned will be determined in order of precedence: @@ -1716,7 +1716,7 @@ def _get_upload_arguments(self, client, content_type): * An object metadata dictionary * The ``content_type`` as a string (according to precedence) """ - content_type = self._get_content_type(content_type) + content_type = self._get_content_type(content_type, filename=filename) headers = { **_get_default_headers(client._connection.user_agent, content_type), **_get_encryption_headers(self._encryption_key), diff --git a/google/cloud/storage/transfer_manager.py b/google/cloud/storage/transfer_manager.py index 0b65702d4..5cb9b6c46 100644 --- a/google/cloud/storage/transfer_manager.py +++ b/google/cloud/storage/transfer_manager.py @@ -26,6 +26,12 @@ from google.api_core import exceptions from google.cloud.storage import Client from google.cloud.storage import Blob +from google.cloud.storage.blob import _get_host_name +from google.cloud.storage.constants import _DEFAULT_TIMEOUT + +from google.resumable_media.requests.upload import XMLMPUContainer +from google.resumable_media.requests.upload import XMLMPUPart + warnings.warn( "The module `transfer_manager` is a preview feature. Functionality and API " @@ -35,7 +41,14 @@ TM_DEFAULT_CHUNK_SIZE = 32 * 1024 * 1024 DEFAULT_MAX_WORKERS = 8 - +METADATA_HEADER_TRANSLATION = { + "cacheControl": "Cache-Control", + "contentDisposition": "Content-Disposition", + "contentEncoding": "Content-Encoding", + "contentLanguage": "Content-Language", + "customTime": "x-goog-custom-time", + "storageClass": "x-goog-storage-class", +} # Constants to be passed in as `worker_type`. PROCESS = "process" @@ -198,7 +211,7 @@ def upload_many( futures.append( executor.submit( _call_method_on_maybe_pickled_blob, - _pickle_blob(blob) if needs_pickling else blob, + _pickle_client(blob) if needs_pickling else blob, "upload_from_filename" if isinstance(path_or_file, str) else "upload_from_file", @@ -343,7 +356,7 @@ def download_many( futures.append( executor.submit( _call_method_on_maybe_pickled_blob, - _pickle_blob(blob) if needs_pickling else blob, + _pickle_client(blob) if needs_pickling else blob, "download_to_filename" if isinstance(path_or_file, str) else "download_to_file", @@ -733,7 +746,6 @@ def download_chunks_concurrently( Checksumming (md5 or crc32c) is not supported for chunked operations. Any `checksum` parameter passed in to download_kwargs will be ignored. - :type bucket: 'google.cloud.storage.bucket.Bucket' :param bucket: The bucket which contains the blobs to be downloaded @@ -745,6 +757,12 @@ def download_chunks_concurrently( :param filename: The destination filename or path. + :type chunk_size: int + :param chunk_size: + The size in bytes of each chunk to send. The optimal chunk size for + maximum throughput may vary depending on the exact network environment + and size of the blob. + :type download_kwargs: dict :param download_kwargs: A dictionary of keyword arguments to pass to the download method. Refer @@ -809,7 +827,7 @@ def download_chunks_concurrently( pool_class, needs_pickling = _get_pool_class_and_requirements(worker_type) # Pickle the blob ahead of time (just once, not once per chunk) if needed. - maybe_pickled_blob = _pickle_blob(blob) if needs_pickling else blob + maybe_pickled_blob = _pickle_client(blob) if needs_pickling else blob futures = [] @@ -844,9 +862,249 @@ def download_chunks_concurrently( return None +def upload_chunks_concurrently( + filename, + blob, + content_type=None, + chunk_size=TM_DEFAULT_CHUNK_SIZE, + deadline=None, + worker_type=PROCESS, + max_workers=DEFAULT_MAX_WORKERS, + *, + checksum="md5", + timeout=_DEFAULT_TIMEOUT, +): + """Upload a single file in chunks, concurrently. + + This function uses the XML MPU API to initialize an upload and upload a + file in chunks, concurrently with a worker pool. + + The XML MPU API is significantly different from other uploads; please review + the documentation at https://cloud.google.com/storage/docs/multipart-uploads + before using this feature. + + The library will attempt to cancel uploads that fail due to an exception. + If the upload fails in a way that precludes cancellation, such as a + hardware failure, process termination, or power outage, then the incomplete + upload may persist indefinitely. To mitigate this, set the + `AbortIncompleteMultipartUpload` with a nonzero `Age` in bucket lifecycle + rules, or refer to the XML API documentation linked above to learn more + about how to list and delete individual downloads. + + Using this feature with multiple threads is unlikely to improve upload + performance under normal circumstances due to Python interpreter threading + behavior. The default is therefore to use processes instead of threads. + + ACL information cannot be sent with this function and should be set + separately with :class:`ObjectACL` methods. + + :type filename: str + :param filename: + The path to the file to upload. File-like objects are not supported. + + :type blob: `google.cloud.storage.Blob` + :param blob: + The blob to which to upload. + + :type content_type: str + :param content_type: (Optional) Type of content being uploaded. + + :type chunk_size: int + :param chunk_size: + The size in bytes of each chunk to send. The optimal chunk size for + maximum throughput may vary depending on the exact network environment + and size of the blob. The remote API has restrictions on the minimum + and maximum size allowable, see: https://cloud.google.com/storage/quotas#requests + + :type deadline: int + :param deadline: + The number of seconds to wait for all threads to resolve. If the + deadline is reached, all threads will be terminated regardless of their + progress and concurrent.futures.TimeoutError will be raised. This can be + left as the default of None (no deadline) for most use cases. + + :type worker_type: str + :param worker_type: + The worker type to use; one of google.cloud.storage.transfer_manager.PROCESS + or google.cloud.storage.transfer_manager.THREAD. + + Although the exact performance impact depends on the use case, in most + situations the PROCESS worker type will use more system resources (both + memory and CPU) and result in faster operations than THREAD workers. + + Because the subprocesses of the PROCESS worker type can't access memory + from the main process, Client objects have to be serialized and then + recreated in each subprocess. The serialization of the Client object + for use in subprocesses is an approximation and may not capture every + detail of the Client object, especially if the Client was modified after + its initial creation or if `Client._http` was modified in any way. + + THREAD worker types are observed to be relatively efficient for + operations with many small files, but not for operations with large + files. PROCESS workers are recommended for large file operations. + + :type max_workers: int + :param max_workers: + The maximum number of workers to create to handle the workload. + + With PROCESS workers, a larger number of workers will consume more + system resources (memory and CPU) at once. + + How many workers is optimal depends heavily on the specific use case, + and the default is a conservative number that should work okay in most + cases without consuming excessive resources. + + :type checksum: str + :param checksum: + (Optional) The checksum scheme to use: either 'md5', 'crc32c' or None. + Each individual part is checksummed. At present, the selected checksum + rule is only applied to parts and a separate checksum of the entire + resulting blob is not computed. Please compute and compare the checksum + of the file to the resulting blob separately if needed, using the + 'crc32c' algorithm as per the XML MPU documentation. + + :type timeout: float or tuple + :param timeout: + (Optional) The amount of time, in seconds, to wait + for the server response. See: :ref:`configuring_timeouts` + + :raises: :exc:`concurrent.futures.TimeoutError` if deadline is exceeded. + """ + + bucket = blob.bucket + client = blob.client + transport = blob._get_transport(client) + + hostname = _get_host_name(client._connection) + url = "{hostname}/{bucket}/{blob}".format( + hostname=hostname, bucket=bucket.name, blob=blob.name + ) + + base_headers, object_metadata, content_type = blob._get_upload_arguments( + client, content_type, filename=filename + ) + headers = {**base_headers, **_headers_from_metadata(object_metadata)} + + if blob.user_project is not None: + headers["x-goog-user-project"] = blob.user_project + + # When a Customer Managed Encryption Key is used to encrypt Cloud Storage object + # at rest, object resource metadata will store the version of the Key Management + # Service cryptographic material. If a Blob instance with KMS Key metadata set is + # used to upload a new version of the object then the existing kmsKeyName version + # value can't be used in the upload request and the client instead ignores it. + if blob.kms_key_name is not None and "cryptoKeyVersions" not in blob.kms_key_name: + headers["x-goog-encryption-kms-key-name"] = blob.kms_key_name + + container = XMLMPUContainer(url, filename, headers=headers) + container.initiate(transport=transport, content_type=content_type) + upload_id = container.upload_id + + size = os.path.getsize(filename) + num_of_parts = -(size // -chunk_size) # Ceiling division + + pool_class, needs_pickling = _get_pool_class_and_requirements(worker_type) + # Pickle the blob ahead of time (just once, not once per chunk) if needed. + maybe_pickled_client = _pickle_client(client) if needs_pickling else client + + futures = [] + + with pool_class(max_workers=max_workers) as executor: + + for part_number in range(1, num_of_parts + 1): + start = (part_number - 1) * chunk_size + end = min(part_number * chunk_size, size) + + futures.append( + executor.submit( + _upload_part, + maybe_pickled_client, + url, + upload_id, + filename, + start=start, + end=end, + part_number=part_number, + checksum=checksum, + headers=headers, + ) + ) + + concurrent.futures.wait( + futures, timeout=deadline, return_when=concurrent.futures.ALL_COMPLETED + ) + + try: + # Harvest results and raise exceptions. + for future in futures: + part_number, etag = future.result() + container.register_part(part_number, etag) + + container.finalize(blob._get_transport(client)) + except Exception: + container.cancel(blob._get_transport(client)) + raise + + +def _upload_part( + maybe_pickled_client, + url, + upload_id, + filename, + start, + end, + part_number, + checksum, + headers, +): + """Helper function that runs inside a thread or subprocess to upload a part. + + `maybe_pickled_client` is either a Client (for threads) or a specially + pickled Client (for processes) because the default pickling mangles Client + objects.""" + + if isinstance(maybe_pickled_client, Client): + client = maybe_pickled_client + else: + client = pickle.loads(maybe_pickled_client) + part = XMLMPUPart( + url, + upload_id, + filename, + start=start, + end=end, + part_number=part_number, + checksum=checksum, + headers=headers, + ) + part.upload(client._http) + return (part_number, part.etag) + + +def _headers_from_metadata(metadata): + """Helper function to translate object metadata into a header dictionary.""" + + headers = {} + # Handle standard writable metadata + for key, value in metadata.items(): + if key in METADATA_HEADER_TRANSLATION: + headers[METADATA_HEADER_TRANSLATION[key]] = value + # Handle custom metadata + if "metadata" in metadata: + for key, value in metadata["metadata"].items(): + headers["x-goog-meta-" + key] = value + return headers + + def _download_and_write_chunk_in_place( maybe_pickled_blob, filename, start, end, download_kwargs ): + """Helper function that runs inside a thread or subprocess. + + `maybe_pickled_blob` is either a Blob (for threads) or a specially pickled + Blob (for processes) because the default pickling mangles Client objects + which are attached to Blobs.""" + if isinstance(maybe_pickled_blob, Blob): blob = maybe_pickled_blob else: @@ -863,9 +1121,9 @@ def _call_method_on_maybe_pickled_blob( ): """Helper function that runs inside a thread or subprocess. - `maybe_pickled_blob` is either a blob (for threads) or a specially pickled - blob (for processes) because the default pickling mangles clients which are - attached to blobs.""" + `maybe_pickled_blob` is either a Blob (for threads) or a specially pickled + Blob (for processes) because the default pickling mangles Client objects + which are attached to Blobs.""" if isinstance(maybe_pickled_blob, Blob): blob = maybe_pickled_blob @@ -894,8 +1152,8 @@ def _reduce_client(cl): ) -def _pickle_blob(blob): - """Pickle a Blob (and its Bucket and Client) and return a bytestring.""" +def _pickle_client(obj): + """Pickle a Client or an object that owns a Client (like a Blob)""" # We need a custom pickler to process Client objects, which are attached to # Buckets (and therefore to Blobs in turn). Unfortunately, the Python @@ -907,7 +1165,7 @@ def _pickle_blob(blob): p = pickle.Pickler(f) p.dispatch_table = copyreg.dispatch_table.copy() p.dispatch_table[Client] = _reduce_client - p.dump(blob) + p.dump(obj) return f.getvalue() diff --git a/setup.py b/setup.py index e2b5cc7a4..a57f972ff 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ "google-auth >= 1.25.0, < 3.0dev", "google-api-core >= 1.31.5, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0", "google-cloud-core >= 2.3.0, < 3.0dev", - "google-resumable-media >= 2.3.2", + "google-resumable-media >= 2.6.0", "requests >= 2.18.0, < 3.0.0dev", ] extras = {"protobuf": ["protobuf<5.0.0dev"]} diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 26d5c785e..fe90ceb80 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -46,6 +46,21 @@ ebh_bucket_iteration = 0 +_key_name_format = "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}" + +keyring_name = "gcs-test" +default_key_name = "gcs-test" +alt_key_name = "gcs-test-alternate" + + +def _kms_key_name(client, bucket, key_name): + return _key_name_format.format( + client.project, + bucket.location.lower(), + keyring_name, + key_name, + ) + @pytest.fixture(scope="session") def storage_client(): @@ -218,3 +233,27 @@ def file_data(): file_data["hash"] = _base64_md5hash(file_obj) return _file_data + + +@pytest.fixture(scope="session") +def kms_bucket_name(): + return _helpers.unique_name("gcp-systest-kms") + + +@pytest.fixture(scope="session") +def kms_bucket(storage_client, kms_bucket_name, no_mtls): + bucket = _helpers.retry_429_503(storage_client.create_bucket)(kms_bucket_name) + + yield bucket + + _helpers.delete_bucket(bucket) + + +@pytest.fixture(scope="session") +def kms_key_name(storage_client, kms_bucket): + return _kms_key_name(storage_client, kms_bucket, default_key_name) + + +@pytest.fixture(scope="session") +def alt_kms_key_name(storage_client, kms_bucket): + return _kms_key_name(storage_client, kms_bucket, alt_key_name) diff --git a/tests/system/test_transfer_manager.py b/tests/system/test_transfer_manager.py index bc7e0d31e..fc7bc2d51 100644 --- a/tests/system/test_transfer_manager.py +++ b/tests/system/test_transfer_manager.py @@ -16,6 +16,8 @@ import tempfile import os +import pytest + from google.cloud.storage import transfer_manager from google.cloud.storage._helpers import _base64_md5hash @@ -23,6 +25,16 @@ DEADLINE = 30 +encryption_key = "b23ff11bba187db8c37077e6af3b25b8" + + +def _check_blob_hash(blob, info): + md5_hash = blob.md5_hash + if not isinstance(md5_hash, bytes): + md5_hash = md5_hash.encode("utf-8") + + assert md5_hash == info["hash"] + def test_upload_many(shared_bucket, file_data, blobs_to_delete): FILE_BLOB_PAIRS = [ @@ -171,3 +183,208 @@ def test_download_chunks_concurrently(shared_bucket, file_data): ) with open(threaded_filename, "rb") as file_obj: assert _base64_md5hash(file_obj) == source_file["hash"] + + +def test_upload_chunks_concurrently(shared_bucket, file_data, blobs_to_delete): + source_file = file_data["big"] + filename = source_file["path"] + blob_name = "mpu_file" + upload_blob = shared_bucket.blob(blob_name) + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + assert os.path.getsize(filename) > chunk_size # Won't make a good test otherwise + + blobs_to_delete.append(upload_blob) + + transfer_manager.upload_chunks_concurrently( + filename, upload_blob, chunk_size=chunk_size, deadline=DEADLINE + ) + + with tempfile.NamedTemporaryFile() as tmp: + download_blob = shared_bucket.blob(blob_name) + download_blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents + + # Also test threaded mode + blob_name = "mpu_threaded" + upload_blob = shared_bucket.blob(blob_name) + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + assert os.path.getsize(filename) > chunk_size # Won't make a good test otherwise + + transfer_manager.upload_chunks_concurrently( + filename, + upload_blob, + chunk_size=chunk_size, + deadline=DEADLINE, + worker_type=transfer_manager.THREAD, + ) + + with tempfile.NamedTemporaryFile() as tmp: + download_blob = shared_bucket.blob(blob_name) + download_blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents + + +def test_upload_chunks_concurrently_with_metadata( + shared_bucket, file_data, blobs_to_delete +): + import datetime + from google.cloud._helpers import UTC + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + custom_metadata = {"key_a": "value_a", "key_b": "value_b"} + + METADATA = { + "cache_control": "private", + "content_disposition": "inline", + "content_language": "en-US", + "custom_time": now, + "metadata": custom_metadata, + "storage_class": "NEARLINE", + } + + source_file = file_data["big"] + filename = source_file["path"] + blob_name = "mpu_file_with_metadata" + upload_blob = shared_bucket.blob(blob_name) + + for key, value in METADATA.items(): + setattr(upload_blob, key, value) + + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + assert os.path.getsize(filename) > chunk_size # Won't make a good test otherwise + + transfer_manager.upload_chunks_concurrently( + filename, upload_blob, chunk_size=chunk_size, deadline=DEADLINE + ) + blobs_to_delete.append(upload_blob) + + with tempfile.NamedTemporaryFile() as tmp: + download_blob = shared_bucket.get_blob(blob_name) + + for key, value in METADATA.items(): + assert getattr(download_blob, key) == value + + download_blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents + + +def test_upload_chunks_concurrently_with_content_encoding( + shared_bucket, file_data, blobs_to_delete +): + import gzip + + METADATA = { + "content_encoding": "gzip", + } + + source_file = file_data["big"] + filename = source_file["path"] + blob_name = "mpu_file_encoded" + upload_blob = shared_bucket.blob(blob_name) + + for key, value in METADATA.items(): + setattr(upload_blob, key, value) + + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + + with tempfile.NamedTemporaryFile() as tmp_gzip: + with open(filename, "rb") as f: + compressed_bytes = gzip.compress(f.read()) + + tmp_gzip.write(compressed_bytes) + tmp_gzip.seek(0) + transfer_manager.upload_chunks_concurrently( + tmp_gzip.name, upload_blob, chunk_size=chunk_size, deadline=DEADLINE + ) + blobs_to_delete.append(upload_blob) + + with tempfile.NamedTemporaryFile() as tmp: + download_blob = shared_bucket.get_blob(blob_name) + + for key, value in METADATA.items(): + assert getattr(download_blob, key) == value + + download_blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents + + +def test_upload_chunks_concurrently_with_encryption_key( + shared_bucket, file_data, blobs_to_delete +): + source_file = file_data["big"] + filename = source_file["path"] + blob_name = "mpu_file_encrypted" + upload_blob = shared_bucket.blob(blob_name, encryption_key=encryption_key) + + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + assert os.path.getsize(filename) > chunk_size # Won't make a good test otherwise + + transfer_manager.upload_chunks_concurrently( + filename, upload_blob, chunk_size=chunk_size, deadline=DEADLINE + ) + blobs_to_delete.append(upload_blob) + + with tempfile.NamedTemporaryFile() as tmp: + download_blob = shared_bucket.get_blob(blob_name, encryption_key=encryption_key) + + download_blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents + + with tempfile.NamedTemporaryFile() as tmp: + keyless_blob = shared_bucket.get_blob(blob_name) + + with pytest.raises(exceptions.BadRequest): + keyless_blob.download_to_file(tmp) + + +def test_upload_chunks_concurrently_with_kms( + kms_bucket, file_data, blobs_to_delete, kms_key_name +): + source_file = file_data["big"] + filename = source_file["path"] + blob_name = "mpu_file_kms" + blob = kms_bucket.blob(blob_name, kms_key_name=kms_key_name) + + chunk_size = 5 * 1024 * 1024 # Minimum supported by XML MPU API + assert os.path.getsize(filename) > chunk_size # Won't make a good test otherwise + + transfer_manager.upload_chunks_concurrently( + filename, blob, chunk_size=chunk_size, deadline=DEADLINE + ) + blobs_to_delete.append(blob) + blob.reload() + assert blob.kms_key_name.startswith(kms_key_name) + + with tempfile.NamedTemporaryFile() as tmp: + blob.download_to_file(tmp) + tmp.seek(0) + + with open(source_file["path"], "rb") as sf: + source_contents = sf.read() + temp_contents = tmp.read() + assert source_contents == temp_contents diff --git a/tests/unit/test_transfer_manager.py b/tests/unit/test_transfer_manager.py index 685f48579..f1d760043 100644 --- a/tests/unit/test_transfer_manager.py +++ b/tests/unit/test_transfer_manager.py @@ -18,6 +18,7 @@ from google.cloud.storage import transfer_manager from google.cloud.storage import Blob +from google.cloud.storage import Client from google.api_core import exceptions @@ -33,6 +34,9 @@ FAKE_ENCODING = "fake_gzip" DOWNLOAD_KWARGS = {"accept-encoding": FAKE_ENCODING} CHUNK_SIZE = 8 +HOSTNAME = "https://example.com" +URL = "https://example.com/bucket/blob" +USER_AGENT = "agent" # Used in subprocesses only, so excluded from coverage @@ -529,7 +533,7 @@ def test_download_chunks_concurrently(): blob_mock.download_to_filename.return_value = FAKE_RESULT - with mock.patch("__main__.open", mock.mock_open()): + with mock.patch("google.cloud.storage.transfer_manager.open", mock.mock_open()): result = transfer_manager.download_chunks_concurrently( blob_mock, FILENAME, @@ -554,7 +558,7 @@ def test_download_chunks_concurrently_raises_on_start_and_end(): MULTIPLE = 4 blob_mock.size = CHUNK_SIZE * MULTIPLE - with mock.patch("__main__.open", mock.mock_open()): + with mock.patch("google.cloud.storage.transfer_manager.open", mock.mock_open()): with pytest.raises(ValueError): transfer_manager.download_chunks_concurrently( blob_mock, @@ -587,7 +591,9 @@ def test_download_chunks_concurrently_passes_concurrency_options(): with mock.patch("concurrent.futures.ThreadPoolExecutor") as pool_patch, mock.patch( "concurrent.futures.wait" - ) as wait_patch, mock.patch("__main__.open", mock.mock_open()): + ) as wait_patch, mock.patch( + "google.cloud.storage.transfer_manager.open", mock.mock_open() + ): transfer_manager.download_chunks_concurrently( blob_mock, FILENAME, @@ -600,6 +606,182 @@ def test_download_chunks_concurrently_passes_concurrency_options(): wait_patch.assert_called_with(mock.ANY, timeout=DEADLINE, return_when=mock.ANY) +def test_upload_chunks_concurrently(): + bucket = mock.Mock() + bucket.name = "bucket" + bucket.client = _PickleableMockClient(identify_as_client=True) + transport = bucket.client._http + bucket.user_project = None + + blob = Blob("blob", bucket) + blob.content_type = FAKE_CONTENT_TYPE + + FILENAME = "file_a.txt" + SIZE = 2048 + + container_mock = mock.Mock() + container_mock.upload_id = "abcd" + part_mock = mock.Mock() + ETAG = "efgh" + part_mock.etag = ETAG + + with mock.patch("os.path.getsize", return_value=SIZE), mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUContainer", + return_value=container_mock, + ), mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUPart", return_value=part_mock + ): + transfer_manager.upload_chunks_concurrently( + FILENAME, + blob, + chunk_size=SIZE // 2, + worker_type=transfer_manager.THREAD, + ) + container_mock.initiate.assert_called_once_with( + transport=transport, content_type=blob.content_type + ) + container_mock.register_part.assert_any_call(1, ETAG) + container_mock.register_part.assert_any_call(2, ETAG) + container_mock.finalize.assert_called_once_with(bucket.client._http) + part_mock.upload.assert_called_with(transport) + + +def test_upload_chunks_concurrently_passes_concurrency_options(): + bucket = mock.Mock() + bucket.name = "bucket" + bucket.client = _PickleableMockClient(identify_as_client=True) + transport = bucket.client._http + bucket.user_project = None + + blob = Blob("blob", bucket) + + FILENAME = "file_a.txt" + SIZE = 2048 + + container_mock = mock.Mock() + container_mock.upload_id = "abcd" + + MAX_WORKERS = 7 + DEADLINE = 10 + + with mock.patch("os.path.getsize", return_value=SIZE), mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUContainer", + return_value=container_mock, + ), mock.patch("concurrent.futures.ThreadPoolExecutor") as pool_patch, mock.patch( + "concurrent.futures.wait" + ) as wait_patch: + try: + transfer_manager.upload_chunks_concurrently( + FILENAME, + blob, + chunk_size=SIZE // 2, + worker_type=transfer_manager.THREAD, + max_workers=MAX_WORKERS, + deadline=DEADLINE, + ) + except ValueError: + pass # The futures don't actually work, so we expect this to abort. + # Conveniently, that gives us a chance to test the auto-delete + # exception handling feature. + container_mock.cancel.assert_called_once_with(transport) + pool_patch.assert_called_with(max_workers=MAX_WORKERS) + wait_patch.assert_called_with(mock.ANY, timeout=DEADLINE, return_when=mock.ANY) + + +def test_upload_chunks_concurrently_with_metadata_and_encryption(): + import datetime + from google.cloud._helpers import UTC + from google.cloud._helpers import _RFC3339_MICROS + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_str = now.strftime(_RFC3339_MICROS) + + custom_metadata = {"key_a": "value_a", "key_b": "value_b"} + encryption_key = "b23ff11bba187db8c37077e6af3b25b8" + kms_key_name = "sample_key_name" + + METADATA = { + "cache_control": "private", + "content_disposition": "inline", + "content_language": "en-US", + "custom_time": now, + "metadata": custom_metadata, + "storage_class": "NEARLINE", + } + + bucket = mock.Mock() + bucket.name = "bucket" + bucket.client = _PickleableMockClient(identify_as_client=True) + transport = bucket.client._http + user_project = "my_project" + bucket.user_project = user_project + + blob = Blob("blob", bucket, kms_key_name=kms_key_name) + blob.content_type = FAKE_CONTENT_TYPE + + for key, value in METADATA.items(): + setattr(blob, key, value) + blob.metadata = {**custom_metadata} + blob.encryption_key = encryption_key + + FILENAME = "file_a.txt" + SIZE = 2048 + + container_mock = mock.Mock() + container_mock.upload_id = "abcd" + part_mock = mock.Mock() + ETAG = "efgh" + part_mock.etag = ETAG + container_cls_mock = mock.Mock(return_value=container_mock) + + invocation_id = "b9f8cbb0-6456-420c-819d-3f4ee3c0c455" + + with mock.patch("os.path.getsize", return_value=SIZE), mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUContainer", new=container_cls_mock + ), mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUPart", return_value=part_mock + ), mock.patch( + "google.cloud.storage._helpers._get_invocation_id", + return_value="gccl-invocation-id/" + invocation_id, + ): + transfer_manager.upload_chunks_concurrently( + FILENAME, + blob, + chunk_size=SIZE // 2, + worker_type=transfer_manager.THREAD, + ) + expected_headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip, deflate", + "User-Agent": "agent", + "X-Goog-API-Client": "agent gccl-invocation-id/{}".format(invocation_id), + "content-type": FAKE_CONTENT_TYPE, + "x-upload-content-type": FAKE_CONTENT_TYPE, + "X-Goog-Encryption-Algorithm": "AES256", + "X-Goog-Encryption-Key": "YjIzZmYxMWJiYTE4N2RiOGMzNzA3N2U2YWYzYjI1Yjg=", + "X-Goog-Encryption-Key-Sha256": "B25Y4hgVlNXDliAklsNz9ykLk7qvgqDrSbdds5iu8r4=", + "Cache-Control": "private", + "Content-Disposition": "inline", + "Content-Language": "en-US", + "x-goog-storage-class": "NEARLINE", + "x-goog-custom-time": now_str, + "x-goog-meta-key_a": "value_a", + "x-goog-meta-key_b": "value_b", + "x-goog-user-project": "my_project", + "x-goog-encryption-kms-key-name": "sample_key_name", + } + container_cls_mock.assert_called_once_with( + URL, FILENAME, headers=expected_headers + ) + container_mock.initiate.assert_called_once_with( + transport=transport, content_type=blob.content_type + ) + container_mock.register_part.assert_any_call(1, ETAG) + container_mock.register_part.assert_any_call(2, ETAG) + container_mock.finalize.assert_called_once_with(transport) + part_mock.upload.assert_called_with(blob.client._http) + + class _PickleableMockBlob: def __init__( self, @@ -623,6 +805,28 @@ def download_to_file(self, *args, **kwargs): return "SUCCESS" +class _PickleableMockConnection: + @staticmethod + def get_api_base_url_for_mtls(): + return HOSTNAME + + user_agent = USER_AGENT + + +class _PickleableMockClient: + def __init__(self, identify_as_client=False): + self._http = "my_transport" # used as an identifier for "called_with" + self._connection = _PickleableMockConnection() + self.identify_as_client = identify_as_client + + @property + def __class__(self): + if self.identify_as_client: + return Client + else: + return _PickleableMockClient + + # Used in subprocesses only, so excluded from coverage def _validate_blob_token_in_subprocess_for_chunk( maybe_pickled_blob, filename, **kwargs @@ -642,7 +846,7 @@ def test_download_chunks_concurrently_with_processes(): with mock.patch( "google.cloud.storage.transfer_manager._download_and_write_chunk_in_place", new=_validate_blob_token_in_subprocess_for_chunk, - ), mock.patch("__main__.open", mock.mock_open()): + ), mock.patch("google.cloud.storage.transfer_manager.open", mock.mock_open()): result = transfer_manager.download_chunks_concurrently( blob, FILENAME, @@ -665,26 +869,44 @@ def test__LazyClient(): assert len(fake_cache) == 1 -def test__pickle_blob(): +def test__pickle_client(): # This test nominally has coverage, but doesn't assert that the essential - # copyreg behavior in _pickle_blob works. Unfortunately there doesn't seem + # copyreg behavior in _pickle_client works. Unfortunately there doesn't seem # to be a good way to check that without actually creating a Client, which # will spin up HTTP connections undesirably. This is more fully checked in - # the system tests, though. - pkl = transfer_manager._pickle_blob(FAKE_RESULT) + # the system tests. + pkl = transfer_manager._pickle_client(FAKE_RESULT) assert pickle.loads(pkl) == FAKE_RESULT def test__download_and_write_chunk_in_place(): pickled_mock = pickle.dumps(_PickleableMockBlob()) FILENAME = "file_a.txt" - with mock.patch("__main__.open", mock.mock_open()): + with mock.patch("google.cloud.storage.transfer_manager.open", mock.mock_open()): result = transfer_manager._download_and_write_chunk_in_place( pickled_mock, FILENAME, 0, 8, {} ) assert result == "SUCCESS" +def test__upload_part(): + pickled_mock = pickle.dumps(_PickleableMockClient()) + FILENAME = "file_a.txt" + UPLOAD_ID = "abcd" + ETAG = "efgh" + + part = mock.Mock() + part.etag = ETAG + with mock.patch( + "google.cloud.storage.transfer_manager.XMLMPUPart", return_value=part + ): + result = transfer_manager._upload_part( + pickled_mock, URL, UPLOAD_ID, FILENAME, 0, 256, 1, None, {"key", "value"} + ) + part.upload.assert_called_once() + assert result == (1, ETAG) + + def test__get_pool_class_and_requirements_error(): with pytest.raises(ValueError): transfer_manager._get_pool_class_and_requirements("garbage")