-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vectara support in Eidolon. 1. Adds vectara logic unit for simple agentic rag 2. Adds vectara agent utilizing built in chat interface. Will add docs in follow-up pr if we like the contract
- Loading branch information
Showing
11 changed files
with
7,134 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import copy | ||
import json | ||
import os | ||
from typing import Annotated | ||
from urllib.parse import urljoin | ||
|
||
from fastapi import Body | ||
from httpx import AsyncClient | ||
from httpx_sse import EventSource | ||
from pydantic import BaseModel, Field | ||
|
||
from eidolon_ai_client.events import StringOutputEvent, StartStreamContextEvent, ObjectOutputEvent, \ | ||
EndStreamContextEvent, AgentStateEvent | ||
from eidolon_ai_sdk.agent.agent import register_action | ||
from eidolon_ai_sdk.system.processes import MongoDoc | ||
from eidolon_ai_sdk.system.reference_model import Specable | ||
|
||
|
||
os.environ.setdefault("VECTARA_API_KEY", "test") | ||
|
||
|
||
class VectaraAgentSpec(BaseModel): | ||
""" | ||
An agent backed by Vectara. Requires the VECTARA_API_KEY environment variable to be set for authentication. | ||
""" | ||
|
||
corpus_key: str | ||
description: str = "Search documents related to {{ corpus_key }}" | ||
vectara_url: str = "https://api.vectara.io/" | ||
body_overrides: dict = Field({}, description="Arguments to use when creating / continuing a chat. See https://docs.vectara.com/docs/rest-api/create-chat for more information.") | ||
|
||
|
||
# We need to store chatid / processid mappings since vectara doesn't have metatdata / query concepts | ||
class VectaraDoc(MongoDoc): | ||
collection = "vectara_docs" | ||
process_id: str | ||
vectara_chat_id: str | ||
metadata: dict = {} | ||
|
||
|
||
class VectaraAgent(Specable[VectaraAgentSpec]): | ||
@property | ||
def _token(self): | ||
return os.environ["VECTARA_API_KEY"] | ||
|
||
@property | ||
def _headers(self): | ||
return { | ||
'Content-Type': 'application/json', | ||
'Accept': 'text/event-stream', | ||
'x-api-key': self._token | ||
} | ||
|
||
def _url(self, suffix): | ||
return urljoin(self.spec.vectara_url, suffix) | ||
|
||
@register_action("initialized", "idle", description=lambda agent, _: agent.spec.description) | ||
async def converse(self, process_id, question: Annotated[str, Body()]): | ||
body = copy.deepcopy(self.spec.body_overrides) | ||
body.setdefault("search", {}).setdefault("corpora", [{}]) | ||
for corpus in body["search"]["corpora"]: | ||
corpus.setdefault("corpus_key", self.spec.corpus_key) | ||
body["query"] = question | ||
body["stream_response"] = True | ||
|
||
doc = await VectaraDoc.find_one(query=dict(process_id=process_id)) | ||
async with AsyncClient() as client: | ||
response = await client.post( | ||
url=self._url("/v2/chats" if not doc else f"/v2/chats/{doc.vectara_chat_id}/turns"), | ||
headers=self._headers, | ||
json=body, | ||
) | ||
response.raise_for_status() | ||
|
||
yield StartStreamContextEvent(context_id="response_info", title="Response Information") | ||
try: | ||
async for sse_event in EventSource(response).aiter_sse(): | ||
if sse_event.event == "chat_info": | ||
if not doc: | ||
data = json.loads(sse_event.data) | ||
doc = await VectaraDoc.create(process_id=process_id, vectara_chat_id=data['chat_id']) | ||
elif sse_event.event == "generation_chunk": | ||
data = json.loads(sse_event.data) | ||
yield StringOutputEvent(content=data['generation_chunk']) | ||
elif sse_event.event == "search_results": | ||
for result in json.loads(sse_event.data)["search_results"]: | ||
yield ObjectOutputEvent(stream_context="response_info", content=result) | ||
elif sse_event.event == "factual_consistency_score": | ||
yield ObjectOutputEvent(stream_context="response_info", content=json.loads(sse_event.data)) | ||
finally: | ||
yield EndStreamContextEvent(context_id="response_info") | ||
|
||
yield AgentStateEvent(state="idle") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
from urllib.parse import urljoin | ||
|
||
from httpx import AsyncClient | ||
from pydantic import BaseModel, Field | ||
|
||
from eidolon_ai_sdk.apu.logic_unit import LogicUnit, llm_function | ||
from eidolon_ai_sdk.system.reference_model import Specable | ||
|
||
|
||
class VectaraSearchSpec(BaseModel): | ||
""" | ||
A logic unit for searching in Vectara. Requires the VECTARA_API_KEY environment variable to be set for authentication. | ||
""" | ||
|
||
corpus_key: str = Field(description="The corpus key to search in.") | ||
description: str = Field("Search documents related to {corpus_key}.", description="Description of the tool presented to LLM. Will be formatted with corpus_key.") | ||
vectara_url: str = "https://api.vectara.io/" | ||
|
||
|
||
class VectaraSearch(Specable[VectaraSearchSpec], LogicUnit): | ||
@property | ||
def _token(self): | ||
return os.environ["VECTARA_API_KEY"] | ||
|
||
@property | ||
def _headers(self): | ||
return { | ||
'Accept': 'application/json', | ||
'x-api-key': self._token | ||
} | ||
|
||
def _url(self, suffix): | ||
return urljoin(self.spec.vectara_url, suffix) | ||
|
||
@llm_function(description=lambda lu, _: lu.spec.description.format(corpus_key=lu.spec.corpus_key)) | ||
async def query(self, query: str, limit: int = 10, offset: int = 0): | ||
async with AsyncClient() as client: | ||
response = await client.post( | ||
url=self._url(f"/v2/corpora/{self.spec.corpus_key}/query"), | ||
headers=self._headers, | ||
json=dict( | ||
query=query, | ||
search=dict( | ||
limit=limit, | ||
offset=offset, | ||
) | ||
), | ||
) | ||
response.raise_for_status() | ||
response_body = response.json() | ||
content = [dict(text=r.get("text"), document_id=r.get("document_id")) for r in response_body["search_results"]] | ||
documents = {r.get("document_id"): r.get("document_metadata", {}).get("title") for r in response_body["search_results"]} | ||
return dict(search_results=content, documents=documents) | ||
|
||
@llm_function() | ||
async def read_document(self, document_id: str): | ||
async with AsyncClient() as client: | ||
response = await client.get( | ||
url=self._url(f"/v2/corpora/{self.spec.corpus_key}/documents/{document_id}"), | ||
headers=self._headers, | ||
) | ||
response.raise_for_status() | ||
response_body = response.json() | ||
|
||
sections = [] | ||
for part in response_body["parts"]: | ||
if part["metadata"].get("is_title", False): | ||
sections.append((dict(title=part["text"], content=[]))) | ||
if "title_level" in part["metadata"]: | ||
sections[-1]["title_level"] = part["metadata"]["title_level"] | ||
else: | ||
if not sections: | ||
sections.append(dict(title="", content=[])) | ||
sections[-1]["content"].append(part["text"]) | ||
for section in sections: | ||
section["content"] = "".join(section["content"]) | ||
|
||
return dict( | ||
document_id=document_id, | ||
metadata=response_body["metadata"], | ||
sections=sections, | ||
) |
Oops, something went wrong.