Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Redis connections after reconnect - consumer starts consuming the tasks after crash. #2007

Merged
merged 4 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions kombu/transport/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def __init__(self, *args, **kwargs):

if not self.ack_emulation: # disable visibility timeout
self.QoS = virtual.QoS

self._registered = False
self._queue_cycle = cycle_by_name(self.queue_order_strategy)()
self.Client = self._get_client()
self.ResponseError = self._get_response_error()
Expand All @@ -747,6 +747,9 @@ def __init__(self, *args, **kwargs):
raise

self.connection.cycle.add(self) # add to channel poller.
# and set to true after sucessfuly added channel to the poll.
self._registered = True

# copy errors, in case channel closed but threads still
# are still waiting for data.
self.connection_errors = self.connection.connection_errors
Expand Down Expand Up @@ -1201,7 +1204,10 @@ def _connparams(self, asynchronous=False):
class Connection(connection_cls):
def disconnect(self, *args):
super().disconnect(*args)
channel._on_connection_disconnect(self)
# We remove the connection from the poller
Nusnus marked this conversation as resolved.
Show resolved Hide resolved
# only if it has been added properly.
if channel._registered:
channel._on_connection_disconnect(self)
connection_cls = Connection

connparams['connection_class'] = connection_cls
Expand Down
156 changes: 156 additions & 0 deletions t/unit/transport/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,173 @@ class XTransport(Transport):
Channel = XChannel

conn = Connection(transport=XTransport)
conn.transport.cycle = Mock(name='cycle')
client.ping.side_effect = RuntimeError()
with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_called_with()
pool.disconnect.reset_mock()
# Ensure that the channel without ensured connection to Redis
# won't be added to the cycle.
conn.transport.cycle.add.assert_not_called()
assert len(conn.transport.channels) == 0

pool_at_init = [None]
with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_not_called()

def test_redis_connection_added_to_cycle_if_ping_succeeds(self):
"""Test should check the connection is added to the cycle only
if the ping to Redis was finished successfully."""
# given: mock pool and client
pool = Mock(name='pool')
client = Mock(name='client')

# override channel class with given mocks
class XChannel(Channel):
def __init__(self, *args, **kwargs):
self._pool = pool
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

# override Channel in Transport with given channel
class XTransport(Transport):
Channel = XChannel

# when: create connection with overridden transport
conn = Connection(transport=XTransport)
conn.transport.cycle = Mock(name='cycle')
# create the channel
chan = conn.channel()
# then: check if ping was called
client.ping.assert_called_once()
# the connection was added to the cycle
conn.transport.cycle.add.assert_called_once()
assert len(conn.transport.channels) == 1
# the channel was flaged as registered into poller
assert chan._registered

def test_redis_on_disconnect_channel_only_if_was_registered(self):
"""Test shoud check if the _on_disconnect method is called only
if the channel was registered into the poller."""
# given: mock pool and client
pool = Mock(name='pool')
client = Mock(
name='client',
ping=Mock(return_value=True)
)

# create RedisConnectionMock class
# for the possibility to run disconnect method
class RedisConnectionMock:
def disconnect(self, *args):
pass

# override Channel method with given mocks
class XChannel(Channel):
connection_class = RedisConnectionMock

def __init__(self, *args, **kwargs):
self._pool = pool
# counter to check if the method was called
self.on_disconect_count = 0
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

def _on_connection_disconnect(self, connection):
# increment the counter when the method is called
self.on_disconect_count += 1

# create the channel
chan = XChannel(Mock(
_used_channel_ids=[],
channel_max=1,
channels=[],
client=Mock(
transport_options={},
hostname="127.0.0.1",
virtual_host=None)))
# create the _connparams with overriden connection_class
connparams = chan._connparams(asynchronous=True)
# create redis.Connection
redis_connection = connparams['connection_class']()
# the connection was added to the cycle
chan.connection.cycle.add.assert_called_once()
# and the ping was called
client.ping.assert_called_once()
# the channel was registered
assert chan._registered
# than disconnect the Redis connection
redis_connection.disconnect()
# the on_disconnect counter should be incremented
assert chan.on_disconect_count == 1

def test_redis__on_disconnect_should_not_be_called_if_not_registered(self):
"""Test should check if the _on_disconnect method is not called because
the connection to Redis isn't established properly."""
# given: mock pool
pool = Mock(name='pool')
# client mock with ping method which return ConnectionError
from redis.exceptions import ConnectionError
client = Mock(
name='client',
ping=Mock(side_effect=ConnectionError())
)

# create RedisConnectionMock
# for the possibility to run disconnect method
class RedisConnectionMock:
def disconnect(self, *args):
pass

# override Channel method with given mocks
class XChannel(Channel):
connection_class = RedisConnectionMock

def __init__(self, *args, **kwargs):
self._pool = pool
# counter to check if the method was called
self.on_disconect_count = 0
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

def _on_connection_disconnect(self, connection):
# increment the counter when the method is called
self.on_disconect_count += 1

# then: exception was risen
with pytest.raises(ConnectionError):
# when: create the channel
chan = XChannel(Mock(
_used_channel_ids=[],
channel_max=1,
channels=[],
client=Mock(
transport_options={},
hostname="127.0.0.1",
virtual_host=None)))
# create the _connparams with overriden connection_class
connparams = chan._connparams(asynchronous=True)
# create redis.Connection
redis_connection = connparams['connection_class']()
# the connection wasn't added to the cycle
chan.connection.cycle.add.assert_not_called()
# the ping was called once with the exception
client.ping.assert_called_once()
# the channel was not registered
assert not chan._registered
# then: disconnect the Redis connection
redis_connection.disconnect()
# the on_disconnect counter shouldn't be incremented
assert chan.on_disconect_count == 0

def test_get_redis_ConnectionError(self):
from redis.exceptions import ConnectionError

Expand Down