Skip to content

Commit

Permalink
[PY] feat: SSO - Deduplication (#1890)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1873

## Details

This PR focuses on the sso_dialog operations. Specifically dialogs use a
2 step waterfall process to manage token acquisition and deduplication.
These features generally exist in the JS version of this process, but
are handled pretty differently there.

Some tests are written to provide basic coverage of the file
sso_dialog.py.

#### Change details

Extended the file sso_dialog.py

Added test_sso_dialog.py

General fmt, lint and option config adjustments.

**code snippets**:

**screenshots**:

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes

### Additional information

> Feel free to add other relevant information below
  • Loading branch information
BMS-geodev committed Sep 16, 2024
1 parent 4333547 commit ac0f3df
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/packages/ai/teams/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .auth_options import AuthOptions
from .oauth import OAuth, OAuthDialog, OAuthOptions
from .sign_in_response import AuthErrorReason, SignInResponse, SignInStatus
from .sso import ConfidentialClientApplicationOptions, SsoAuth, SsoOptions
from .sso import ConfidentialClientApplicationOptions, SsoAuth, SsoDialog, SsoOptions

__all__ = [
"SignInStatus",
Expand All @@ -22,5 +22,6 @@
"OAuthDialog",
"SsoAuth",
"SsoOptions",
"SsoDialog",
"ConfidentialClientApplicationOptions",
]
51 changes: 48 additions & 3 deletions python/packages/ai/teams/auth/sso/sso_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
WaterfallDialog,
WaterfallStepContext,
)
from botbuilder.schema import Activity, ActivityTypes
from botbuilder.schema import Activity, ActivityTypes, SignInConstants
from msal import ConfidentialClientApplication

from ...dialogs import Dialog
Expand Down Expand Up @@ -60,7 +60,9 @@ async def sign_in(self, context: TurnContext, state: StateT) -> Optional[str]:
if hasattr(res.result, "token"):
return cast(str, getattr(res.result, "token"))

return await self.sign_in(context, state)
if not getattr(state, "sign_in_retries", 0):
setattr(state, "sign_in_retries", 1)
return await self.sign_in(context, state)

return None

Expand All @@ -74,5 +76,48 @@ async def _step_one(self, context: WaterfallStepContext) -> DialogTurnResult:
async def _step_two(self, context: WaterfallStepContext) -> DialogTurnResult:
token_response = context.result

# TODO: Dedup token exchange responses
if token_response and await self._should_dedup(context.context):
return DialogTurnResult(DialogTurnStatus.Waiting)

return await context.end_dialog(token_response)

async def _should_dedup(self, context: TurnContext) -> bool:
"""
Checks if deduplication should be performed for token exchange.
"""
etag = context.activity.value.get("id")
store_item = {"eTag": etag}
key = self._get_storage_key(context)

try:
await self._options.storage.write({key: store_item})
except Exception as e:
if "eTag conflict" in str(e):
return True
raise e

return False

def _get_storage_key(self, context: TurnContext) -> str:
"""
Gets the storage key for storing the token exchange state.
"""
if not context or not context.activity or not context.activity.conversation:
raise ValueError("Invalid context, cannot get storage key!")

activity = context.activity
if not (
activity.type == ActivityTypes.invoke
and activity.name == SignInConstants.token_exchange_operation_name
):
raise ValueError(
"TokenExchangeState can only be used with Invokes of signin/tokenExchange."
)

value_id = activity.value.get("id")
if not value_id:
raise ValueError("Invalid signin/tokenExchange. Missing activity.value.id.")

channel_id = activity.channel_id
conversation_id = activity.conversation.id
return f"{channel_id}/{conversation_id}/{value_id}"
4 changes: 3 additions & 1 deletion python/packages/ai/teams/auth/sso/sso_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List
from typing import Any, List

from dataclasses_json import DataClassJsonMixin, dataclass_json

Expand Down Expand Up @@ -36,6 +36,8 @@ class SsoOptions(DataClassJsonMixin):
# Whether auth should end upon receiving an invalid message.
# Only works in conversational bot scenario.

storage: Any = None


@dataclass_json
@dataclass
Expand Down
101 changes: 101 additions & 0 deletions python/packages/ai/tests/auth/test_sso_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

import unittest
from unittest.mock import AsyncMock, MagicMock

from botbuilder.dialogs import DialogTurnResult, DialogTurnStatus
from msal import ConfidentialClientApplication

from teams.auth import ConfidentialClientApplicationOptions, SsoDialog, SsoOptions
from teams.state import TurnState


class TestSsoDialog(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.storage_mock = AsyncMock()
self.msal_config = ConfidentialClientApplicationOptions(
client_id="client_id",
authority="https://login.microsoftonline.com/common",
client_secret="client_secret",
)

self.options = SsoOptions(
scopes=["User.Read"],
msal_config=self.msal_config,
sign_in_link="https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
timeout=900000,
end_on_invalid_message=True,
storage=self.storage_mock,
)

self.msal_app = MagicMock(spec=ConfidentialClientApplication)
self.sso_dialog = SsoDialog("test_sso", self.options, self.msal_app)

self.context = self.create_mock_context()
self.state = await TurnState.load(self.context)

def create_mock_context(self):
context = MagicMock()
activity = MagicMock()
activity.type = "message"
activity.text = "dummy_text"
activity.channel_id = "msteams"
activity.from_property = MagicMock(id="user_id", aad_object_id="aad_object_id")
activity.conversation = MagicMock(id="conversation_id")
activity.value = {"token": "dummy_token", "id": "dummy_id"}
context.activity = activity
return context

async def test_step_one(self):
waterfall_step_context = MagicMock()
waterfall_step_context.begin_dialog = AsyncMock(
return_value=DialogTurnResult(DialogTurnStatus.Waiting)
)

result = await self.sso_dialog._step_one(waterfall_step_context)

waterfall_step_context.begin_dialog.assert_called_once_with("TeamsSsoPrompt")
self.assertEqual(result.status, DialogTurnStatus.Waiting)

async def test_step_two_no_dedup_conflict(self):
waterfall_step_context = MagicMock()
waterfall_step_context.result = {"token": "new_access_token"}

class TempState:
duplicate_token_exchange = False

waterfall_step_context.context.state.temp = TempState()

self.sso_dialog._should_dedup = AsyncMock(return_value=True)
waterfall_step_context.end_dialog = AsyncMock(
return_value=DialogTurnResult(DialogTurnStatus.Waiting)
)

result = await self.sso_dialog._step_two(waterfall_step_context)

self.sso_dialog._should_dedup.assert_called_once_with(waterfall_step_context.context)
self.assertFalse(waterfall_step_context.context.state.temp.duplicate_token_exchange)
self.assertEqual(result.status, DialogTurnStatus.Waiting)

async def test_step_two_dedup_conflict(self):
waterfall_step_context = MagicMock()
waterfall_step_context.result = {"token": "new_access_token"}

class TempState:
duplicate_token_exchange = True

waterfall_step_context.context.state.temp = TempState()

self.sso_dialog._should_dedup = AsyncMock(return_value=True)
waterfall_step_context.end_dialog = AsyncMock(
return_value=DialogTurnResult(DialogTurnStatus.Waiting)
)

result = await self.sso_dialog._step_two(waterfall_step_context)

self.sso_dialog._should_dedup.assert_called_once_with(waterfall_step_context.context)
self.assertTrue(waterfall_step_context.context.state.temp.duplicate_token_exchange)
self.assertEqual(result.status, DialogTurnStatus.Waiting)

0 comments on commit ac0f3df

Please sign in to comment.