Skip to content

Commit

Permalink
community: Azure Search Vector Store is missing Access Token Authenti…
Browse files Browse the repository at this point in the history
…cation (langchain-ai#24330)

Added Azure Search Access Token Authentication instead of API KEY auth.
Fixes Issue: langchain-ai#24263
Dependencies: None
Twitter: @levalencia

@baskaryan

Could you please review? First time creating a PR that fixes some code.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
levalencia and efriis committed Aug 26, 2024
1 parent 49b0bc7 commit 99f9a66
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 11 deletions.
1 change: 0 additions & 1 deletion libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ cloudpickle>=2.0.0
cohere>=4,<6
databricks-vectorsearch>=0.21,<0.22
datasets>=2.15.0,<3
dedoc>=2.2.6,<3
dgml-utils>=0.3.0,<0.4
elasticsearch>=8.12.0,<9
esprima>=4.0.1,<5
Expand Down
39 changes: 29 additions & 10 deletions libs/community/langchain_community/vectorstores/azuresearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import json
import logging
import time
import uuid
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -79,8 +80,9 @@

def _get_search_client(
endpoint: str,
key: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
Expand All @@ -95,7 +97,7 @@ def _get_search_client(
async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.search.documents import SearchClient
Expand All @@ -119,13 +121,23 @@ def _get_search_client(

additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or []
if key is None:
credential = DefaultAzureCredential()
elif key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]

# Determine the appropriate credential to use
if key is not None:
if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential.get_token("https://search.azure.com/.default")
else:
credential = AzureKeyCredential(key)
elif azure_ad_access_token is not None:
credential = TokenCredential(
lambda *scopes, **kwargs: AccessToken(
azure_ad_access_token, int(time.time()) + 3600
)
)
else:
credential = AzureKeyCredential(key)
credential = DefaultAzureCredential()
index_client: SearchIndexClient = SearchIndexClient(
endpoint=endpoint, credential=credential, user_agent=user_agent
)
Expand Down Expand Up @@ -253,6 +265,7 @@ def __init__(
self,
azure_search_endpoint: str,
azure_search_key: str,
azure_ad_access_token: Optional[str],
index_name: str,
embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid",
Expand Down Expand Up @@ -321,8 +334,9 @@ def __init__(
user_agent += " " + kwargs["user_agent"]
self.client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
Expand All @@ -336,8 +350,9 @@ def __init__(
)
self.async_client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
azure_search_key,
azure_ad_access_token,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
Expand Down Expand Up @@ -1387,6 +1402,7 @@ def from_texts(
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
Expand All @@ -1395,6 +1411,7 @@ def from_texts(
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,
Expand All @@ -1411,6 +1428,7 @@ async def afrom_texts(
metadatas: Optional[List[dict]] = None,
azure_search_endpoint: str = "",
azure_search_key: str = "",
azure_ad_access_token: Optional[str] = None,
index_name: str = "langchain-index",
fields: Optional[List[SearchField]] = None,
**kwargs: Any,
Expand All @@ -1419,6 +1437,7 @@ async def afrom_texts(
azure_search = cls(
azure_search_endpoint,
azure_search_key,
azure_ad_access_token,
index_name,
embedding,
fields=fields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Vector store settings
vector_store_address: str = os.getenv("AZURE_SEARCH_ENDPOINT", "")
vector_store_password: str = os.getenv("AZURE_SEARCH_ADMIN_KEY", "")
access_token: str = os.getenv("AZURE_SEARCH_ACCESS_TOKEN", "")
index_name: str = "embeddings-vector-store-test"


Expand All @@ -25,6 +26,7 @@ def similarity_search_test() -> None:
vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name,
embedding_function=embeddings.embed_query,
)
Expand Down Expand Up @@ -68,6 +70,7 @@ def test_semantic_hybrid_search() -> None:
vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password,
azure_ad_access_token=access_token,
index_name=index_name,
embedding_function=embeddings.embed_query,
semantic_configuration_name="default",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def embed_query(self, text: str) -> List[float]:
DEFAULT_INDEX_NAME = "langchain-index"
DEFAULT_ENDPOINT = "https://my-search-service.search.windows.net"
DEFAULT_KEY = "mykey"
DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()


Expand Down Expand Up @@ -127,6 +128,7 @@ def create_vector_store(
return AzureSearch(
azure_search_endpoint=DEFAULT_ENDPOINT,
azure_search_key=DEFAULT_KEY,
azure_ad_access_token=DEFAULT_ACCESS_TOKEN,
index_name=DEFAULT_INDEX_NAME,
embedding_function=DEFAULT_EMBEDDING_MODEL,
additional_search_client_options=additional_search_client_options,
Expand Down

0 comments on commit 99f9a66

Please sign in to comment.