diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 515bcfc538265..8812425498c12 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -16,7 +16,6 @@ cloudpickle>=2.0.0 cohere>=4,<6 databricks-vectorsearch>=0.21,<0.22 datasets>=2.15.0,<3 -dedoc>=2.2.6,<3 dgml-utils>=0.3.0,<0.4 elasticsearch>=8.12.0,<9 esprima>=4.0.1,<5 diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 9154f4493396d..45d5566fb56e0 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -5,6 +5,7 @@ import itertools import json import logging +import time import uuid from typing import ( TYPE_CHECKING, @@ -79,8 +80,9 @@ def _get_search_client( endpoint: str, - key: str, index_name: str, + key: Optional[str] = None, + azure_ad_access_token: Optional[str] = None, semantic_configuration_name: Optional[str] = None, fields: Optional[List[SearchField]] = None, vector_search: Optional[VectorSearch] = None, @@ -95,7 +97,7 @@ def _get_search_client( async_: bool = False, additional_search_client_options: Optional[Dict[str, Any]] = None, ) -> Union[SearchClient, AsyncSearchClient]: - from azure.core.credentials import AzureKeyCredential + from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential from azure.search.documents import SearchClient @@ -119,13 +121,23 @@ def _get_search_client( additional_search_client_options = additional_search_client_options or {} default_fields = default_fields or [] - if key is None: - credential = DefaultAzureCredential() - elif key.upper() == "INTERACTIVE": - credential = InteractiveBrowserCredential() - credential.get_token("https://search.azure.com/.default") + credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential] + + # Determine the appropriate credential to use + if key is not None: + if key.upper() == "INTERACTIVE": + credential = InteractiveBrowserCredential() + credential.get_token("https://search.azure.com/.default") + else: + credential = AzureKeyCredential(key) + elif azure_ad_access_token is not None: + credential = TokenCredential( + lambda *scopes, **kwargs: AccessToken( + azure_ad_access_token, int(time.time()) + 3600 + ) + ) else: - credential = AzureKeyCredential(key) + credential = DefaultAzureCredential() index_client: SearchIndexClient = SearchIndexClient( endpoint=endpoint, credential=credential, user_agent=user_agent ) @@ -253,6 +265,7 @@ def __init__( self, azure_search_endpoint: str, azure_search_key: str, + azure_ad_access_token: Optional[str], index_name: str, embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", @@ -321,8 +334,9 @@ def __init__( user_agent += " " + kwargs["user_agent"] self.client = _get_search_client( azure_search_endpoint, - azure_search_key, index_name, + azure_search_key, + azure_ad_access_token, semantic_configuration_name=semantic_configuration_name, fields=fields, vector_search=vector_search, @@ -336,8 +350,9 @@ def __init__( ) self.async_client = _get_search_client( azure_search_endpoint, - azure_search_key, index_name, + azure_search_key, + azure_ad_access_token, semantic_configuration_name=semantic_configuration_name, fields=fields, vector_search=vector_search, @@ -1387,6 +1402,7 @@ def from_texts( metadatas: Optional[List[dict]] = None, azure_search_endpoint: str = "", azure_search_key: str = "", + azure_ad_access_token: Optional[str] = None, index_name: str = "langchain-index", fields: Optional[List[SearchField]] = None, **kwargs: Any, @@ -1395,6 +1411,7 @@ def from_texts( azure_search = cls( azure_search_endpoint, azure_search_key, + azure_ad_access_token, index_name, embedding, fields=fields, @@ -1411,6 +1428,7 @@ async def afrom_texts( metadatas: Optional[List[dict]] = None, azure_search_endpoint: str = "", azure_search_key: str = "", + azure_ad_access_token: Optional[str] = None, index_name: str = "langchain-index", fields: Optional[List[SearchField]] = None, **kwargs: Any, @@ -1419,6 +1437,7 @@ async def afrom_texts( azure_search = cls( azure_search_endpoint, azure_search_key, + azure_ad_access_token, index_name, embedding, fields=fields, diff --git a/libs/community/tests/integration_tests/vectorstores/test_azuresearch.py b/libs/community/tests/integration_tests/vectorstores/test_azuresearch.py index 69bf73fb86b7a..051fe57b2310c 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_azuresearch.py +++ b/libs/community/tests/integration_tests/vectorstores/test_azuresearch.py @@ -13,6 +13,7 @@ # Vector store settings vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "") vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "") +access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "") index_name: str = "embeddings-vector-store-test" @@ -25,6 +26,7 @@ def similarity_search_test() -> None: vector_store: AzureSearch = AzureSearch( azure_search_endpoint=vector_store_address, azure_search_key=vector_store_password, + azure_ad_access_token=access_token, index_name=index_name, embedding_function=embeddings.embed_query, ) @@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None: vector_store: AzureSearch = AzureSearch( azure_search_endpoint=vector_store_address, azure_search_key=vector_store_password, + azure_ad_access_token=access_token, index_name=index_name, embedding_function=embeddings.embed_query, semantic_configuration_name="default", diff --git a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py index a06fbfd151b0b..0f54e08d501ec 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py @@ -32,6 +32,7 @@ def embed_query(self, text: str) -> List[float]: DEFAULT_INDEX_NAME = "langchain-index" DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net" DEFAULT_KEY = "mykey" +DEFAULT_ACCESS_TOKEN = "myaccesstoken1" DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension() @@ -127,6 +128,7 @@ def create_vector_store( return AzureSearch( azure_search_endpoint=DEFAULT_ENDPOINT, azure_search_key=DEFAULT_KEY, + azure_ad_access_token=DEFAULT_ACCESS_TOKEN, index_name=DEFAULT_INDEX_NAME, embedding_function=DEFAULT_EMBEDDING_MODEL, additional_search_client_options=additional_search_client_options,