Skip to content

Commit

Permalink
allow getting recoverable_connection_errors without an active transpo…
Browse files Browse the repository at this point in the history
…rt (#1471)

* allow getting recoverable_connection_errors without an active transport

* move redis transport errors to class

* move consul transport errors to class

* move etcd transport errors to class

* remove redis.Transport._get_errors and references in tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix flake8 errors

* add integration test for redis ConnectionError

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pawl and pre-commit-ci[bot] committed Dec 30, 2021
1 parent b6b4408 commit 9c062bd
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 37 deletions.
8 changes: 4 additions & 4 deletions kombu/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def recoverable_connection_errors(self):
but where the connection must be closed and re-established first.
"""
try:
return self.transport.recoverable_connection_errors
return self.get_transport_cls().recoverable_connection_errors
except AttributeError:
# There were no such classification before,
# and all errors were assumed to be recoverable,
Expand All @@ -948,19 +948,19 @@ def recoverable_channel_errors(self):
recovered from without re-establishing the connection.
"""
try:
return self.transport.recoverable_channel_errors
return self.get_transport_cls().recoverable_channel_errors
except AttributeError:
return ()

@cached_property
def connection_errors(self):
"""List of exceptions that may be raised by the connection."""
return self.transport.connection_errors
return self.get_transport_cls().connection_errors

@cached_property
def channel_errors(self):
"""List of exceptions that may be raised by the channel."""
return self.transport.channel_errors
return self.get_transport_cls().channel_errors

@property
def supports_heartbeats(self):
Expand Down
17 changes: 9 additions & 8 deletions kombu/transport/consul.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,24 +276,25 @@ class Transport(virtual.Transport):
driver_type = 'consul'
driver_name = 'consul'

def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')

super().__init__(*args, **kwargs)

self.connection_errors = (
if consul:
connection_errors = (
virtual.Transport.connection_errors + (
consul.ConsulException, consul.base.ConsulException
)
)

self.channel_errors = (
channel_errors = (
virtual.Transport.channel_errors + (
consul.ConsulException, consul.base.ConsulException
)
)

def __init__(self, *args, **kwargs):
if consul is None:
raise ImportError('Missing python-consul library')

super().__init__(*args, **kwargs)

def verify_connection(self, connection):
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
Expand Down
17 changes: 9 additions & 8 deletions kombu/transport/etcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,22 @@ class Transport(virtual.Transport):
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct']))

if etcd:
connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)

channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)

def __init__(self, *args, **kwargs):
"""Create a new instance of etcd.Transport."""
if etcd is None:
raise ImportError('Missing python-etcd library')

super().__init__(*args, **kwargs)

self.connection_errors = (
virtual.Transport.connection_errors + (etcd.EtcdException, )
)

self.channel_errors = (
virtual.Transport.channel_errors + (etcd.EtcdException, )
)

def verify_connection(self, connection):
"""Verify the connection works."""
port = connection.client.port or self.default_port
Expand Down
9 changes: 3 additions & 6 deletions kombu/transport/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,13 +1214,14 @@ class Transport(virtual.Transport):
exchange_type=frozenset(['direct', 'topic', 'fanout'])
)

if redis:
connection_errors, channel_errors = get_redis_error_classes()

def __init__(self, *args, **kwargs):
if redis is None:
raise ImportError('Missing redis library (pip install redis)')
super().__init__(*args, **kwargs)

# Get redis-py exceptions.
self.connection_errors, self.channel_errors = self._get_errors()
# All channels share the same poller.
self.cycle = MultiChannelPoller()

Expand Down Expand Up @@ -1265,10 +1266,6 @@ def on_readable(self, fileno):
"""Handle AIO event for one of our file descriptors."""
self.cycle.on_readable(fileno)

def _get_errors(self):
"""Utility to import redis-py's exceptions at runtime."""
return get_redis_error_classes()


if sentinel:
class SentinelManagedSSLConnection(
Expand Down
7 changes: 6 additions & 1 deletion t/integration/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import redis

import kombu
from kombu.transport.redis import Transport

from .common import (BaseExchangeTypes, BaseMessage, BasePriority,
BasicFunctionality)
Expand Down Expand Up @@ -56,7 +57,11 @@ def test_failed_credentials():
@pytest.mark.env('redis')
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_RedisBasicFunctionality(BasicFunctionality):
pass
def test_failed_connection__ConnectionError(self, invalid_connection):
# method raises transport exception
with pytest.raises(redis.exceptions.ConnectionError) as ex:
invalid_connection.connection
assert ex.type in Transport.connection_errors


@pytest.mark.env('redis')
Expand Down
94 changes: 91 additions & 3 deletions t/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def test_is_evented(self):
assert not c.is_evented

def test_register_with_event_loop(self):
c = Connection(transport=Mock)
transport = Mock(name='transport')
transport.connection_errors = []
c = Connection(transport=transport)
loop = Mock(name='loop')
c.register_with_event_loop(loop)
c.transport.register_with_event_loop.assert_called_with(
Expand Down Expand Up @@ -477,15 +479,15 @@ class _ConnectionError(Exception):
def publish():
raise _ConnectionError('failed connection')

self.conn.transport.connection_errors = (_ConnectionError,)
self.conn.get_transport_cls().connection_errors = (_ConnectionError,)
ensured = self.conn.ensure(self.conn, publish)
with pytest.raises(OperationalError):
ensured()

def test_autoretry(self):
myfun = Mock()

self.conn.transport.connection_errors = (KeyError,)
self.conn.get_transport_cls().connection_errors = (KeyError,)

def on_call(*args, **kwargs):
myfun.side_effect = None
Expand Down Expand Up @@ -571,6 +573,18 @@ class MyTransport(Transport):
conn = Connection(transport=MyTransport)
assert conn.channel_errors == (KeyError, ValueError)

def test_channel_errors__exception_no_cache(self):
"""Ensure the channel_errors can be retrieved without an initialized
transport.
"""

class MyTransport(Transport):
channel_errors = (KeyError,)

conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.channel_errors == (KeyError,)

def test_connection_errors(self):

class MyTransport(Transport):
Expand All @@ -579,6 +593,80 @@ class MyTransport(Transport):
conn = Connection(transport=MyTransport)
assert conn.connection_errors == (KeyError, ValueError)

def test_connection_errors__exception_no_cache(self):
"""Ensure the connection_errors can be retrieved without an
initialized transport.
"""

class MyTransport(Transport):
connection_errors = (KeyError,)

conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.connection_errors == (KeyError,)

def test_recoverable_connection_errors(self):

class MyTransport(Transport):
recoverable_connection_errors = (KeyError, ValueError)

conn = Connection(transport=MyTransport)
assert conn.recoverable_connection_errors == (KeyError, ValueError)

def test_recoverable_connection_errors__fallback(self):
"""Ensure missing recoverable_connection_errors on the Transport does
not cause a fatal error.
"""

class MyTransport(Transport):
connection_errors = (KeyError,)
channel_errors = (ValueError,)

conn = Connection(transport=MyTransport)
assert conn.recoverable_connection_errors == (KeyError, ValueError)

def test_recoverable_connection_errors__exception_no_cache(self):
"""Ensure the recoverable_connection_errors can be retrieved without
an initialized transport.
"""

class MyTransport(Transport):
recoverable_connection_errors = (KeyError,)

conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.recoverable_connection_errors == (KeyError,)

def test_recoverable_channel_errors(self):

class MyTransport(Transport):
recoverable_channel_errors = (KeyError, ValueError)

conn = Connection(transport=MyTransport)
assert conn.recoverable_channel_errors == (KeyError, ValueError)

def test_recoverable_channel_errors__fallback(self):
"""Ensure missing recoverable_channel_errors on the Transport does not
cause a fatal error.
"""

class MyTransport(Transport):
pass

conn = Connection(transport=MyTransport)
assert conn.recoverable_channel_errors == ()

def test_recoverable_channel_errors__exception_no_cache(self):
"""Ensure the recoverable_channel_errors can be retrieved without an
initialized transport.
"""
class MyTransport(Transport):
recoverable_channel_errors = (KeyError,)

conn = Connection(transport=MyTransport)
MyTransport.__init__ = Mock(side_effect=Exception)
assert conn.recoverable_channel_errors == (KeyError,)

def test_multiple_urls_hostname(self):
conn = Connection(['example.com;amqp://example.com'])
assert conn.as_uri() == 'amqp://guest:**@example.com:5672//'
Expand Down
20 changes: 13 additions & 7 deletions t/unit/transport/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,8 @@ def pipeline(self):

class Transport(redis.Transport):
Channel = Channel

def _get_errors(self):
return ((KeyError,), (IndexError,))
connection_errors = (KeyError,)
channel_errors = (IndexError,)


class test_Channel:
Expand Down Expand Up @@ -907,15 +906,22 @@ def test_transport_on_readable(self):
redis.Transport.on_readable(transport, 13)
cycle.on_readable.assert_called_with(13)

def test_transport_get_errors(self):
assert redis.Transport._get_errors(self.connection.transport)
def test_transport_connection_errors(self):
"""Ensure connection_errors are populated."""
assert redis.Transport.connection_errors

def test_transport_channel_errors(self):
"""Ensure connection_errors are populated."""
assert redis.Transport.channel_errors

def test_transport_driver_version(self):
assert redis.Transport.driver_version(self.connection.transport)

def test_transport_get_errors_when_InvalidData_used(self):
def test_transport_errors_when_InvalidData_used(self):
from redis import exceptions

from kombu.transport.redis import get_redis_error_classes

class ID(Exception):
pass

Expand All @@ -924,7 +930,7 @@ class ID(Exception):
exceptions.InvalidData = ID
exceptions.DataError = None
try:
errors = redis.Transport._get_errors(self.connection.transport)
errors = get_redis_error_classes()
assert errors
assert ID in errors[1]
finally:
Expand Down

0 comments on commit 9c062bd

Please sign in to comment.