From 9b505f429746a735aefeff55fb76a7eeeb7fa8d7 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Wed, 2 Mar 2022 22:23:31 -0700 Subject: [PATCH] Add an option to not base64-encode SQS messages. Also simplify the base64 decoding logic so that we don't have to run base64 decoding twice for every message. --- kombu/transport/SQS.py | 32 ++++++++++++++++++-------------- t/unit/transport/test_SQS.py | 8 ++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 6e9b8dfa1..7e8c44a66 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -395,8 +395,11 @@ def _delete(self, queue, *args, **kwargs): def _put(self, queue, message, **kwargs): """Put message onto queue.""" q_url = self._new_queue(queue) - kwargs = {'QueueUrl': q_url, - 'MessageBody': AsyncMessage().encode(dumps(message))} + if self.sqs_base64_encoding: + body = AsyncMessage().encode(dumps(message)) + else: + body = dumps(message) + kwargs = {'QueueUrl': q_url, 'MessageBody': body} if queue.endswith('.fifo'): if 'MessageGroupId' in message['properties']: kwargs['MessageGroupId'] = \ @@ -420,22 +423,19 @@ def _put(self, queue, message, **kwargs): c.send_message(**kwargs) @staticmethod - def __b64_encoded(byte_string): + def _optional_b64_decode(byte_string): try: - return base64.b64encode( - base64.b64decode(byte_string) - ) == byte_string + data = base64.b64decode(byte_string) + if base64.b64encode(data) == byte_string: + return data + # else the base64 module found some embedded base64 content + # that should be ignored. except Exception: # pylint: disable=broad-except - return False - - def _message_to_python(self, message, queue_name, queue): - body = message['Body'].encode() - try: - if self.__b64_encoded(body): - body = base64.b64decode(body) - except TypeError: pass + return byte_string + def _message_to_python(self, message, queue_name, queue): + body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: queue = self._new_queue(queue_name) @@ -837,6 +837,10 @@ def wait_time_seconds(self): return self.transport_options.get('wait_time_seconds', self.default_wait_time_seconds) + @cached_property + def sqs_base64_encoding(self): + return self.transport_options.get('sqs_base64_encoding', True) + class Transport(virtual.Transport): """SQS Transport. diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 6056dd3d8..ea261659a 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -336,13 +336,13 @@ def test_get_bulk_raises_empty(self): with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) - def test_is_base64_encoded(self): + def test_optional_b64_decode(self): raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \ b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa b64_enc = base64.b64encode(raw) - assert self.channel._Channel__b64_encoded(b64_enc) - assert not self.channel._Channel__b64_encoded(raw) - assert not self.channel._Channel__b64_encoded(b"test123") + assert self.channel._optional_b64_decode(b64_enc) == raw + assert self.channel._optional_b64_decode(raw) == raw + assert self.channel._optional_b64_decode(b"test123") == b"test123" def test_messages_to_python(self): from kombu.asynchronous.aws.sqs.message import Message