diff --git a/sdk/search/azure-search-documents/CHANGELOG.md b/sdk/search/azure-search-documents/CHANGELOG.md index 84a355394932..b086f97afa96 100644 --- a/sdk/search/azure-search-documents/CHANGELOG.md +++ b/sdk/search/azure-search-documents/CHANGELOG.md @@ -1,9 +1,10 @@ # Release History -## 11.3.0b1 (Unreleased) +## 11.3.0b1 (2021-07-07) ### Features Added +- Added AAD support - Added support for semantic search - Added normalizer support diff --git a/sdk/search/azure-search-documents/azure/search/documents/_headers_mixin.py b/sdk/search/azure-search-documents/azure/search/documents/_headers_mixin.py index c279412f31b9..1b09b013429f 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_headers_mixin.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_headers_mixin.py @@ -18,6 +18,8 @@ def _headers(self): def _merge_client_headers(self, headers): # type(Optional[dict]) -> dict + if self._aad: + return headers headers = headers or {} combined = self._headers combined.update(headers) diff --git a/sdk/search/azure-search-documents/azure/search/documents/_search_client.py b/sdk/search/azure-search-documents/azure/search/documents/_search_client.py index 757ab8cfbf3b..8455626832c7 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_search_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_search_client.py @@ -6,6 +6,7 @@ from typing import cast, List, TYPE_CHECKING import six +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator import distributed_trace from ._api_versions import DEFAULT_VERSION from ._generated import SearchClient as SearchIndexClient @@ -15,12 +16,13 @@ from ._paging import SearchItemPaged, SearchPageIterator from ._queries import AutocompleteQuery, SearchQuery, SuggestQuery from ._headers_mixin import HeadersMixin +from ._utils import get_authentication_policy from ._version import SDK_MONIKER if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any, Union - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials import TokenCredential def odata(statement, **kwargs): @@ -59,7 +61,7 @@ class SearchClient(HeadersMixin): :param index_name: The name of the index to connect to :type index_name: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential :keyword str api_version: The Search API version to use for requests. .. admonition:: Example: @@ -75,18 +77,30 @@ class SearchClient(HeadersMixin): _ODATA_ACCEPT = "application/json;odata.metadata=none" # type: str def __init__(self, endpoint, index_name, credential, **kwargs): - # type: (str, str, AzureKeyCredential, **Any) -> None + # type: (str, str, Union[AzureKeyCredential, TokenCredential], **Any) -> None self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._endpoint = endpoint # type: str self._index_name = index_name # type: str - self._credential = credential # type: AzureKeyCredential - self._client = SearchIndexClient( - endpoint=endpoint, - index_name=index_name, - sdk_moniker=SDK_MONIKER, - api_version=self._api_version, **kwargs - ) # type: SearchIndexClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient + else: + self._aad = True + authentication_policy = get_authentication_policy(credential) + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient def __repr__(self): # type: () -> str diff --git a/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender.py b/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender.py index 15306b3c28a6..98cfbad49572 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender.py @@ -7,9 +7,10 @@ import time import threading +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator import distributed_trace from azure.core.exceptions import ServiceResponseTimeoutError -from ._utils import is_retryable_status_code +from ._utils import is_retryable_status_code, get_authentication_policy from .indexes import SearchIndexClient as SearchServiceClient from ._search_indexing_buffered_sender_base import SearchIndexingBufferedSenderBase from ._generated import SearchClient as SearchIndexClient @@ -22,7 +23,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any, Union - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials import TokenCredential class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixin): """A buffered sender for document indexing actions. @@ -32,7 +33,7 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi :param index_name: The name of the index to connect to :type index_name: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential :keyword int auto_flush_interval: how many max seconds if between 2 flushes. This only takes effect when auto_flush is on. Default to 60 seconds. :keyword int initial_batch_action_count: The initial number of actions to group into a batch when @@ -52,19 +53,31 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi # pylint: disable=too-many-instance-attributes def __init__(self, endpoint, index_name, credential, **kwargs): - # type: (str, str, AzureKeyCredential, **Any) -> None + # type: (str, str, Union[AzureKeyCredential, TokenCredential], **Any) -> None super(SearchIndexingBufferedSender, self).__init__( endpoint=endpoint, index_name=index_name, credential=credential, **kwargs) self._index_documents_batch = IndexDocumentsBatch() - self._client = SearchIndexClient( - endpoint=endpoint, - index_name=index_name, - sdk_moniker=SDK_MONIKER, - api_version=self._api_version, **kwargs - ) # type: SearchIndexClient + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient + else: + self._aad = True + authentication_policy = get_authentication_policy(credential) + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient self._reset_timer() def _cleanup(self, flush=True): diff --git a/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender_base.py b/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender_base.py index 162997827474..b94d65c18578 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender_base.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_search_indexing_buffered_sender_base.py @@ -34,7 +34,7 @@ def __init__(self, endpoint, index_name, credential, **kwargs): self._endpoint = endpoint # type: str self._index_name = index_name # type: str self._index_key = None - self._credential = credential # type: AzureKeyCredential + self._credential = credential self._on_new = kwargs.pop('on_new', None) self._on_progress = kwargs.pop('on_progress', None) self._on_error = kwargs.pop('on_error', None) diff --git a/sdk/search/azure-search-documents/azure/search/documents/_utils.py b/sdk/search/azure-search-documents/azure/search/documents/_utils.py index 0e4025f47679..0be8970be755 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/_utils.py +++ b/sdk/search/azure-search-documents/azure/search/documents/_utils.py @@ -3,6 +3,14 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from azure.core.pipeline import policies + +CREDENTIAL_SCOPES = ['https://search.azure.com/.default'] + def is_retryable_status_code(status_code): # type: (int) -> bool return status_code in [422, 409, 503] + +def get_authentication_policy(credential): + authentication_policy = policies.BearerTokenCredentialPolicy(credential, *CREDENTIAL_SCOPES) + return authentication_policy diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py index d529b0b60f08..620c82b17fb3 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py @@ -3,11 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import cast, List, TYPE_CHECKING +from typing import cast, List, TYPE_CHECKING, Union import six +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator_async import distributed_trace_async from ._paging import AsyncSearchItemPaged, AsyncSearchPageIterator +from ._utils_async import get_async_authentication_policy from .._generated.aio import SearchClient as SearchIndexClient from .._generated.models import IndexingResult from .._search_documents_error import RequestEntityTooLargeError @@ -20,7 +22,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials_async import AsyncTokenCredential class SearchClient(HeadersMixin): @@ -31,7 +33,7 @@ class SearchClient(HeadersMixin): :param index_name: The name of the index to connect to :type index_name: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword str api_version: The Search API version to use for requests. .. admonition:: Example: @@ -46,20 +48,34 @@ class SearchClient(HeadersMixin): _ODATA_ACCEPT = "application/json;odata.metadata=none" # type: str - def __init__(self, endpoint, index_name, credential, **kwargs): - # type: (str, str, AzureKeyCredential, **Any) -> None - + def __init__(self, endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + **kwargs + ) -> None: self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._index_documents_batch = IndexDocumentsBatch() self._endpoint = endpoint # type: str self._index_name = index_name # type: str - self._credential = credential # type: AzureKeyCredential - self._client = SearchIndexClient( - endpoint=endpoint, - index_name=index_name, - sdk_moniker=SDK_MONIKER, - api_version=self._api_version, **kwargs - ) # type: SearchIndexClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient + else: + self._aad = True + authentication_policy = get_async_authentication_policy(credential) + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient def __repr__(self): diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_indexing_buffered_sender_async.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_indexing_buffered_sender_async.py index b16017868cda..25217ebcf50f 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/aio/_search_indexing_buffered_sender_async.py +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_search_indexing_buffered_sender_async.py @@ -3,12 +3,14 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import cast, List, TYPE_CHECKING +from typing import cast, List, TYPE_CHECKING, Union import time +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.exceptions import ServiceResponseTimeoutError from ._timer import Timer +from ._utils_async import get_async_authentication_policy from .._utils import is_retryable_status_code from .._search_indexing_buffered_sender_base import SearchIndexingBufferedSenderBase from .._generated.aio import SearchClient as SearchIndexClient @@ -21,7 +23,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from typing import Any - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials_async import AsyncTokenCredential class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixin): @@ -32,7 +34,7 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi :param index_name: The name of the index to connect to :type index_name: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword int auto_flush_interval: how many max seconds if between 2 flushes. This only takes effect when auto_flush is on. Default to 60 seconds. :keyword int initial_batch_action_count: The initial number of actions to group into a batch when @@ -50,20 +52,35 @@ class SearchIndexingBufferedSender(SearchIndexingBufferedSenderBase, HeadersMixi """ # pylint: disable=too-many-instance-attributes - def __init__(self, endpoint, index_name, credential, **kwargs): - # type: (str, str, AzureKeyCredential, **Any) -> None + def __init__(self, endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + **kwargs + ) -> None: super(SearchIndexingBufferedSender, self).__init__( endpoint=endpoint, index_name=index_name, credential=credential, **kwargs) self._index_documents_batch = IndexDocumentsBatch() - self._client = SearchIndexClient( - endpoint=endpoint, - index_name=index_name, - sdk_moniker=SDK_MONIKER, - api_version=self._api_version, **kwargs - ) # type: SearchIndexClient + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient + else: + self._aad = True + authentication_policy = get_async_authentication_policy(credential) + self._client = SearchIndexClient( + endpoint=endpoint, + index_name=index_name, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, **kwargs + ) # type: SearchIndexClient self._reset_timer() async def _cleanup(self, flush=True): diff --git a/sdk/search/azure-search-documents/azure/search/documents/aio/_utils_async.py b/sdk/search/azure-search-documents/azure/search/documents/aio/_utils_async.py new file mode 100644 index 000000000000..c5b5781f1f23 --- /dev/null +++ b/sdk/search/azure-search-documents/azure/search/documents/aio/_utils_async.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from azure.core.pipeline import policies + +from .._utils import CREDENTIAL_SCOPES + +def get_async_authentication_policy(credential): + authentication_policy = policies.AsyncBearerTokenCredentialPolicy(credential, *CREDENTIAL_SCOPES) + return authentication_policy diff --git a/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_index_client.py b/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_index_client.py index c9c2e6ded8bf..d9f87a36a8c8 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_index_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_index_client.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from azure.core import MatchConditions +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator import distributed_trace from azure.core.paging import ItemPaged @@ -16,6 +17,7 @@ normalize_endpoint, ) from .._headers_mixin import HeadersMixin +from .._utils import get_authentication_policy from .._version import SDK_MONIKER from .._search_client import SearchClient from .models import SearchIndex, SynonymMap @@ -24,7 +26,7 @@ # pylint:disable=unused-import,ungrouped-imports from .models._models import AnalyzeTextOptions from typing import Any, Dict, List, Sequence, Union, Optional - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials import TokenCredential class SearchIndexClient(HeadersMixin): @@ -33,21 +35,36 @@ class SearchIndexClient(HeadersMixin): :param endpoint: The URL endpoint of an Azure search service :type endpoint: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential :keyword str api_version: The Search API version to use for requests. """ _ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str def __init__(self, endpoint, credential, **kwargs): - # type: (str, AzureKeyCredential, **Any) -> None + # type: (str, Union[AzureKeyCredential, TokenCredential], **Any) -> None self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._endpoint = normalize_endpoint(endpoint) # type: str - self._credential = credential # type: AzureKeyCredential - self._client = _SearchServiceClient( - endpoint=endpoint, sdk_moniker=SDK_MONIKER, api_version=self._api_version, **kwargs - ) # type: _SearchServiceClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = _SearchServiceClient( + endpoint=endpoint, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient + else: + self._aad = True + authentication_policy = get_authentication_policy(credential) + self._client = _SearchServiceClient( + endpoint=endpoint, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient def __enter__(self): # type: () -> SearchIndexClient diff --git a/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_indexer_client.py b/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_indexer_client.py index fd49ac624c3b..0a3e8a754951 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_indexer_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/indexes/_search_indexer_client.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from azure.core import MatchConditions +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator import distributed_trace from ._generated import SearchClient as _SearchServiceClient @@ -17,13 +18,14 @@ from .models import SearchIndexerDataSourceConnection from .._api_versions import DEFAULT_VERSION from .._headers_mixin import HeadersMixin +from .._utils import get_authentication_policy from .._version import SDK_MONIKER if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from ._generated.models import SearchIndexer, SearchIndexerStatus - from typing import Any, Optional, Sequence - from azure.core.credentials import AzureKeyCredential + from typing import Any, Optional, Sequence, Union + from azure.core.credentials import TokenCredential class SearchIndexerClient(HeadersMixin): # pylint: disable=R0904 @@ -32,7 +34,7 @@ class SearchIndexerClient(HeadersMixin): # pylint: disable=R0904 :param endpoint: The URL endpoint of an Azure search service :type endpoint: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential :keyword str api_version: The Search API version to use for requests. """ @@ -40,14 +42,29 @@ class SearchIndexerClient(HeadersMixin): # pylint: disable=R0904 _ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str def __init__(self, endpoint, credential, **kwargs): - # type: (str, AzureKeyCredential, **Any) -> None + # type: (str, Union[AzureKeyCredential, TokenCredential], **Any) -> None self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._endpoint = normalize_endpoint(endpoint) # type: str - self._credential = credential # type: AzureKeyCredential - self._client = _SearchServiceClient( - endpoint=endpoint, sdk_moniker=SDK_MONIKER, api_version=self._api_version, **kwargs - ) # type: _SearchServiceClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = _SearchServiceClient( + endpoint=endpoint, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient + else: + self._aad = True + authentication_policy = get_authentication_policy(credential) + self._client = _SearchServiceClient( + endpoint=endpoint, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient def __enter__(self): # type: () -> SearchIndexerClient diff --git a/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_index_client.py b/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_index_client.py index 06634dfd2b57..d68de85b3107 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_index_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_index_client.py @@ -3,14 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from azure.core import MatchConditions +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.async_paging import AsyncItemPaged from .._generated.aio import SearchClient as _SearchServiceClient from ...aio._search_client_async import SearchClient +from ...aio._utils_async import get_async_authentication_policy from .._utils import ( get_access_conditions, normalize_endpoint, @@ -27,8 +29,8 @@ # pylint:disable=unused-import,ungrouped-imports from .._generated.models import AnalyzeResult from ..models._models import AnalyzeTextOptions - from typing import Any, Dict, List, Union - from azure.core.credentials import AzureKeyCredential + from typing import Any, Dict, List + from azure.core.credentials_async import AsyncTokenCredential class SearchIndexClient(HeadersMixin): @@ -37,22 +39,38 @@ class SearchIndexClient(HeadersMixin): :param endpoint: The URL endpoint of an Azure search service :type endpoint: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword str api_version: The Search API version to use for requests. """ _ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str - def __init__(self, endpoint, credential, **kwargs): - # type: (str, AzureKeyCredential, **Any) -> None - + def __init__(self, endpoint: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + **kwargs + ) -> None: self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._endpoint = normalize_endpoint(endpoint) # type: str - self._credential = credential # type: AzureKeyCredential - self._client = _SearchServiceClient( - endpoint=endpoint, sdk_moniker=SDK_MONIKER, api_version=self._api_version, **kwargs - ) # type: _SearchServiceClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = _SearchServiceClient( + endpoint=endpoint, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient + else: + self._aad = True + authentication_policy = get_async_authentication_policy(credential) + self._client = _SearchServiceClient( + endpoint=endpoint, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient async def __aenter__(self): # type: () -> SearchIndexesClient diff --git a/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_indexer_client.py b/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_indexer_client.py index 1df655059e38..3c5382179130 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_indexer_client.py +++ b/sdk/search/azure-search-documents/azure/search/documents/indexes/aio/_search_indexer_client.py @@ -3,9 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from azure.core import MatchConditions +from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator_async import distributed_trace_async from .._generated.aio import SearchClient as _SearchServiceClient @@ -20,12 +21,13 @@ from ..._api_versions import DEFAULT_VERSION from ..._headers_mixin import HeadersMixin from ..._version import SDK_MONIKER +from ...aio._utils_async import get_async_authentication_policy if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports from .._generated.models import SearchIndexer, SearchIndexerStatus from typing import Any, Optional, Sequence - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials_async import AsyncTokenCredential class SearchIndexerClient(HeadersMixin): # pylint: disable=R0904 @@ -34,22 +36,38 @@ class SearchIndexerClient(HeadersMixin): # pylint: disable=R0904 :param endpoint: The URL endpoint of an Azure search service :type endpoint: str :param credential: A credential to authorize search client requests - :type credential: ~azure.core.credentials.AzureKeyCredential + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword str api_version: The Search API version to use for requests. """ _ODATA_ACCEPT = "application/json;odata.metadata=minimal" # type: str - def __init__(self, endpoint, credential, **kwargs): - # type: (str, AzureKeyCredential, **Any) -> None - + def __init__(self, endpoint: str, + credential: Union[AzureKeyCredential, "AsyncTokenCredential"], + **kwargs + ) -> None: self._api_version = kwargs.pop("api_version", DEFAULT_VERSION) self._endpoint = normalize_endpoint(endpoint) # type: str - self._credential = credential # type: AzureKeyCredential - self._client = _SearchServiceClient( - endpoint=endpoint, sdk_moniker=SDK_MONIKER, api_version=self._api_version, **kwargs - ) # type: _SearchServiceClient + self._credential = credential + if isinstance(credential, AzureKeyCredential): + self._aad = False + self._client = _SearchServiceClient( + endpoint=endpoint, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient + else: + self._aad = True + authentication_policy = get_async_authentication_policy(credential) + self._client = _SearchServiceClient( + endpoint=endpoint, + authentication_policy=authentication_policy, + sdk_moniker=SDK_MONIKER, + api_version=self._api_version, + **kwargs + ) # type: _SearchServiceClient async def __aenter__(self): # type: () -> SearchIndexersClient