-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support for falcon, llama2 and gpt2(testing)
- Loading branch information
1 parent
52214f2
commit 2fd2347
Showing
18 changed files
with
132 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from abc import ABC | ||
from typing import Dict, Optional | ||
|
||
from fastapi import HTTPException | ||
from langchain import HuggingFaceHub, LLMChain | ||
from pydantic import Field | ||
|
||
from genoss.entities.chat.chat_completion import ChatCompletion | ||
from genoss.llm.base_genoss import BaseGenossLLM | ||
from genoss.prompts.prompt_template import prompt_template | ||
|
||
|
||
class BaseHuggingFaceHubLLM(BaseGenossLLM, ABC): | ||
# Sub classes must define these | ||
HUGGINGFACEHUB_API_TOKEN: Optional[str] = Field(None) | ||
repo_id: Optional[str] = None | ||
|
||
""" | ||
Class for interacting with Hugging Face Inference APIs | ||
""" | ||
|
||
def __init__(self, api_key, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
if api_key is None: | ||
raise HTTPException(status_code=403, detail="API key missing") | ||
|
||
self.HUGGINGFACEHUB_API_TOKEN = api_key | ||
|
||
def generate_answer(self, question: str) -> Dict: | ||
""" | ||
Generate answer from prompt | ||
""" | ||
|
||
print("Generating Answer...") | ||
|
||
llm = HuggingFaceHub( | ||
repo_id=self.repo_id, huggingfacehub_api_token=self.HUGGINGFACEHUB_API_TOKEN | ||
) # type: ignore | ||
llm_chain = LLMChain(prompt=prompt_template, llm=llm) | ||
|
||
response_text = llm_chain(question) | ||
|
||
answer = response_text["text"] | ||
|
||
chat_completion = ChatCompletion( | ||
model=self.name, question=question, answer=answer | ||
) | ||
|
||
return chat_completion.to_dict() | ||
|
||
def generate_embedding(self, text: str): | ||
"""Dummy method to satisfy base class requirement.""" | ||
raise NotImplementedError( | ||
"This method is not used for Hugging Face Inference API." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM | ||
|
||
|
||
class HuggingFaceHubFalconLLM(BaseHuggingFaceHubLLM): | ||
name: str = "falcon" | ||
description: str = "Hugging Face Falcon Inference API" | ||
repo_id = "tiiuae/falcon-40b" | ||
|
||
""" | ||
Class for interacting with Hugging Face Falcon Inference API | ||
""" | ||
|
||
def __init__(self, api_key, *args, **kwargs): | ||
super().__init__(api_key, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM | ||
|
||
|
||
class HuggingFaceHubGPT2LLM(BaseHuggingFaceHubLLM): | ||
name: str = "gpt2" | ||
description: str = "Hugging Face GPT2 Test Inference API" | ||
repo_id = "gpt2" | ||
|
||
""" | ||
Class for interacting with Hugging Face GPT2 Inference API. Good for testing. | ||
""" | ||
|
||
def __init__(self, api_key, *args, **kwargs): | ||
super().__init__(api_key, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM | ||
|
||
|
||
class HuggingFaceHubLlama2LLM(BaseHuggingFaceHubLLM): | ||
name: str = "llama2" | ||
description: str = "Hugging Face Llama2 Inference API" | ||
repo_id = "Llama-2-70b-chat-hf" | ||
|
||
""" | ||
Class for interacting with Hugging Face Llama Inference API | ||
""" | ||
|
||
def __init__(self, api_key, *args, **kwargs): | ||
super().__init__(api_key, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from langchain import PromptTemplate | ||
|
||
system_prompt = "Question from user: {question}?, Answer from helpful chatbot:" | ||
|
||
prompt_template = PromptTemplate(template=system_prompt, input_variables=["question"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters