From 95ef0cbdb1c20fa45029bc695d604e618ba2f777 Mon Sep 17 00:00:00 2001 From: swathipil <76007337+swathipil@users.noreply.github.com> Date: Tue, 10 Aug 2021 09:52:35 -0700 Subject: [PATCH] raise error for loop if Python 3.10 (#261) * raise error for loop * update sample * adams comments * adams comments * fix * update impl * add loop props * update test * create folder * comment test and add later * add async loop test to async samples * remove unnecessary imports --- dev_requirements.txt | 1 + .../test_azure_iothub_cli_extension.py | 4 +- samples/asynctests/test_loop_param_async.py | 59 ++++++++++++++++ uamqp/async_ops/client_async.py | 68 ++++++++----------- uamqp/async_ops/connection_async.py | 21 +++--- uamqp/async_ops/mgmt_operation_async.py | 10 +-- uamqp/async_ops/receiver_async.py | 14 ++-- uamqp/async_ops/sender_async.py | 12 ++-- uamqp/async_ops/session_async.py | 13 ++-- uamqp/async_ops/utils.py | 13 ++++ uamqp/authentication/cbs_auth_async.py | 18 +++-- 11 files changed, 157 insertions(+), 76 deletions(-) create mode 100644 samples/asynctests/test_loop_param_async.py create mode 100644 uamqp/async_ops/utils.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 9fa2caed8..504d8c294 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,7 @@ cython==0.29.21 setuptools>=27.1.2 wheel>=0.32.0 +pytest==6.2.4; python_version >= '3.10' pytest==5.4.1; python_version >= '3.6' pytest==4.6.9; python_version == '2.7' pytest-asyncio==0.10.0; python_version >= '3.6' diff --git a/samples/asynctests/test_azure_iothub_cli_extension.py b/samples/asynctests/test_azure_iothub_cli_extension.py index 4fe92597a..c68ea5d34 100644 --- a/samples/asynctests/test_azure_iothub_cli_extension.py +++ b/samples/asynctests/test_azure_iothub_cli_extension.py @@ -93,7 +93,7 @@ def executor(target, consumer_group, enqueued_time, device_id=None, properties=N loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - future = asyncio.gather(*coroutines, loop=loop, return_exceptions=True) + future = asyncio.gather(*coroutines, return_exceptions=True) result = None try: @@ -112,7 +112,7 @@ def stop_and_suppress_eloop(): except KeyboardInterrupt: print('Stopping event monitor...') remaining_tasks = [t for t in asyncio.Task.all_tasks() if not t.done()] - remaining_future = asyncio.gather(*remaining_tasks, loop=loop, return_exceptions=True) + remaining_future = asyncio.gather(*remaining_tasks, return_exceptions=True) try: loop.run_until_complete(asyncio.wait_for(remaining_future, 5)) except concurrent.futures.TimeoutError: diff --git a/samples/asynctests/test_loop_param_async.py b/samples/asynctests/test_loop_param_async.py new file mode 100644 index 000000000..acda0b765 --- /dev/null +++ b/samples/asynctests/test_loop_param_async.py @@ -0,0 +1,59 @@ +import sys +import pytest + +import asyncio +from uamqp.async_ops.mgmt_operation_async import MgmtOperationAsync +from uamqp.async_ops.receiver_async import MessageReceiverAsync +from uamqp.authentication.cbs_auth_async import CBSAsyncAuthMixin +from uamqp.async_ops.sender_async import MessageSenderAsync +from uamqp.async_ops.client_async import ( + AMQPClientAsync, + SendClientAsync, + ReceiveClientAsync, + ConnectionAsync, +) + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="raise error if loop passed in >=3.10") +async def test_error_loop_arg_async(): + with pytest.raises(ValueError) as e: + AMQPClientAsync("fake_addr", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + client_async = AMQPClientAsync("sb://resourcename.servicebus.windows.net/") + assert len(client_async._internal_kwargs) == 0 # pylint:disable=protected-access + + with pytest.raises(ValueError) as e: + SendClientAsync("fake_addr", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + client_async = SendClientAsync("sb://resourcename.servicebus.windows.net/") + assert len(client_async._internal_kwargs) == 0 # pylint:disable=protected-access + + with pytest.raises(ValueError) as e: + ReceiveClientAsync("fake_addr", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + client_async = ReceiveClientAsync("sb://resourcename.servicebus.windows.net/") + assert len(client_async._internal_kwargs) == 0 # pylint:disable=protected-access + + with pytest.raises(ValueError) as e: + ConnectionAsync("fake_addr", sasl='fake_sasl', loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + + with pytest.raises(ValueError) as e: + MgmtOperationAsync("fake_addr", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + + with pytest.raises(ValueError) as e: + MessageReceiverAsync("fake_addr", "session", "target", "on_message_received", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + + with pytest.raises(ValueError) as e: + MessageSenderAsync("fake_addr", "source", "target", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + + async def auth_async_loop(): + auth_async = CBSAsyncAuthMixin() + with pytest.raises(ValueError) as e: + await auth_async.create_authenticator_async("fake_conn", loop=asyncio.get_event_loop()) + assert "no longer supports loop" in e + loop = asyncio.get_event_loop() + loop.run_until_complete(auth_async_loop()) diff --git a/uamqp/async_ops/client_async.py b/uamqp/async_ops/client_async.py index 504fda28b..aa39a2495 100644 --- a/uamqp/async_ops/client_async.py +++ b/uamqp/async_ops/client_async.py @@ -13,11 +13,11 @@ import uuid from uamqp import address, authentication, client, constants, errors, compat, c_uamqp -from uamqp.utils import get_running_loop from uamqp.async_ops.connection_async import ConnectionAsync from uamqp.async_ops.receiver_async import MessageReceiverAsync from uamqp.async_ops.sender_async import MessageSenderAsync from uamqp.async_ops.session_async import SessionAsync +from uamqp.async_ops.utils import get_dict_with_loop_if_needed try: TimeoutException = TimeoutError @@ -43,8 +43,6 @@ class AMQPClientAsync(client.AMQPClient): :param client_name: The name for the client, also known as the Container ID. If no name is provided, a random GUID will be used. :type client_name: str or bytes - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop :param debug: Whether to turn on network trace logs. If `True`, trace logs will be logged at INFO level. Default is `False`. :type debug: bool @@ -105,14 +103,7 @@ def __init__( keep_alive_interval=None, **kwargs): - if loop: - self.loop = loop - else: - try: - if not self.loop: # from sub class instance - self.loop = get_running_loop() - except AttributeError: - self.loop = get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(AMQPClientAsync, self).__init__( remote_address, @@ -146,9 +137,9 @@ async def _keep_alive_async(self): _logger.info("Keeping %r connection alive. %r", self.__class__.__name__, self._connection.container_id) - await asyncio.shield(self._connection.work_async(), loop=self.loop) + await asyncio.shield(self._connection.work_async(), **self._internal_kwargs) start_time = current_time - await asyncio.sleep(1, loop=self.loop) + await asyncio.sleep(1, **self._internal_kwargs) except Exception as e: # pylint: disable=broad-except _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) @@ -163,7 +154,7 @@ async def _client_ready_async(self): # pylint: disable=no-self-use async def _client_run_async(self): """Perform a single Connection iteration.""" - await asyncio.shield(self._connection.work_async(), loop=self.loop) + await asyncio.shield(self._connection.work_async(), **self._internal_kwargs) async def _redirect_async(self, redirect, auth): """Redirect the client endpoint using a Link DETACH redirect @@ -177,7 +168,7 @@ async def _redirect_async(self, redirect, auth): # pylint: disable=protected-access if not self._connection._cbs: _logger.info("Closing non-CBS session.") - await asyncio.shield(self._session.destroy_async(), loop=self.loop) + await asyncio.shield(self._session.destroy_async(), **self._internal_kwargs) self._session = None self._auth = auth self._hostname = self._remote_address.hostname @@ -197,8 +188,8 @@ async def _build_session_async(self): outgoing_window=self._outgoing_window, handle_max=self._handle_max, on_attach=self._on_attach, - loop=self.loop), - loop=self.loop) + **self._internal_kwargs), + **self._internal_kwargs) self._session = self._auth._session # pylint: disable=protected-access elif self._connection._cbs: self._session = self._auth._session # pylint: disable=protected-access @@ -209,7 +200,11 @@ async def _build_session_async(self): outgoing_window=self._outgoing_window, handle_max=self._handle_max, on_attach=self._on_attach, - loop=self.loop) + **self._internal_kwargs) + + @property + def loop(self): + return self._internal_kwargs.get("loop") async def open_async(self, connection=None): """Asynchronously open the client. The client can create a new Connection @@ -242,10 +237,10 @@ async def open_async(self, connection=None): remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, error_policy=self._error_policy, debug=self._debug_trace, - loop=self.loop) + **self._internal_kwargs) await self._build_session_async() if self._keep_alive_interval: - self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_async(), loop=self.loop) + self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_async(), **self._internal_kwargs) finally: if self._ext_connection: connection.release_async() @@ -267,13 +262,13 @@ async def close_async(self): return # already closed. if not self._connection._cbs: # pylint: disable=protected-access _logger.info("Closing non-CBS session.") - await asyncio.shield(self._session.destroy_async(), loop=self.loop) + await asyncio.shield(self._session.destroy_async(), **self._internal_kwargs) else: _logger.info("CBS session pending %r.", self._connection.container_id) self._session = None if not self._ext_connection: _logger.info("Closing exclusive connection %r.", self._connection.container_id) - await asyncio.shield(self._connection.destroy_async(), loop=self.loop) + await asyncio.shield(self._connection.destroy_async(), **self._internal_kwargs) else: _logger.info("Shared connection remaining open.") self._connection = None @@ -314,7 +309,7 @@ async def mgmt_request_async(self, message, operation, op_type=None, node=None, :rtype: ~uamqp.message.Message """ while not await self.auth_complete_async(): - await asyncio.sleep(0.05, loop=self.loop) + await asyncio.sleep(0.05, **self._internal_kwargs) response = await asyncio.shield( self._session.mgmt_request_async( message, @@ -325,7 +320,7 @@ async def mgmt_request_async(self, message, operation, op_type=None, node=None, encoding=self._encoding, debug=self._debug_trace, **kwargs), - loop=self.loop) + **self._internal_kwargs) return response async def auth_complete_async(self): @@ -396,8 +391,6 @@ class SendClientAsync(client.SendClient, AMQPClientAsync): :param client_name: The name for the client, also known as the Container ID. If no name is provided, a random GUID will be used. :type client_name: str or bytes - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop :param debug: Whether to turn on network trace logs. If `True`, trace logs will be logged at INFO level. Default is `False`. :type debug: bool @@ -474,7 +467,7 @@ def __init__( error_policy=None, keep_alive_interval=None, **kwargs): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) client.SendClient.__init__( self, target, @@ -488,7 +481,7 @@ def __init__( # AMQP object settings self.sender_type = MessageSenderAsync - self._pending_messages_lock = asyncio.Lock(loop=self.loop) + self._pending_messages_lock = asyncio.Lock(**self._internal_kwargs) async def _client_ready_async(self): """Determine whether the client is ready to start sending messages. @@ -513,8 +506,8 @@ async def _client_ready_async(self): error_policy=self._error_policy, encoding=self._encoding, desired_capabilities=self._desired_capabilities, - loop=self.loop) - await asyncio.shield(self.message_handler.open_async(), loop=self.loop) + **self._internal_kwargs) + await asyncio.shield(self.message_handler.open_async(), **self._internal_kwargs) return False if self.message_handler.get_state() == constants.MessageSenderState.Error: raise errors.MessageHandlerError( @@ -528,7 +521,8 @@ async def _client_ready_async(self): async def _transfer_message_async(self, message, timeout): sent = await asyncio.shield( self.message_handler.send_async(message, self._on_message_sent, timeout=timeout), - loop=self.loop) + **self._internal_kwargs + ) if not sent: _logger.info("Message not sent, raising RuntimeError.") raise RuntimeError("Message sender failed to add message data to outgoing queue.") @@ -567,7 +561,7 @@ async def _client_run_async(self): """ # pylint: disable=protected-access await self.message_handler.work_async() - await asyncio.shield(self._connection.work_async(), loop=self.loop) + await asyncio.shield(self._connection.work_async(), **self._internal_kwargs) if self._connection._state == c_uamqp.ConnectionState.DISCARDING: raise errors.ConnectionClose(constants.ErrorCodes.InternalServerError) self._waiting_messages = 0 @@ -692,8 +686,6 @@ class ReceiveClientAsync(client.ReceiveClient, AMQPClientAsync): :param client_name: The name for the client, also known as the Container ID. If no name is provided, a random GUID will be used. :type client_name: str or bytes - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop :param debug: Whether to turn on network trace logs. If `True`, trace logs will be logged at INFO level. Default is `False`. :type debug: bool @@ -782,7 +774,7 @@ def __init__( error_policy=None, keep_alive_interval=None, **kwargs): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) client.ReceiveClient.__init__( self, source, @@ -823,8 +815,8 @@ async def _client_ready_async(self): error_policy=self._error_policy, encoding=self._encoding, desired_capabilities=self._desired_capabilities, - loop=self.loop) - await asyncio.shield(self.message_handler.open_async(), loop=self.loop) + ) + await asyncio.shield(self.message_handler.open_async(), **self._internal_kwargs) return False if self.message_handler.get_state() == constants.MessageReceiverState.Error: raise errors.MessageHandlerError( @@ -850,7 +842,7 @@ async def _client_run_async(self): now = self._counter.get_current_ms() if self._last_activity_timestamp and not self._was_message_received: # If no messages are coming through, back off a little to keep CPU use low. - await asyncio.sleep(0.05, loop=self.loop) + await asyncio.sleep(0.05, **self._internal_kwargs) if self._timeout > 0: timespan = now - self._last_activity_timestamp if timespan >= self._timeout: diff --git a/uamqp/async_ops/connection_async.py b/uamqp/async_ops/connection_async.py index 1fd379f8c..997e0a014 100644 --- a/uamqp/async_ops/connection_async.py +++ b/uamqp/async_ops/connection_async.py @@ -4,12 +4,13 @@ # license information. #-------------------------------------------------------------------------- +import sys import asyncio import logging import uamqp from uamqp import c_uamqp, connection -from uamqp.utils import get_running_loop +from uamqp.async_ops.utils import get_dict_with_loop_if_needed _logger = logging.getLogger(__name__) @@ -56,8 +57,6 @@ class ConnectionAsync(connection.Connection): :param encoding: The encoding to use for parameters supplied as strings. Default is 'UTF-8' :type encoding: str - :param loop: A user specified event loop. - :type loop: ~asyncio.AbstractEventLoop """ def __init__(self, hostname, sasl, @@ -71,7 +70,7 @@ def __init__(self, hostname, sasl, debug=False, encoding='UTF-8', loop=None): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(ConnectionAsync, self).__init__( hostname, sasl, container_id=container_id, @@ -83,7 +82,7 @@ def __init__(self, hostname, sasl, error_policy=error_policy, debug=debug, encoding=encoding) - self._async_lock = asyncio.Lock(loop=self.loop) + self._async_lock = asyncio.Lock(**self._internal_kwargs) async def __aenter__(self): """Open the Connection in an async context manager.""" @@ -105,8 +104,12 @@ async def _close_async(self): self.auth.close() _logger.info("Connection shutdown complete %r.", self.container_id) + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def lock_async(self, timeout=3.0): - await asyncio.wait_for(self._async_lock.acquire(), timeout=timeout, loop=self.loop) + await asyncio.wait_for(self._async_lock.acquire(), timeout=timeout, **self._internal_kwargs) def release_async(self): try: @@ -135,12 +138,12 @@ async def work_async(self): if self._closing: _logger.debug("Connection unlocked but shutting down.") return - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0, **self._internal_kwargs) self._conn.do_work() except asyncio.TimeoutError: _logger.debug("Connection %r timed out while waiting for lock acquisition.", self.container_id) finally: - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0, **self._internal_kwargs) self.release_async() async def sleep_async(self, seconds): @@ -151,7 +154,7 @@ async def sleep_async(self, seconds): """ try: await self.lock_async() - await asyncio.sleep(seconds, loop=self.loop) + await asyncio.sleep(seconds, **self._internal_kwargs) except asyncio.TimeoutError: _logger.debug("Connection %r timed out while waiting for lock acquisition.", self.container_id) finally: diff --git a/uamqp/async_ops/mgmt_operation_async.py b/uamqp/async_ops/mgmt_operation_async.py index 1a28b8778..c22d2bd42 100644 --- a/uamqp/async_ops/mgmt_operation_async.py +++ b/uamqp/async_ops/mgmt_operation_async.py @@ -8,9 +8,9 @@ import uuid from uamqp import Message, constants, errors -from uamqp.utils import get_running_loop #from uamqp.session import Session from uamqp.mgmt_operation import MgmtOperation +from uamqp.async_ops.utils import get_dict_with_loop_if_needed try: TimeoutException = TimeoutError @@ -42,8 +42,6 @@ class MgmtOperationAsync(MgmtOperation): :param encoding: The encoding to use for parameters supplied as strings. Default is 'UTF-8' :type encoding: str - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop """ def __init__(self, @@ -54,7 +52,7 @@ def __init__(self, description_fields=b'statusDescription', encoding='UTF-8', loop=None): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(MgmtOperationAsync, self).__init__( session, target=target, @@ -63,6 +61,10 @@ def __init__(self, description_fields=description_fields, encoding=encoding) + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def execute_async(self, operation, op_type, message, timeout=0): """Execute a request and wait on a response asynchronously. diff --git a/uamqp/async_ops/receiver_async.py b/uamqp/async_ops/receiver_async.py index 9630a3383..45eb5800c 100644 --- a/uamqp/async_ops/receiver_async.py +++ b/uamqp/async_ops/receiver_async.py @@ -8,7 +8,7 @@ import asyncio from uamqp import constants, errors, receiver -from uamqp.utils import get_running_loop +from uamqp.async_ops.utils import get_dict_with_loop_if_needed _logger = logging.getLogger(__name__) @@ -71,8 +71,6 @@ class MessageReceiverAsync(receiver.MessageReceiver): :param encoding: The encoding to use for parameters supplied as strings. Default is 'UTF-8' :type encoding: str - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop """ def __init__(self, session, source, target, @@ -88,7 +86,7 @@ def __init__(self, session, source, target, encoding='UTF-8', desired_capabilities=None, loop=None): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(MessageReceiverAsync, self).__init__( session, source, target, on_message_received, @@ -112,6 +110,10 @@ async def __aexit__(self, *args): """Close the MessageReceiver when exiting an async context manager.""" await self.destroy_async() + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def destroy_async(self): """Asynchronously close both the Receiver and the Link. Clean up any C objects.""" self.destroy() @@ -133,7 +135,7 @@ async def open_async(self): async def work_async(self): """Update the link status.""" - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0, **self._internal_kwargs) self._link.do_work() async def reset_link_credit_async(self, link_credit, **kwargs): @@ -142,7 +144,7 @@ async def reset_link_credit_async(self, link_credit, **kwargs): :param link_credit: The link credit amount that is requested. :type link_credit: int """ - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0, **self._internal_kwargs) drain = kwargs.get("drain", False) self._link.reset_link_credit(link_credit, drain) diff --git a/uamqp/async_ops/sender_async.py b/uamqp/async_ops/sender_async.py index 2c622a526..d8f65f663 100644 --- a/uamqp/async_ops/sender_async.py +++ b/uamqp/async_ops/sender_async.py @@ -8,7 +8,7 @@ import asyncio from uamqp import constants, errors, sender -from uamqp.utils import get_running_loop +from uamqp.async_ops.utils import get_dict_with_loop_if_needed _logger = logging.getLogger(__name__) @@ -76,8 +76,6 @@ class MessageSenderAsync(sender.MessageSender): :param encoding: The encoding to use for parameters supplied as strings. Default is 'UTF-8' :type encoding: str - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop """ def __init__(self, session, source, target, @@ -92,7 +90,7 @@ def __init__(self, session, source, target, encoding='UTF-8', desired_capabilities=None, loop=None): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(MessageSenderAsync, self).__init__( session, source, target, name=name, @@ -115,6 +113,10 @@ async def __aexit__(self, *args): """Close the MessageSender when exiting an async context manager.""" await self.destroy_async() + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def destroy_async(self): """Asynchronously close both the Sender and the Link. Clean up any C objects.""" self.destroy() @@ -167,7 +169,7 @@ async def send_async(self, message, callback, timeout=0): async def work_async(self): """Update the link status.""" - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0, **self._internal_kwargs) self._link.do_work() async def close_async(self): diff --git a/uamqp/async_ops/session_async.py b/uamqp/async_ops/session_async.py index 6c82c2057..bc7fb54c1 100644 --- a/uamqp/async_ops/session_async.py +++ b/uamqp/async_ops/session_async.py @@ -7,8 +7,8 @@ import logging from uamqp import constants, errors, session -from uamqp.utils import get_running_loop from uamqp.async_ops.mgmt_operation_async import MgmtOperationAsync +from uamqp.async_ops.utils import get_dict_with_loop_if_needed _logger = logging.getLogger(__name__) @@ -36,8 +36,6 @@ class SessionAsync(session.Session): :param on_attach: A callback function to be run on receipt of an ATTACH frame. The function must take 4 arguments: source, target, properties and error. :type on_attach: func[~uamqp.address.Source, ~uamqp.address.Target, dict, ~uamqp.errors.AMQPConnectionError] - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop """ def __init__(self, connection, @@ -46,7 +44,7 @@ def __init__(self, connection, handle_max=None, on_attach=None, loop=None): - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) super(SessionAsync, self).__init__( connection, incoming_window=incoming_window, @@ -62,6 +60,10 @@ async def __aexit__(self, *args): """Close and destroy sesion on exiting an async context manager.""" await self.destroy_async() + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def mgmt_request_async(self, message, operation, op_type=None, node=b'$management', **kwargs): """Asynchronously run a request/response operation. These are frequently used for management tasks against a $management node, however any node name can be @@ -100,7 +102,8 @@ async def mgmt_request_async(self, message, operation, op_type=None, node=b'$man try: mgmt_link = self._mgmt_links[node] except KeyError: - mgmt_link = MgmtOperationAsync(self, target=node, loop=self.loop, **kwargs) + kwargs.update(self._internal_kwargs) + mgmt_link = MgmtOperationAsync(self, target=node, **kwargs) self._mgmt_links[node] = mgmt_link while not mgmt_link.open and not mgmt_link.mgmt_error: await self._connection.work_async() diff --git a/uamqp/async_ops/utils.py b/uamqp/async_ops/utils.py new file mode 100644 index 000000000..65a94448b --- /dev/null +++ b/uamqp/async_ops/utils.py @@ -0,0 +1,13 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +import sys +from uamqp.utils import get_running_loop + +def get_dict_with_loop_if_needed(loop): + if sys.version_info >= (3, 10) and loop: + raise ValueError("Starting Python 3.10, asyncio no longer supports loop as a parameter.") + return {'loop': loop or get_running_loop()} if sys.version_info < (3, 10) else {} diff --git a/uamqp/authentication/cbs_auth_async.py b/uamqp/authentication/cbs_auth_async.py index 043a869af..af7a0bb5e 100644 --- a/uamqp/authentication/cbs_auth_async.py +++ b/uamqp/authentication/cbs_auth_async.py @@ -11,9 +11,9 @@ import logging from uamqp import c_uamqp, compat, constants, errors -from uamqp.utils import get_running_loop from uamqp.async_ops import SessionAsync from uamqp.constants import TransportType +from uamqp.async_ops.utils import get_dict_with_loop_if_needed from .cbs_auth import CBSAuthMixin, SASTokenAuth, JWTTokenAuth, TokenRetryPolicy @@ -36,6 +36,10 @@ def is_coroutine(get_token): class CBSAsyncAuthMixin(CBSAuthMixin): """Mixin to handle sending and refreshing CBS auth tokens asynchronously.""" + @property + def loop(self): + return self._internal_kwargs.get("loop") + async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs): """Create the async AMQP session and the CBS channel with which to negotiate the token. @@ -46,13 +50,12 @@ async def create_authenticator_async(self, connection, debug=False, loop=None, * :param debug: Whether to emit network trace logging events for the CBS session. Default is `False`. Logging events are set at INFO level. :type debug: bool - :param loop: A user specified event loop. - :type loop: ~asycnio.AbstractEventLoop :rtype: uamqp.c_uamqp.CBSTokenAuth """ - self.loop = loop or get_running_loop() + self._internal_kwargs = get_dict_with_loop_if_needed(loop) self._connection = connection - self._session = SessionAsync(connection, loop=self.loop, **kwargs) + kwargs.update(self._internal_kwargs) + self._session = SessionAsync(connection, **kwargs) try: self._cbs_auth = c_uamqp.CBSTokenAuth( @@ -114,7 +117,7 @@ async def handle_token_async(self): _logger.info("Authentication status: %r, description: %r", error_code, error_description) _logger.info("Authentication Put-Token failed. Retrying.") self.retries += 1 # pylint: disable=no-member - await asyncio.sleep(self._retry_policy.backoff, loop=self.loop) + await asyncio.sleep(self._retry_policy.backoff, **self._internal_kwargs) self._cbs_auth.authenticate() in_progress = True elif auth_status == constants.CBSAuthStatus.Failure: @@ -277,7 +280,8 @@ def __init__(self, audience, uri, async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs): await self.update_token() - return await super(JWTTokenAsync, self).create_authenticator_async(connection, debug, loop, **kwargs) + kwargs.update(self._internal_kwargs) + return await super(JWTTokenAsync, self).create_authenticator_async(connection, debug, **kwargs) async def update_token(self): access_token = await self.get_token()