Skip to content

Commit

Permalink
add aad support (Azure#19604)
Browse files Browse the repository at this point in the history
* add aad support

* update

* update
  • Loading branch information
xiangyan99 authored and rakshith91 committed Jul 16, 2021
1 parent 6b76926 commit 4150ded
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 82 deletions.
3 changes: 2 additions & 1 deletion sdk/search/azure-search-documents/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Release History

## 11.3.0b1 (Unreleased)
## 11.3.0b1 (2021-07-07)

### Features Added

- Added AAD support
- Added support for semantic search
- Added normalizer support

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def _headers(self):

def _merge_client_headers(self, headers):
# type(Optional[dict]) -> dict
if self._aad:
return headers
headers = headers or {}
combined = self._headers
combined.update(headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast, List, TYPE_CHECKING
import six

from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator import distributed_trace
from ._api_versions import DEFAULT_VERSION
from ._generated import SearchClient as SearchIndexClient
Expand All @@ -15,12 +16,13 @@
from ._paging import SearchItemPaged, SearchPageIterator
from ._queries import AutocompleteQuery, SearchQuery, SuggestQuery
from ._headers_mixin import HeadersMixin
from ._utils import get_authentication_policy
from ._version import SDK_MONIKER

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Union
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials import TokenCredential


def odata(statement, **kwargs):
Expand Down Expand Up @@ -59,7 +61,7 @@ class SearchClient(HeadersMixin):
:param index_name: The name of the index to connect to
:type index_name: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials.AzureKeyCredential
:type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential
:keyword str api_version: The Search API version to use for requests.
.. admonition:: Example:
Expand All @@ -75,18 +77,30 @@ class SearchClient(HeadersMixin):
_ODATA_ACCEPT = "application/json;odata.metadata=none" # type: str

def __init__(self, endpoint, index_name, credential, **kwargs):
# type: (str, str, AzureKeyCredential, **Any) -> None
# type: (str, str, Union[AzureKeyCredential, TokenCredential], **Any) -> None

self._api_version = kwargs.pop("api_version", DEFAULT_VERSION)
self._endpoint = endpoint # type: str
self._index_name = index_name # type: str
self._credential = credential # type: AzureKeyCredential
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
self._credential = credential
if isinstance(credential, AzureKeyCredential):
self._aad = False
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
else:
self._aad = True
authentication_policy = get_authentication_policy(credential)
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
authentication_policy=authentication_policy,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient

def __repr__(self):
# type: () -> str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import time
import threading

from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import ServiceResponseTimeoutError
from ._utils import is_retryable_status_code
from ._utils import is_retryable_status_code, get_authentication_policy
from .indexes import SearchIndexClient as SearchServiceClient
from ._search_indexing_buffered_sender_base import SearchIndexingBufferedSenderBase
from ._generated import SearchClient as SearchIndexClient
Expand All @@ -22,7 +23,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Union
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials import TokenCredential

class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixin):
"""A buffered sender for document indexing actions.
Expand All @@ -32,7 +33,7 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi
:param index_name: The name of the index to connect to
:type index_name: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials.AzureKeyCredential
:type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential
:keyword int auto_flush_interval: how many max seconds if between 2 flushes. This only takes effect
when auto_flush is on. Default to 60 seconds.
:keyword int initial_batch_action_count: The initial number of actions to group into a batch when
Expand All @@ -52,19 +53,31 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi
# pylint: disable=too-many-instance-attributes

def __init__(self, endpoint, index_name, credential, **kwargs):
# type: (str, str, AzureKeyCredential, **Any) -> None
# type: (str, str, Union[AzureKeyCredential, TokenCredential], **Any) -> None
super(SearchIndexingBufferedSender, self).__init__(
endpoint=endpoint,
index_name=index_name,
credential=credential,
**kwargs)
self._index_documents_batch = IndexDocumentsBatch()
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
if isinstance(credential, AzureKeyCredential):
self._aad = False
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
else:
self._aad = True
authentication_policy = get_authentication_policy(credential)
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
authentication_policy=authentication_policy,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
self._reset_timer()

def _cleanup(self, flush=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, endpoint, index_name, credential, **kwargs):
self._endpoint = endpoint # type: str
self._index_name = index_name # type: str
self._index_key = None
self._credential = credential # type: AzureKeyCredential
self._credential = credential
self._on_new = kwargs.pop('on_new', None)
self._on_progress = kwargs.pop('on_progress', None)
self._on_error = kwargs.pop('on_error', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from azure.core.pipeline import policies

CREDENTIAL_SCOPES = ['https://search.azure.com/.default']

def is_retryable_status_code(status_code):
# type: (int) -> bool
return status_code in [422, 409, 503]

def get_authentication_policy(credential):
authentication_policy = policies.BearerTokenCredentialPolicy(credential, *CREDENTIAL_SCOPES)
return authentication_policy
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import cast, List, TYPE_CHECKING
from typing import cast, List, TYPE_CHECKING, Union
import six

from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator_async import distributed_trace_async
from ._paging import AsyncSearchItemPaged, AsyncSearchPageIterator
from ._utils_async import get_async_authentication_policy
from .._generated.aio import SearchClient as SearchIndexClient
from .._generated.models import IndexingResult
from .._search_documents_error import RequestEntityTooLargeError
Expand All @@ -20,7 +22,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential


class SearchClient(HeadersMixin):
Expand All @@ -31,7 +33,7 @@ class SearchClient(HeadersMixin):
:param index_name: The name of the index to connect to
:type index_name: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials.AzureKeyCredential
:type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential
:keyword str api_version: The Search API version to use for requests.
.. admonition:: Example:
Expand All @@ -46,20 +48,34 @@ class SearchClient(HeadersMixin):

_ODATA_ACCEPT = "application/json;odata.metadata=none" # type: str

def __init__(self, endpoint, index_name, credential, **kwargs):
# type: (str, str, AzureKeyCredential, **Any) -> None

def __init__(self, endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, "AsyncTokenCredential"],
**kwargs
) -> None:
self._api_version = kwargs.pop("api_version", DEFAULT_VERSION)
self._index_documents_batch = IndexDocumentsBatch()
self._endpoint = endpoint # type: str
self._index_name = index_name # type: str
self._credential = credential # type: AzureKeyCredential
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
self._credential = credential
if isinstance(credential, AzureKeyCredential):
self._aad = False
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
else:
self._aad = True
authentication_policy = get_async_authentication_policy(credential)
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
authentication_policy=authentication_policy,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient


def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import cast, List, TYPE_CHECKING
from typing import cast, List, TYPE_CHECKING, Union
import time

from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.exceptions import ServiceResponseTimeoutError
from ._timer import Timer
from ._utils_async import get_async_authentication_policy
from .._utils import is_retryable_status_code
from .._search_indexing_buffered_sender_base import SearchIndexingBufferedSenderBase
from .._generated.aio import SearchClient as SearchIndexClient
Expand All @@ -21,7 +23,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential


class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixin):
Expand All @@ -32,7 +34,7 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi
:param index_name: The name of the index to connect to
:type index_name: str
:param credential: A credential to authorize search client requests
:type credential: ~azure.core.credentials.AzureKeyCredential
:type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential
:keyword int auto_flush_interval: how many max seconds if between 2 flushes. This only takes effect
when auto_flush is on. Default to 60 seconds.
:keyword int initial_batch_action_count: The initial number of actions to group into a batch when
Expand All @@ -50,20 +52,35 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi
"""
# pylint: disable=too-many-instance-attributes

def __init__(self, endpoint, index_name, credential, **kwargs):
# type: (str, str, AzureKeyCredential, **Any) -> None
def __init__(self, endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, "AsyncTokenCredential"],
**kwargs
) -> None:
super(SearchIndexingBufferedSender, self).__init__(
endpoint=endpoint,
index_name=index_name,
credential=credential,
**kwargs)
self._index_documents_batch = IndexDocumentsBatch()
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
if isinstance(credential, AzureKeyCredential):
self._aad = False
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
else:
self._aad = True
authentication_policy = get_async_authentication_policy(credential)
self._client = SearchIndexClient(
endpoint=endpoint,
index_name=index_name,
authentication_policy=authentication_policy,
sdk_moniker=SDK_MONIKER,
api_version=self._api_version, **kwargs
) # type: SearchIndexClient
self._reset_timer()

async def _cleanup(self, flush=True):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from azure.core.pipeline import policies

from .._utils import CREDENTIAL_SCOPES

def get_async_authentication_policy(credential):
authentication_policy = policies.AsyncBearerTokenCredentialPolicy(credential, *CREDENTIAL_SCOPES)
return authentication_policy
Loading

0 comments on commit 4150ded

Please sign in to comment.