Skip to content

Commit

Permalink
Support for user-supplied SAS Tokens (#1140)
Browse files Browse the repository at this point in the history
* Support for user-supplied SAS Tokens now available via the `sastoken` parameter when constructing an `IoTHubSession` or a `ProvisioningSession`
* New `.update_sastoken()` method added to both `IoTHubSession` and `ProvisioningSession` to facilitate user-supplied SAS token replacement
* Removed the provisional `sastoken_fn` callback parameter from constructor of both `IoTHubSession` and `ProvisioningSession` as it is no longer necessary
* Added support for connection strings that contain a SAS token via the `.from_connection_string` factory method
* Introduced a new exception, `CredentialError`, which indicates an expired SAS token
* Removed `ExternalSasTokenGenerator` class as it is no longer necessary because `sastoken_fn` is no longer supported
* Renamed `InternalSasTokenGenerator` class to simply `SasTokenGenerator`
* Moved SAS token TTL configuration to the `.generate_sastoken()` method of `SasTokenGenerator` instead of instantiation
* Removed usage of `SasTokenProvider` from underlying MQTT and HTTP clients, however the implementation still is defined (for now)
* Updated docstrings
* Added unit tests for connection strings containing `SharedAccessSignature` and `GatewayHostName`
* Added previously missing unit tests for `sastoken_ttl`
* Enabled support for SAS auth in E2E tests
* Migration guide changes
  • Loading branch information
cartertinney authored May 30, 2023
1 parent f3e84fc commit c6920dc
Show file tree
Hide file tree
Showing 23 changed files with 1,197 additions and 1,421 deletions.
1 change: 1 addition & 0 deletions azure-iot-device/azure/iot/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IoTEdgeEnvironmentError,
ProvisioningServiceError,
SessionError,
CredentialError,
IoTHubClientError,
MQTTError,
MQTTConnectionFailedError,
Expand Down
1 change: 0 additions & 1 deletion azure-iot-device/azure/iot/device/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
self.proxy_options = proxy_options

# Auth
self.sastoken_provider = sastoken_provider
self.ssl_context = ssl_context

# MQTT
Expand Down
27 changes: 16 additions & 11 deletions azure-iot-device/azure/iot/device/connection_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
X509,
]

# TODO: does this module need revision for V3?


class ConnectionString(object):
"""Key/value mappings for connection details.
Expand Down Expand Up @@ -93,18 +91,25 @@ def _parse_connection_string(connection_string):
def _validate_keys(d):
"""Raise ValueError if incorrect combination of keys in dict d"""
host_name = d.get(HOST_NAME)
shared_access_key_name = d.get(SHARED_ACCESS_KEY_NAME)
shared_access_key = d.get(SHARED_ACCESS_KEY)
shared_access_signature = d.get(SHARED_ACCESS_SIGNATURE)
device_id = d.get(DEVICE_ID)
x509 = d.get(X509)

if shared_access_key and x509 and x509.lower() == "true":
# Validate only one type of auth included
auth_count = 0
if shared_access_key:
auth_count += 1
if x509 and x509.lower() == "true":
auth_count += 1
if shared_access_signature:
auth_count += 1

if auth_count > 1:
raise ValueError("Invalid Connection String - Mixed authentication scheme")
elif auth_count < 1:
raise ValueError("Invalid Connection String - No authentication scheme")

# This logic could be expanded to return the category of ConnectionString
if host_name and device_id and (shared_access_key or x509):
pass
elif host_name and shared_access_key and shared_access_key_name:
pass
else:
raise ValueError("Invalid Connection String - Incomplete")
# Validate connection details
if not host_name or not device_id:
raise ValueError("Invalid Connection String - Missing connection details")
1 change: 1 addition & 0 deletions azure-iot-device/azure/iot/device/custom_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

_P = ParamSpec("_P")
_R = TypeVar("_R")
# TODO: This is currently unused. Remove when we're sure it's no longer necessary
FunctionOrCoroutine = Union[Callable[_P, _R], Callable[_P, Awaitable[_R]]]


Expand Down
6 changes: 6 additions & 0 deletions azure-iot-device/azure/iot/device/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class SessionError(Exception):
pass


class CredentialError(Exception):
"""Represents a failure from an invalid auth credential"""

pass


# Service Exceptions
class IoTHubError(Exception):
"""Represents a failure reported by IoT Hub"""
Expand Down
27 changes: 20 additions & 7 deletions azure-iot-device/azure/iot/device/iothub_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import exceptions as exc
from . import config, constant, user_agent
from . import http_path_iothub as http_path
from . import sastoken as st

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,7 +80,19 @@ def __init__(self, client_config: config.IoTHubClientConfig) -> None:

self._session = _create_client_session(client_config.hostname)
self._ssl_context = client_config.ssl_context
self._sastoken_provider = client_config.sastoken_provider
self._sastoken: Optional[st.SasToken] = None
# NOTE: Because SAS tokens have a lifespan that expires, after which they need to be
# replaced, they are not set at instantiation, and are instead set manually via the
# `.set_sastoken()` method to allow for flexible higher level SAS logic.

def set_sastoken(self, sastoken: st.SasToken):
"""Set the current SasToken being used for authentication"""
self._sastoken = sastoken
# NOTE: This actually gets set on the underlying mqtt client during a `.connect()`
# since credentials need to be set whether or not SasTokens are being used.
# NOTE: There is currently no validation on this token expiry. If this module
# continues being used into the future, this should be added, but for now, as this
# module may end up being removed, I did not do extra work to implement that logic.

async def shutdown(self):
"""Shut down the client
Expand Down Expand Up @@ -124,8 +137,8 @@ async def invoke_direct_method(
HEADER_EDGE_MODULE_ID: self._edge_module_id, # TODO: I assume this isn't supposed to be URI encoded just like in MQTT?
}
# If using SAS auth, pass the auth header
if self._sastoken_provider:
headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken())
if self._sastoken:
headers[HEADER_AUTHORIZATION] = str(self._sastoken)

logger.debug(
"Sending direct method invocation request to {device_id}/{module_id}".format(
Expand Down Expand Up @@ -177,8 +190,8 @@ async def get_storage_info_for_blob(self, *, blob_name: str) -> StorageInfo:
# NOTE: Other headers are auto-generated by aiohttp
headers = {HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string)}
# If using SAS auth, pass the auth header
if self._sastoken_provider:
headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken())
if self._sastoken:
headers[HEADER_AUTHORIZATION] = str(self._sastoken)

logger.debug("Sending storage info request to IoTHub...")
async with self._session.post(
Expand Down Expand Up @@ -229,8 +242,8 @@ async def notify_blob_upload_status(
# NOTE: Other headers are auto-generated by aiohttp
headers = {HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string)}
# If using SAS auth, pass the auth header
if self._sastoken_provider:
headers[HEADER_AUTHORIZATION] = str(self._sastoken_provider.get_current_sastoken())
if self._sastoken:
headers[HEADER_AUTHORIZATION] = str(self._sastoken)

logger.debug("Sending blob upload notification to IoTHub...")
async with self._session.post(
Expand Down
51 changes: 37 additions & 14 deletions azure-iot-device/azure/iot/device/iothub_mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from .custom_typing import TwinPatch, Twin
from . import config, constant, user_agent, models
from . import exceptions as exc
from . import request_response as rr
from . import mqtt_client as mqtt
from . import mqtt_topic_iothub as mqtt_topic
from . import sastoken as st
from . import request_response as rr

# TODO: update docstrings with correct class paths once repo structured better
# TODO: If we're truly done with keeping SAS credentials fresh, we don't need to use SasTokenProvider,
# and we could just simply use a single token or generator instead.

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,12 +46,15 @@ def __init__(
product_info=client_config.product_info,
)

# SAS (Optional)
self._sastoken_provider = client_config.sastoken_provider
# SAS
self._sastoken: Optional[st.SasToken] = None
# NOTE: Because SAS tokens have a lifespan that expires, after which they need to be
# replaced, they are not set at instantiation, and are instead set manually via the
# `.set_sastoken()` method to allow for flexible higher level SAS logic.

# MQTT Configuration
self._mqtt_client = _create_mqtt_client(self._client_id, client_config)
# NOTE: credentials are set upon `.start()`
# NOTE: credentials are set upon `.connect()`.

# Add filters for receive topics delivering data used internally
twin_response_topic = mqtt_topic.get_twin_response_topic_for_subscribe()
Expand Down Expand Up @@ -174,20 +176,22 @@ async def _process_twin_responses(self) -> None:
)
)

def set_sastoken(self, sastoken: st.SasToken):
"""Set the current SasToken being used for authentication"""
self._sastoken = sastoken
# NOTE: This actually gets set on the underlying mqtt client during a `.connect()`
# since credentials need to be set whether or not SasTokens are being used.
#
# NOTE: There isn't currently a defined path to "un-set" the token - it could be added
# by allowing for 'None' to be passed through, although that may not be desirable
# semantics.

async def start(self) -> None:
"""Start up the client.
- Must be invoked before any other methods.
- If already started, will not (meaningfully) do anything.
"""
# Set credentials
if self._sastoken_provider:
logger.debug("Using SASToken as password")
password = str(self._sastoken_provider.get_current_sastoken())
else:
logger.debug("No password used")
password = None
self._mqtt_client.set_credentials(self._username, password)
# Start background tasks
if not self._process_twin_responses_bg_task:
self._process_twin_responses_bg_task = asyncio.create_task(
Expand Down Expand Up @@ -219,12 +223,31 @@ async def stop(self) -> None:
# BaseException in Python 3.7
if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError):
raise result
# TODO: Find a way to remove this sleep.
# Not having it causes leak in some E2E tests. It appears that for some reason the
# IoTHubMQTTClient is holding some object references slightly longer than it should be,
# causing a spurious memory leak (the memory is freed, just not immediately).
# This only occurs when using operations that rely on the request/response infrastructure
# (e.g. twin). Presumably it has something to do with the ._process_twin_responses_bg_task
await asyncio.sleep(0.1)

async def connect(self) -> None:
"""Connect to IoTHub
:raises: MQTTConnectionFailedError if there is a failure connecting
:raises: CredentialError if the current SasToken has expired
"""
# Set credentials
if self._sastoken:
logger.debug("Using SASToken as password")
if self._sastoken.is_expired():
raise exc.CredentialError("SAS Token expired - set a new one")
password = str(self._sastoken)
else:
logger.debug("No password used")
password = None
self._mqtt_client.set_credentials(self._username, password)

# Connect
logger.debug("Connecting to IoTHub...")
await self._mqtt_client.connect()
Expand Down
Loading

0 comments on commit c6920dc

Please sign in to comment.