Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs] Fix pylint and mypy #21939

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

_LOGGER = logging.getLogger(__name__)
_Address = collections.namedtuple("Address", "hostname path")
_AccessToken = collections.namedtuple("AccessToken", "token expires_on")


def _parse_conn_str(conn_str, **kwargs):
Expand Down Expand Up @@ -127,8 +126,8 @@ def _parse_conn_str(conn_str, **kwargs):


def _generate_sas_token(uri, policy, key, expiry=None):
# type: (str, str, str, Optional[timedelta]) -> _AccessToken
"""Create a shared access signiture token as a string literal.
# type: (str, str, str, Optional[timedelta]) -> AccessToken
"""Create a shared access signature token as a string literal.
:returns: SAS token as string literal.
:rtype: str
"""
Expand All @@ -141,7 +140,7 @@ def _generate_sas_token(uri, policy, key, expiry=None):
encoded_key = key.encode("utf-8")

token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry)
return _AccessToken(token=token, expires_on=abs_expiry)
return AccessToken(token=token, expires_on=abs_expiry)


def _build_uri(address, entity):
Expand Down Expand Up @@ -169,11 +168,12 @@ def __init__(self, policy, key):
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> _AccessToken
# type: (str, Any) -> AccessToken
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)


class EventhubAzureNamedKeyTokenCredential(object):
"""The named key credential used for authentication.

Expand All @@ -187,7 +187,7 @@ def __init__(self, azure_named_key_credential):
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> _AccessToken
# type: (str, Any) -> AccessToken
if not scopes:
raise ValueError("No token scope provided.")
name, key = self._credential.named_key
Expand Down Expand Up @@ -217,6 +217,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
"""
return AccessToken(self.token, self.expiry)


class EventhubAzureSasTokenCredential(object):
"""The shared access token credential used for authentication
when AzureSasCredential is provided.
Expand Down Expand Up @@ -350,6 +351,7 @@ def _backoff(

def _management_request(self, mgmt_msg, op_type):
# type: (Message, bytes) -> Any
# pylint:disable=assignment-from-none
retried_times = 0
last_exception = None
while retried_times <= self._config.max_retries:
Expand All @@ -360,7 +362,7 @@ def _management_request(self, mgmt_msg, op_type):
try:
conn = self._conn_manager.get_connection(
self._address.hostname, mgmt_auth
) # pylint:disable=assignment-from-none
)
mgmt_client.open(connection=conn)
mgmt_msg.application_properties["security_token"] = mgmt_auth.token
response = mgmt_client.mgmt_request(
Expand Down
3 changes: 2 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, body=None):
# Internal usage only for transforming AmqpAnnotatedMessage to outgoing EventData
self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore
data_body=body, annotations={}, application_properties={}
)
)
self.message = (self._raw_amqp_message._message) # pylint:disable=protected-access
self._raw_amqp_message.header = AmqpMessageHeader()
self._raw_amqp_message.properties = AmqpMessageProperties()
Expand Down Expand Up @@ -171,6 +171,7 @@ def __str__(self):
@classmethod
def _from_message(cls, message, raw_amqp_message=None):
# type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData
# pylint:disable=protected-access
"""Internal use only.

Creates an EventData object from a raw uamqp message and, if provided, AmqpAnnotatedMessage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
# --------------------------------------------------------------------------------------------

from typing import TYPE_CHECKING
from threading import Lock
from enum import Enum

from uamqp import Connection, TransportType, c_uamqp
from uamqp import Connection

if TYPE_CHECKING:
from uamqp.authentication import JWTTokenAuth
Expand Down
6 changes: 4 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ def trace_message(event, parent_span=None):


def get_event_links(events):
# pylint:disable=isinstance-second-argument-not-valid-type
trace_events = (
events if isinstance(events, Iterable) else (events,)
) # pylint:disable=isinstance-second-argument-not-valid-type
)
links = []
try:
for event in trace_events: # type: ignore
Expand Down Expand Up @@ -296,6 +297,7 @@ def transform_outbound_single_message(message, message_type):

def decode_with_recurse(data, encoding="UTF-8"):
# type: (Any, str) -> Any
# pylint:disable=isinstance-second-argument-not-valid-type
"""
If data is of a compatible type, iterates through nested structure and decodes all binary
strings with provided encoding.
Expand All @@ -311,7 +313,7 @@ def decode_with_recurse(data, encoding="UTF-8"):
return data.decode(encoding)
if isinstance(data, Mapping):
decoded_mapping = {}
for k,v in data.items():
for k, v in data.items():
decoded_key = decode_with_recurse(k, encoding)
decoded_val = decode_with_recurse(v, encoding)
decoded_mapping[decoded_key] = decoded_val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#--------------------------------------------------------------------------

import sys
import asyncio


def get_dict_with_loop_if_needed(loop):
if sys.version_info >= (3, 10):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, policy: str, key: str):
self.key = key
self.token_type = b"servicebus.windows.net:sastoken"

async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
async def get_token(self, *scopes, **kwargs) -> AccessToken: # pylint:disable=unused-argument
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)
Expand All @@ -84,6 +84,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint
"""
return AccessToken(self.token, self.expiry)


class EventhubAzureNamedKeyTokenCredentialAsync(object):
"""The named key credential used for authentication.

Expand All @@ -96,7 +97,7 @@ def __init__(self, azure_named_key_credential):
self._credential = azure_named_key_credential
self.token_type = b"servicebus.windows.net:sastoken"

async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
async def get_token(self, *scopes, **kwargs) -> AccessToken: # pylint:disable=unused-argument
if not scopes:
raise ValueError("No token scope provided.")
name, key = self._credential.named_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# --------------------------------------------------------------------------------------------

from typing import TYPE_CHECKING
from asyncio import Lock

from uamqp import TransportType, c_uamqp
from uamqp.async_ops import ConnectionAsync

if TYPE_CHECKING:
Expand Down