diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 2f0f9d45..963c1975 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -442,13 +442,8 @@ async def close(self, *_: Any) -> Any: log.debug("Queue iterator %r closed", self) # Reject all messages - msg: Optional[IncomingMessage] = None - try: - while True: - msg = self._queue.get_nowait() - except asyncio.QueueEmpty: - if msg is None: - return + while not self._queue.empty(): + msg = self._queue.get_nowait() if self._amqp_queue.channel.is_closed: log.warning( @@ -456,17 +451,14 @@ async def close(self, *_: Any) -> Any: msg, self, ) - return - - if self._consume_kwargs.get("no_ack", False): + elif self._consume_kwargs.get("no_ack", False): log.warning( "Message %r lost for consumer with no_ack %r", msg, self, ) - return - - await msg.nack(requeue=True, multiple=True) + else: + await msg.nack(requeue=True, multiple=False) def __str__(self) -> str: return f"queue[{self._amqp_queue}](...)" diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 33ddb1ca..dcee7f35 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -16,7 +16,7 @@ import aio_pika.exceptions from aio_pika import Channel, DeliveryMode, Message from aio_pika.abc import ( - AbstractConnection, AbstractIncomingMessage, MessageInfo, + AbstractConnection, AbstractIncomingMessage, MessageInfo, AbstractQueue, ) from aio_pika.exceptions import ( DeliveryError, MessageProcessError, ProbableAuthenticationError, @@ -1587,6 +1587,140 @@ async def test_heartbeat_disabling( async with connection: assert heartbeat == 0 + async def test_non_acked_messages_are_redelivered_to_queue( + self, + channel: aio_pika.Channel, + declare_queue: Callable, + declare_exchange: Callable, + ): + queue_name = get_random_name("test_connection") + routing_key = get_random_name() + + exchange = await declare_exchange( + "direct", auto_delete=True, channel=channel, + ) + + queue: AbstractQueue = await declare_queue( + queue_name, auto_delete=False, channel=channel, + ) + + await queue.bind(exchange, routing_key) + + # Publish 5 messages to queue + all_bodies = [] + for _ in range(0, 5): + body = bytes(shortuuid.uuid(), "utf-8") + all_bodies.append(body) + + assert await exchange.publish(Message(body), routing_key) + + # Create a subscription but only process first message + async with queue.iterator() as queue_iterator: + first_message = await anext(queue_iterator) + async with first_message.process(): + assert first_message.body == all_bodies[0] + + # Confirm other messages are still in queue + for i in range(1, 5): + incoming_message = await queue.get(timeout=5) + await incoming_message.ack() + + assert incoming_message.body == all_bodies[i] + + # Check if the queue is now empty + assert await queue.get(fail=False, timeout=.5) is None + + # Cleanup, delete the queue + await queue.delete() + + async def test_regression_only_messages_cancelled_subscription_are_nacked( + self, + channel: aio_pika.Channel, + declare_queue: Callable, + declare_exchange: Callable, + ): + queue_name1 = get_random_name("test_queue") + queue_name2 = get_random_name("test_queue") + routing_key1 = get_random_name() + routing_key2 = get_random_name() + + exchange = await declare_exchange( + "direct", auto_delete=True, channel=channel, + ) + + queue1: AbstractQueue = await declare_queue( + queue_name1, auto_delete=False, channel=channel, + ) + queue2: AbstractQueue = await declare_queue( + queue_name2, auto_delete=False, channel=channel, + ) + + await queue1.bind(exchange, routing_key1) + await queue2.bind(exchange, routing_key2) + + # Publish 5 messages to queue 1 + all_bodies1 = [] + for _ in range(0, 5): + body = bytes(shortuuid.uuid(), "utf-8") + all_bodies1.append(body) + + assert await exchange.publish(Message(body), routing_key1) + + # Publish 5 messages to queue 2 + all_bodies2 = [] + for _ in range(0, 5): + body = bytes(shortuuid.uuid(), "utf-8") + all_bodies2.append(body) + + assert await exchange.publish(Message(body), routing_key2) + + # Create a subscription to both queues but only process first message + queue_iterator1 = await queue1.iterator().__aenter__() + queue_iterator2 = await queue2.iterator().__aenter__() + + first_message1 = await anext(queue_iterator1) + async with first_message1.process(): + assert first_message1.body == all_bodies1[0] + + first_message2 = await anext(queue_iterator2) + async with first_message2.process(): + assert first_message2.body == all_bodies2[0] + # The order of exit here is important. + # Subscription to queue 1 is received first then to 2. + # Therefore, the delivery tags of subscription to queue 2 will be + # higher. + # So first we cancel the subscription to 2, to test if we + # accidentally also nacked the messages of queue 1. Then we cancel + # subscription to queue 1 to test. + + await queue_iterator2.__aexit__(None, None, None) + # To test if the wrong messages are nacked by stopping subscription to + # queue 2, we ack a message received from queue 1. + second_message1 = await anext(queue_iterator1) + async with second_message1.process(): + assert second_message1.body == all_bodies1[1] + + await queue_iterator1.__aexit__(None, None, None) + + # Confirm other messages are still in queue + for i in range(2, 5): + incoming_message = await queue1.get(timeout=5) + await incoming_message.ack() + assert incoming_message.body == all_bodies1[i] + + for i in range(1, 5): + incoming_message = await queue2.get(timeout=5) + await incoming_message.ack() + assert incoming_message.body == all_bodies2[i] + + # Check if the queue is now empty + assert await queue1.get(fail=False, timeout=.5) is None + assert await queue2.get(fail=False, timeout=.5) is None + + # Cleanup, delete the queue + await queue1.delete() + await queue2.delete() + class TestCaseAmqpNoConfirms(TestCaseAmqp): @staticmethod