Skip to content

Commit

Permalink
take version from pika
Browse files Browse the repository at this point in the history
  • Loading branch information
nozik committed Dec 20, 2021
1 parent 773c289 commit fdba606
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from logging import getLogger
from typing import Any, Collection, Dict, Optional

import pkg_resources
import pika
import wrapt
from packaging import version
from pika.adapters import BlockingConnection
Expand All @@ -34,7 +34,18 @@
_FUNCTIONS_TO_UNINSTRUMENT = ["basic_publish"]


def _consumer_callback_attribute_name() -> str:
pika_version = version.parse(pika.__version__)
return (
"on_message_callback"
if pika_version >= version.parse("1.0.0")
else "consumer_cb"
)


class PikaInstrumentor(BaseInstrumentor): # type: ignore
CONSUMER_CALLBACK_ATTR = _consumer_callback_attribute_name()

# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_blocking_channel_consumers(
Expand All @@ -43,9 +54,7 @@ def _instrument_blocking_channel_consumers(
consume_hook: utils.HookT = utils.dummy_callback,
) -> Any:
for consumer_tag, consumer_info in channel._consumer_infos.items():
callback_attr = (
PikaInstrumentor._consumer_callback_attribute_name()
)
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
consumer_callback = getattr(consumer_info, callback_attr)
decorated_callback = utils._decorate_callback(
consumer_callback,
Expand Down Expand Up @@ -132,27 +141,14 @@ def uninstrument_channel(channel: BlockingChannel) -> None:
return

for consumers_tag, client_info in channel._consumer_infos.items():
callback_attr = (
PikaInstrumentor._consumer_callback_attribute_name()
)
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
consumer_callback = getattr(client_info, callback_attr)
if hasattr(consumer_callback, "_original_callback"):
channel._consumer_infos[
consumers_tag
] = consumer_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

@staticmethod
def _consumer_callback_attribute_name() -> str:
pika_version = version.parse(
pkg_resources.get_distribution("pika").version
)
return (
"on_message_callback"
if pika_version >= version.parse("1.0.0")
else "consumer_cb"
)

def _decorate_channel_function(
self,
tracer_provider: Optional[TracerProvider],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestPika(TestCase):
def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel)
consumer_info = mock.MagicMock()
callback_attr = PikaInstrumentor._consumer_callback_attribute_name()
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
setattr(consumer_info, callback_attr, mock.MagicMock())
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock()
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_instrument_consumers(
self, decorate_callback: mock.MagicMock
) -> None:
tracer = mock.MagicMock(spec=Tracer)
callback_attr = PikaInstrumentor._consumer_callback_attribute_name()
callback_attr = PikaInstrumentor.CONSUMER_CALLBACK_ATTR
expected_decoration_calls = [
mock.call(
getattr(value, callback_attr), tracer, key, dummy_callback
Expand Down

0 comments on commit fdba606

Please sign in to comment.