Skip to content

Commit

Permalink
[EventHub] add SAS token auth capabilities to EventHub (#13354)
Browse files Browse the repository at this point in the history
* Add SAS token support to EventHub for connection string via 'sharedaccesssignature'
* Adds changelog/docs/samples/tests, for now, utilize the old-style of test structure for sync vs async instead of preparer until import issue is resolved.
  • Loading branch information
KieranBrantnerMagee authored Sep 4, 2020
1 parent 635b820 commit 39192b7
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 12 deletions.
5 changes: 4 additions & 1 deletion sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Release History

## 5.2.0b2 (Unreleased)
## 5.2.0 (2020-09-08)

**New Features**

- Connection strings used with `from_connection_string` methods now supports using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value.

## 5.2.0b1 (2020-07-06)

Expand Down
54 changes: 49 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils
import six
from azure.core.credentials import AccessToken

from .exceptions import _handle_exception, ClientClosedError, ConnectError
from ._configuration import Configuration
Expand All @@ -43,11 +44,13 @@


def _parse_conn_str(conn_str, kwargs):
# type: (str, Dict[str, Any]) -> Tuple[str, str, str, str]
# type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
Expand All @@ -61,7 +64,16 @@ def _parse_conn_str(conn_str, kwargs):
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
Expand All @@ -72,7 +84,12 @@ def _parse_conn_str(conn_str, kwargs):
host = cast(str, endpoint)[left_slash_pos + 2 :]
else:
host = str(endpoint)
return host, str(shared_access_key_name), str(shared_access_key), entity
return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def _generate_sas_token(uri, policy, key, expiry=None):
Expand Down Expand Up @@ -124,6 +141,30 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
return _generate_sas_token(scopes[0], self.policy, self.key)


class EventHubSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token, expiry):
# type: (str, int) -> None
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ClientBase(object): # pylint:disable=too-many-instance-attributes
def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs):
# type: (str, str, TokenCredential, Any) -> None
Expand All @@ -148,10 +189,13 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
@staticmethod
def _from_connection_string(conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

def _create_auth(self):
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# ------------------------------------

VERSION = "5.2.0b2"
VERSION = "5.2.0"
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Message,
AMQPClientAsync,
)
from azure.core.credentials import AccessToken

from .._client_base import ClientBase, _generate_sas_token, _parse_conn_str
from .._utils import utc_from_timestamp
Expand Down Expand Up @@ -62,6 +63,28 @@ async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
return _generate_sas_token(scopes[0], self.policy, self.key)


class EventHubSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ClientBaseAsync(ClientBase):
def __init__(
self,
Expand All @@ -86,10 +109,13 @@ def __enter__(self):

@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

async def _create_auth_async(self) -> authentication.JWTTokenAsync:
Expand Down
8 changes: 8 additions & 0 deletions sdk/eventhub/azure-eventhub/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,11 @@ def connstr_senders(live_eventhub):
for s in senders:
s.close()
client.close()

# Note: This is duplicated between here and the basic conftest, so that it does not throw warnings if you're
# running locally to this SDK. (Everything works properly, pytest just makes a bit of noise.)
def pytest_configure(config):
# register an additional marker
config.addinivalue_line(
"markers", "liveTest: mark test to be a live test only"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@

import pytest
import asyncio
import datetime
import time

from azure.identity.aio import EnvironmentCredential
from azure.eventhub import EventData
from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient
from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient, EventHubSharedKeyCredential
from azure.eventhub.aio._client_base_async import EventHubSASTokenCredential

from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer
from tests.eventhub_preparer import (
CachedEventHubNamespacePreparer,
CachedEventHubPreparer
)

@pytest.mark.liveTest
@pytest.mark.asyncio
Expand Down Expand Up @@ -43,3 +51,51 @@ def on_event(partition_context, event):
assert on_event.called is True
assert on_event.partition_id == "0"
assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8')


class AsyncEventHubAuthTests(AzureMgmtTestCase):

@pytest.mark.liveTest
@pytest.mark.live_test_only
@CachedResourceGroupPreparer(name_prefix='eventhubtest')
@CachedEventHubNamespacePreparer(name_prefix='eventhubtest')
@CachedEventHubPreparer(name_prefix='eventhubtest')
async def test_client_sas_credential_async(self,
eventhub,
eventhub_namespace,
eventhub_namespace_key_name,
eventhub_namespace_primary_key,
eventhub_namespace_connection_string,
**kwargs):
# This should "just work" to validate known-good.
hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name)
producer_client = EventHubProducerClient.from_connection_string(eventhub_namespace_connection_string, eventhub_name = eventhub.name)

async with producer_client:
batch = await producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await producer_client.send_batch(batch)

# This should also work, but now using SAS tokens.
credential = EventHubSharedKeyCredential(eventhub_namespace_key_name, eventhub_namespace_primary_key)
hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name)
auth_uri = "sb://{}/{}".format(hostname, eventhub.name)
token = (await credential.get_token(auth_uri)).token
producer_client = EventHubProducerClient(fully_qualified_namespace=hostname,
eventhub_name=eventhub.name,
credential=EventHubSASTokenCredential(token, time.time() + 3000))

async with producer_client:
batch = await producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await producer_client.send_batch(batch)

# Finally let's do it with SAS token + conn str
token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode())
conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str,
eventhub_name=eventhub.name)

async with conn_str_producer_client:
batch = await conn_str_producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await conn_str_producer_client.send_batch(batch)
39 changes: 37 additions & 2 deletions sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import pytest
import time
import threading
import datetime

from azure.identity import EnvironmentCredential
from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient

from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential
from azure.eventhub._client_base import EventHubSASTokenCredential

@pytest.mark.liveTest
def test_client_secret_credential(live_eventhub):
Expand Down Expand Up @@ -46,3 +47,37 @@ def on_event(partition_context, event):
assert on_event.called is True
assert on_event.partition_id == "0"
assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8')

@pytest.mark.liveTest
def test_client_sas_credential(live_eventhub):
# This should "just work" to validate known-good.
hostname = live_eventhub['hostname']
producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'])

with producer_client:
batch = producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
producer_client.send_batch(batch)

# This should also work, but now using SAS tokens.
credential = EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])
auth_uri = "sb://{}/{}".format(hostname, live_eventhub['event_hub'])
token = credential.get_token(auth_uri).token
producer_client = EventHubProducerClient(fully_qualified_namespace=hostname,
eventhub_name=live_eventhub['event_hub'],
credential=EventHubSASTokenCredential(token, time.time() + 3000))

with producer_client:
batch = producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
producer_client.send_batch(batch)

# Finally let's do it with SAS token + conn str
token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode())
conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str,
eventhub_name=live_eventhub['event_hub'])

with conn_str_producer_client:
batch = conn_str_producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
conn_str_producer_client.send_batch(batch)

0 comments on commit 39192b7

Please sign in to comment.