Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 8.15] allow embeddings vector to be used for mmr searching (#2620) #2639

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions elasticsearch/helpers/vectorstore/_async/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ async def delete( # type: ignore[no-untyped-def]
async def search(
self,
*,
query: Optional[str],
query: Optional[str] = None,
query_vector: Optional[List[float]] = None,
k: int = 4,
num_candidates: int = 50,
Expand Down Expand Up @@ -344,8 +344,9 @@ async def _create_index_if_not_exists(self) -> None:
async def max_marginal_relevance_search(
self,
*,
embedding_service: AsyncEmbeddingService,
query: str,
query: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
embedding_service: Optional[AsyncEmbeddingService] = None,
vector_field: str,
k: int = 4,
num_candidates: int = 20,
Expand All @@ -361,6 +362,8 @@ async def max_marginal_relevance_search(
among selected documents.

:param query (str): Text to look up documents similar to.
:param query_embedding: Input embedding vector. If given, input query string is
ignored.
:param k (int): Number of Documents to return. Defaults to 4.
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
:param lambda_mult (float): Number between 0 and 1 that determines the degree
Expand All @@ -381,12 +384,22 @@ async def max_marginal_relevance_search(
remove_vector_query_field_from_metadata = False

# Embed the query
query_embedding = await embedding_service.embed_query(query)
if query_embedding:
query_vector = query_embedding
else:
if not query:
raise ValueError("specify either query or query_embedding to search")
elif embedding_service:
query_vector = await embedding_service.embed_query(query)
elif self.embedding_service:
query_vector = await self.embedding_service.embed_query(query)
else:
raise ValueError("specify embedding_service to search with query")

# Fetch the initial documents
got_hits = await self.search(
query=None,
query_vector=query_embedding,
query_vector=query_vector,
k=num_candidates,
fields=fields,
custom_query=custom_query,
Expand All @@ -397,7 +410,7 @@ async def max_marginal_relevance_search(

# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k
query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_hits = [got_hits[i] for i in selected_indices]

Expand Down
25 changes: 19 additions & 6 deletions elasticsearch/helpers/vectorstore/_sync/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def delete( # type: ignore[no-untyped-def]
def search(
self,
*,
query: Optional[str],
query: Optional[str] = None,
query_vector: Optional[List[float]] = None,
k: int = 4,
num_candidates: int = 50,
Expand Down Expand Up @@ -341,8 +341,9 @@ def _create_index_if_not_exists(self) -> None:
def max_marginal_relevance_search(
self,
*,
embedding_service: EmbeddingService,
query: str,
query: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
embedding_service: Optional[EmbeddingService] = None,
vector_field: str,
k: int = 4,
num_candidates: int = 20,
Expand All @@ -358,6 +359,8 @@ def max_marginal_relevance_search(
among selected documents.

:param query (str): Text to look up documents similar to.
:param query_embedding: Input embedding vector. If given, input query string is
ignored.
:param k (int): Number of Documents to return. Defaults to 4.
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
:param lambda_mult (float): Number between 0 and 1 that determines the degree
Expand All @@ -378,12 +381,22 @@ def max_marginal_relevance_search(
remove_vector_query_field_from_metadata = False

# Embed the query
query_embedding = embedding_service.embed_query(query)
if query_embedding:
query_vector = query_embedding
else:
if not query:
raise ValueError("specify either query or query_embedding to search")
elif embedding_service:
query_vector = embedding_service.embed_query(query)
elif self.embedding_service:
query_vector = self.embedding_service.embed_query(query)
else:
raise ValueError("specify embedding_service to search with query")

# Fetch the initial documents
got_hits = self.search(
query=None,
query_vector=query_embedding,
query_vector=query_vector,
k=num_candidates,
fields=fields,
custom_query=custom_query,
Expand All @@ -394,7 +407,7 @@ def max_marginal_relevance_search(

# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k
query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_hits = [got_hits[i] for i in selected_indices]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,55 @@ def test_bulk_args(self, sync_client_request_saving: Any, index: str) -> None:
# 1 for index exist, 1 for index create, 3 to index docs
assert len(store.client.transport.requests) == 5 # type: ignore

def test_max_marginal_relevance_search_errors(
self, sync_client: Elasticsearch, index: str
) -> None:
"""Test max marginal relevance search error conditions."""
texts = ["foo", "bar", "baz"]
vector_field = "vector_field"
embedding_service = ConsistentFakeEmbeddings()
store = VectorStore(
index=index,
retrieval_strategy=DenseVectorScriptScoreStrategy(),
embedding_service=embedding_service,
client=sync_client,
)
store.add_texts(texts)

# search without query embeddings vector or query
with pytest.raises(
ValueError, match="specify either query or query_embedding to search"
):
store.max_marginal_relevance_search(
vector_field=vector_field,
k=3,
num_candidates=3,
)

# search without service
no_service_store = VectorStore(
index=index,
retrieval_strategy=DenseVectorScriptScoreStrategy(),
client=sync_client,
)
with pytest.raises(
ValueError, match="specify embedding_service to search with query"
):
no_service_store.max_marginal_relevance_search(
query=texts[0],
vector_field=vector_field,
k=3,
num_candidates=3,
)

def test_max_marginal_relevance_search(
self, sync_client: Elasticsearch, index: str
) -> None:
"""Test max marginal relevance search."""
texts = ["foo", "bar", "baz"]
vector_field = "vector_field"
text_field = "text_field"
query_embedding = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
embedding_service = ConsistentFakeEmbeddings()
store = VectorStore(
index=index,
Expand All @@ -833,8 +875,8 @@ def test_max_marginal_relevance_search(
)
store.add_texts(texts)

# search with query
mmr_output = store.max_marginal_relevance_search(
embedding_service=embedding_service,
query=texts[0],
vector_field=vector_field,
k=3,
Expand All @@ -843,8 +885,17 @@ def test_max_marginal_relevance_search(
sim_output = store.search(query=texts[0], k=3)
assert mmr_output == sim_output

# search with query embeddings
mmr_output = store.max_marginal_relevance_search(
query_embedding=query_embedding,
vector_field=vector_field,
k=3,
num_candidates=3,
)
sim_output = store.search(query_vector=query_embedding, k=3)
assert mmr_output == sim_output

mmr_output = store.max_marginal_relevance_search(
embedding_service=embedding_service,
query=texts[0],
vector_field=vector_field,
k=2,
Expand All @@ -855,7 +906,6 @@ def test_max_marginal_relevance_search(
assert mmr_output[1]["_source"][text_field] == texts[1]

mmr_output = store.max_marginal_relevance_search(
embedding_service=embedding_service,
query=texts[0],
vector_field=vector_field,
k=2,
Expand All @@ -868,7 +918,6 @@ def test_max_marginal_relevance_search(

# if fetch_k < k, then the output will be less than k
mmr_output = store.max_marginal_relevance_search(
embedding_service=embedding_service,
query=texts[0],
vector_field=vector_field,
k=3,
Expand Down
Loading