Skip to content

Commit

Permalink
take version from pika
Browse files Browse the repository at this point in the history
  • Loading branch information
nozik committed Dec 20, 2021
1 parent 773c289 commit 1f8882b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from logging import getLogger
from typing import Any, Collection, Dict, Optional

import pkg_resources
import wrapt
from packaging import version
import pika
from pika.adapters import BlockingConnection
from pika.adapters.blocking_connection import BlockingChannel

Expand All @@ -34,7 +34,20 @@
_FUNCTIONS_TO_UNINSTRUMENT = ["basic_publish"]


def _consumer_callback_attribute_name() -> str:
pika_version = version.parse(
pika.__version__
)
return (
"on_message_callback"
if pika_version >= version.parse("1.0.0")
else "consumer_cb"
)


class PikaInstrumentor(BaseInstrumentor): # type: ignore
CONSUMER_CALLBACK_ATTR = _consumer_callback_attribute_name()

# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_blocking_channel_consumers(
Expand All @@ -44,7 +57,7 @@ def _instrument_blocking_channel_consumers(
) -> Any:
for consumer_tag, consumer_info in channel._consumer_infos.items():
callback_attr = (
PikaInstrumentor._consumer_callback_attribute_name()
PikaInstrumentor.CONSUMER_CALLBACK_ATTR
)
consumer_callback = getattr(consumer_info, callback_attr)
decorated_callback = utils._decorate_callback(
Expand Down Expand Up @@ -133,7 +146,7 @@ def uninstrument_channel(channel: BlockingChannel) -> None:

for consumers_tag, client_info in channel._consumer_infos.items():
callback_attr = (
PikaInstrumentor._consumer_callback_attribute_name()
PikaInstrumentor.CONSUMER_CALLBACK_ATTR
)
consumer_callback = getattr(client_info, callback_attr)
if hasattr(consumer_callback, "_original_callback"):
Expand All @@ -142,17 +155,6 @@ def uninstrument_channel(channel: BlockingChannel) -> None:
] = consumer_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

@staticmethod
def _consumer_callback_attribute_name() -> str:
pika_version = version.parse(
pkg_resources.get_distribution("pika").version
)
return (
"on_message_callback"
if pika_version >= version.parse("1.0.0")
else "consumer_cb"
)

def _decorate_channel_function(
self,
tracer_provider: Optional[TracerProvider],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestPika(TestCase):
def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel)
consumer_info = mock.MagicMock()
callback_attr = PikaInstrumentor._consumer_callback_attribute_name()
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
setattr(consumer_info, callback_attr, mock.MagicMock())
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock()
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_instrument_consumers(
self, decorate_callback: mock.MagicMock
) -> None:
tracer = mock.MagicMock(spec=Tracer)
callback_attr = PikaInstrumentor._consumer_callback_attribute_name()
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
expected_decoration_calls = [
mock.call(
getattr(value, callback_attr), tracer, key, dummy_callback
Expand Down

0 comments on commit 1f8882b

Please sign in to comment.