Skip to content

Commit

Permalink
Add an option to not base64-encode SQS messages.
Browse files Browse the repository at this point in the history
Also simplify the base64 decoding logic so that we don't have to
run base64 decoding twice for every message.
  • Loading branch information
hathawsh authored and auvipy committed Mar 4, 2022
1 parent 22adaaa commit 9b505f4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
32 changes: 18 additions & 14 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,11 @@ def _delete(self, queue, *args, **kwargs):
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q_url = self._new_queue(queue)
kwargs = {'QueueUrl': q_url,
'MessageBody': AsyncMessage().encode(dumps(message))}
if self.sqs_base64_encoding:
body = AsyncMessage().encode(dumps(message))
else:
body = dumps(message)
kwargs = {'QueueUrl': q_url, 'MessageBody': body}
if queue.endswith('.fifo'):
if 'MessageGroupId' in message['properties']:
kwargs['MessageGroupId'] = \
Expand All @@ -420,22 +423,19 @@ def _put(self, queue, message, **kwargs):
c.send_message(**kwargs)

@staticmethod
def __b64_encoded(byte_string):
def _optional_b64_decode(byte_string):
try:
return base64.b64encode(
base64.b64decode(byte_string)
) == byte_string
data = base64.b64decode(byte_string)
if base64.b64encode(data) == byte_string:
return data
# else the base64 module found some embedded base64 content
# that should be ignored.
except Exception: # pylint: disable=broad-except
return False

def _message_to_python(self, message, queue_name, queue):
body = message['Body'].encode()
try:
if self.__b64_encoded(body):
body = base64.b64decode(body)
except TypeError:
pass
return byte_string

def _message_to_python(self, message, queue_name, queue):
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
queue = self._new_queue(queue_name)
Expand Down Expand Up @@ -837,6 +837,10 @@ def wait_time_seconds(self):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)

@cached_property
def sqs_base64_encoding(self):
return self.transport_options.get('sqs_base64_encoding', True)


class Transport(virtual.Transport):
"""SQS Transport.
Expand Down
8 changes: 4 additions & 4 deletions t/unit/transport/test_SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,13 @@ def test_get_bulk_raises_empty(self):
with pytest.raises(Empty):
self.channel._get_bulk(self.queue_name)

def test_is_base64_encoded(self):
def test_optional_b64_decode(self):
raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \
b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa
b64_enc = base64.b64encode(raw)
assert self.channel._Channel__b64_encoded(b64_enc)
assert not self.channel._Channel__b64_encoded(raw)
assert not self.channel._Channel__b64_encoded(b"test123")
assert self.channel._optional_b64_decode(b64_enc) == raw
assert self.channel._optional_b64_decode(raw) == raw
assert self.channel._optional_b64_decode(b"test123") == b"test123"

def test_messages_to_python(self):
from kombu.asynchronous.aws.sqs.message import Message
Expand Down

0 comments on commit 9b505f4

Please sign in to comment.