From 94a35ba7416804881973f6a5296b430bdcf2832d Mon Sep 17 00:00:00 2001 From: cojenco Date: Wed, 31 May 2023 09:22:30 -0700 Subject: [PATCH] feat: allow exceptions to be included in batch responses (#1043) * feat: allow exceptions to be included in batch responses * fix docstring * address comments and update tests * more tests --- google/cloud/storage/batch.py | 39 ++++++++++++--- google/cloud/storage/client.py | 11 ++++- tests/unit/test_batch.py | 88 ++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 8 deletions(-) diff --git a/google/cloud/storage/batch.py b/google/cloud/storage/batch.py index 599aa3a7f..54ef55cd3 100644 --- a/google/cloud/storage/batch.py +++ b/google/cloud/storage/batch.py @@ -133,11 +133,18 @@ class Batch(Connection): :type client: :class:`google.cloud.storage.client.Client` :param client: The client to use for making connections. + + :type raise_exception: bool + :param raise_exception: + (Optional) Defaults to True. If True, instead of adding exceptions + to the list of return responses, the final exception will be raised. + Note that exceptions are unwrapped after all operations are complete + in success or failure, and only the last exception is raised. """ _MAX_BATCH_SIZE = 1000 - def __init__(self, client): + def __init__(self, client, raise_exception=True): api_endpoint = client._connection.API_BASE_URL client_info = client._connection._client_info super(Batch, self).__init__( @@ -145,6 +152,8 @@ def __init__(self, client): ) self._requests = [] self._target_objects = [] + self._responses = [] + self._raise_exception = raise_exception def _do_request( self, method, url, headers, data, target_object, timeout=_DEFAULT_TIMEOUT @@ -219,24 +228,34 @@ def _prepare_batch_request(self): _, body = payload.split("\n\n", 1) return dict(multi._headers), body, timeout - def _finish_futures(self, responses): + def _finish_futures(self, responses, raise_exception=True): """Apply all the batch responses to the futures created. :type responses: list of (headers, payload) tuples. :param responses: List of headers and payloads from each response in the batch. + :type raise_exception: bool + :param raise_exception: + (Optional) Defaults to True. If True, instead of adding exceptions + to the list of return responses, the final exception will be raised. + Note that exceptions are unwrapped after all operations are complete + in success or failure, and only the last exception is raised. + :raises: :class:`ValueError` if no requests have been deferred. """ # If a bad status occurs, we track it, but don't raise an exception # until all futures have been populated. + # If raise_exception=False, we add exceptions to the list of responses. exception_args = None if len(self._target_objects) != len(responses): # pragma: NO COVER raise ValueError("Expected a response for every request.") for target_object, subresponse in zip(self._target_objects, responses): - if not 200 <= subresponse.status_code < 300: + # For backwards compatibility, only the final exception will be raised. + # Set raise_exception=False to include all exceptions to the list of return responses. + if not 200 <= subresponse.status_code < 300 and raise_exception: exception_args = exception_args or subresponse elif target_object is not None: try: @@ -247,9 +266,16 @@ def _finish_futures(self, responses): if exception_args is not None: raise exceptions.from_http_response(exception_args) - def finish(self): + def finish(self, raise_exception=True): """Submit a single `multipart/mixed` request with deferred requests. + :type raise_exception: bool + :param raise_exception: + (Optional) Defaults to True. If True, instead of adding exceptions + to the list of return responses, the final exception will be raised. + Note that exceptions are unwrapped after all operations are complete + in success or failure, and only the last exception is raised. + :rtype: list of tuples :returns: one ``(headers, payload)`` tuple per deferred request. """ @@ -269,7 +295,8 @@ def finish(self): raise exceptions.from_http_response(response) responses = list(_unpack_batch_response(response)) - self._finish_futures(responses) + self._finish_futures(responses, raise_exception=raise_exception) + self._responses = responses return responses def current(self): @@ -283,7 +310,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): try: if exc_type is None: - self.finish() + self.finish(raise_exception=self._raise_exception) finally: self._client._pop_batch() diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index bcb0b59ef..042e8b2ef 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -307,17 +307,24 @@ def bucket(self, bucket_name, user_project=None): """ return Bucket(client=self, name=bucket_name, user_project=user_project) - def batch(self): + def batch(self, raise_exception=True): """Factory constructor for batch object. .. note:: This will not make an HTTP request; it simply instantiates a batch object owned by this client. + :type raise_exception: bool + :param raise_exception: + (Optional) Defaults to True. If True, instead of adding exceptions + to the list of return responses, the final exception will be raised. + Note that exceptions are unwrapped after all operations are complete + in success or failure, and only the last exception is raised. + :rtype: :class:`google.cloud.storage.batch.Batch` :returns: The batch object created. """ - return Batch(client=self) + return Batch(client=self, raise_exception=raise_exception) def _get_resource( self, diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 72b54769f..37f8b8190 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -334,6 +334,7 @@ def test_finish_nonempty(self): result = batch.finish() self.assertEqual(len(result), len(batch._requests)) + self.assertEqual(len(result), len(batch._responses)) response1, response2, response3 = result @@ -438,6 +439,55 @@ def test_finish_nonempty_with_status_failure(self): self._check_subrequest_payload(chunks[0], "GET", url, {}) self._check_subrequest_payload(chunks[1], "GET", url, {}) + def test_finish_no_raise_exception(self): + url = "http://api.example.com/other_api" + expected_response = _make_response( + content=_TWO_PART_MIME_RESPONSE_WITH_FAIL, + headers={"content-type": 'multipart/mixed; boundary="DEADBEEF="'}, + ) + http = _make_requests_session([expected_response]) + connection = _Connection(http=http) + client = _Client(connection) + batch = self._make_one(client) + batch.API_BASE_URL = "http://api.example.com" + target1 = _MockObject() + target2 = _MockObject() + + batch._do_request("GET", url, {}, None, target1, timeout=42) + batch._do_request("GET", url, {}, None, target2, timeout=420) + + # Make sure futures are not populated. + self.assertEqual( + [future for future in batch._target_objects], [target1, target2] + ) + + batch.finish(raise_exception=False) + + self.assertEqual(len(batch._requests), 2) + self.assertEqual(len(batch._responses), 2) + + # Make sure NotFound exception is added to responses and target2 + self.assertEqual(target1._properties, {"foo": 1, "bar": 2}) + self.assertEqual(target2._properties, {"error": {"message": "Not Found"}}) + + expected_url = f"{batch.API_BASE_URL}/batch/storage/v1" + http.request.assert_called_once_with( + method="POST", + url=expected_url, + headers=mock.ANY, + data=mock.ANY, + timeout=420, # the last request timeout prevails + ) + + _, request_body, _, boundary = self._get_mutlipart_request(http) + + chunks = self._get_payload_chunks(boundary, request_body) + self.assertEqual(len(chunks), 2) + self._check_subrequest_payload(chunks[0], "GET", url, {}) + self._check_subrequest_payload(chunks[1], "GET", url, {}) + self.assertEqual(batch._responses[0].status_code, 200) + self.assertEqual(batch._responses[1].status_code, 404) + def test_finish_nonempty_non_multipart_response(self): url = "http://api.example.com/other_api" http = _make_requests_session([_make_response()]) @@ -497,6 +547,7 @@ def test_as_context_mgr_wo_error(self): self.assertEqual(list(client._batch_stack), []) self.assertEqual(len(batch._requests), 3) + self.assertEqual(len(batch._responses), 3) self.assertEqual(batch._requests[0][0], "POST") self.assertEqual(batch._requests[1][0], "PATCH") self.assertEqual(batch._requests[2][0], "DELETE") @@ -505,6 +556,43 @@ def test_as_context_mgr_wo_error(self): self.assertEqual(target2._properties, {"foo": 1, "bar": 3}) self.assertEqual(target3._properties, b"") + def test_as_context_mgr_no_raise_exception(self): + from google.cloud.storage.client import Client + + url = "http://api.example.com/other_api" + expected_response = _make_response( + content=_TWO_PART_MIME_RESPONSE_WITH_FAIL, + headers={"content-type": 'multipart/mixed; boundary="DEADBEEF="'}, + ) + http = _make_requests_session([expected_response]) + project = "PROJECT" + credentials = _make_credentials() + client = Client(project=project, credentials=credentials) + client._http_internal = http + + self.assertEqual(list(client._batch_stack), []) + + target1 = _MockObject() + target2 = _MockObject() + + with self._make_one(client, raise_exception=False) as batch: + self.assertEqual(list(client._batch_stack), [batch]) + batch._make_request("GET", url, {}, target_object=target1) + batch._make_request("GET", url, {}, target_object=target2) + + self.assertEqual(list(client._batch_stack), []) + self.assertEqual(len(batch._requests), 2) + self.assertEqual(len(batch._responses), 2) + self.assertEqual(batch._requests[0][0], "GET") + self.assertEqual(batch._requests[1][0], "GET") + self.assertEqual(batch._target_objects, [target1, target2]) + + # Make sure NotFound exception is added to responses and target2 + self.assertEqual(batch._responses[0].status_code, 200) + self.assertEqual(batch._responses[1].status_code, 404) + self.assertEqual(target1._properties, {"foo": 1, "bar": 2}) + self.assertEqual(target2._properties, {"error": {"message": "Not Found"}}) + def test_as_context_mgr_w_error(self): from google.cloud.storage.batch import _FutureDict from google.cloud.storage.client import Client