Skip to content

Commit

Permalink
feat: support for falcon, llama2 and gpt2(testing)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzcarey committed Jul 20, 2023
1 parent 52214f2 commit 2fd2347
Show file tree
Hide file tree
Showing 18 changed files with 132 additions and 88 deletions.
6 changes: 3 additions & 3 deletions genoss/api/completions_routes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

from fastapi import APIRouter, Body, HTTPException
from fastapi.params import Depends
from pydantic import BaseModel

from genoss.auth.auth_handler import AuthHandler
from genoss.entities.chat.messages import Messages
from genoss.entities.chat.message import Message
from genoss.services.model_factory import ModelFactory
from logger import get_logger

Expand All @@ -16,7 +16,7 @@

class RequestBody(BaseModel):
model: str
messages: Messages
messages: List[Message]
temperature: Optional[float]


Expand Down
4 changes: 3 additions & 1 deletion genoss/api/embeddings_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ast import List

from fastapi import APIRouter
from genoss.model.llm.local.gpt4all import Gpt4AllLLM

from genoss.llm.local.gpt4all import Gpt4AllLLM
from logger import get_logger

logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion genoss/entities/chat/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from typing import Any, Dict

from genoss.entities.chat.messages import Message
from genoss.entities.chat.message import Message


class ChatCompletion:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict

from pydantic import BaseModel, Field

Expand All @@ -15,13 +15,3 @@ class Message(BaseModel):

def to_dict(self) -> Dict[str, Any]:
return {"role": self.role, "content": self.content}


class Messages(BaseModel):
messages: List[Message]

def to_dict(self) -> List[Dict[str, Any]]:
return [message.to_dict() for message in self.messages]

def __getitem__(self, index):
return self.messages[index]
File renamed without changes.
4 changes: 2 additions & 2 deletions genoss/model/llm/fake_llm.py → genoss/llm/fake_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from langchain.llms import FakeListLLM

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.model.llm.base_genoss import BaseGenossLLM
from genoss.model.prompts.prompt_template import prompt_template
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.prompts.prompt_template import prompt_template

FAKE_LLM_NAME = "fake"

Expand Down
56 changes: 56 additions & 0 deletions genoss/llm/hf_hub/base_hf_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from abc import ABC
from typing import Dict, Optional

from fastapi import HTTPException
from langchain import HuggingFaceHub, LLMChain
from pydantic import Field

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.prompts.prompt_template import prompt_template


class BaseHuggingFaceHubLLM(BaseGenossLLM, ABC):
# Sub classes must define these
HUGGINGFACEHUB_API_TOKEN: Optional[str] = Field(None)
repo_id: Optional[str] = None

"""
Class for interacting with Hugging Face Inference APIs
"""

def __init__(self, api_key, *args, **kwargs):
super().__init__(*args, **kwargs)

if api_key is None:
raise HTTPException(status_code=403, detail="API key missing")

self.HUGGINGFACEHUB_API_TOKEN = api_key

def generate_answer(self, question: str) -> Dict:
"""
Generate answer from prompt
"""

print("Generating Answer...")

llm = HuggingFaceHub(
repo_id=self.repo_id, huggingfacehub_api_token=self.HUGGINGFACEHUB_API_TOKEN
) # type: ignore
llm_chain = LLMChain(prompt=prompt_template, llm=llm)

response_text = llm_chain(question)

answer = response_text["text"]

chat_completion = ChatCompletion(
model=self.name, question=question, answer=answer
)

return chat_completion.to_dict()

def generate_embedding(self, text: str):
"""Dummy method to satisfy base class requirement."""
raise NotImplementedError(
"This method is not used for Hugging Face Inference API."
)
14 changes: 14 additions & 0 deletions genoss/llm/hf_hub/falcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM


class HuggingFaceHubFalconLLM(BaseHuggingFaceHubLLM):
name: str = "falcon"
description: str = "Hugging Face Falcon Inference API"
repo_id = "tiiuae/falcon-40b"

"""
Class for interacting with Hugging Face Falcon Inference API
"""

def __init__(self, api_key, *args, **kwargs):
super().__init__(api_key, *args, **kwargs)
14 changes: 14 additions & 0 deletions genoss/llm/hf_hub/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM


class HuggingFaceHubGPT2LLM(BaseHuggingFaceHubLLM):
name: str = "gpt2"
description: str = "Hugging Face GPT2 Test Inference API"
repo_id = "gpt2"

"""
Class for interacting with Hugging Face GPT2 Inference API. Good for testing.
"""

def __init__(self, api_key, *args, **kwargs):
super().__init__(api_key, *args, **kwargs)
14 changes: 14 additions & 0 deletions genoss/llm/hf_hub/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from genoss.llm.hf_hub.base_hf_hub import BaseHuggingFaceHubLLM


class HuggingFaceHubLlama2LLM(BaseHuggingFaceHubLLM):
name: str = "llama2"
description: str = "Hugging Face Llama2 Inference API"
repo_id = "Llama-2-70b-chat-hf"

"""
Class for interacting with Hugging Face Llama Inference API
"""

def __init__(self, api_key, *args, **kwargs):
super().__init__(api_key, *args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import validator

from genoss.model.llm.base_genoss import BaseGenossLLM
from genoss.llm.base_genoss import BaseGenossLLM


class BaseLocalLLM(BaseGenossLLM):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
from tkinter.messagebox import QUESTION

from typing import Dict

Expand All @@ -8,8 +7,8 @@
from langchain.llms import GPT4All

from genoss.entities.chat.chat_completion import ChatCompletion
from genoss.model.llm.local.base_local import BaseLocalLLM
from genoss.model.prompts.prompt_template import prompt_template
from genoss.llm.local.base_local import BaseLocalLLM
from genoss.prompts.prompt_template import prompt_template


class Gpt4AllLLM(BaseLocalLLM):
Expand All @@ -25,11 +24,11 @@ def generate_answer(self, question: str) -> Dict:
)

llm_chain = LLMChain(llm=llm, prompt=prompt_template)
response_text = llm_chain(QUESTION)
response_text = llm_chain(question)
print("###################")
print(response_text)
answer = response_text["text"]
# TODO: fix, chat completion expects a list but message is a string...

chat_completion = ChatCompletion(
model=self.name, question=question, answer=answer
)
Expand Down
19 changes: 0 additions & 19 deletions genoss/model/llm/hf_hub/base_hf_hub.py

This file was deleted.

32 changes: 0 additions & 32 deletions genoss/model/llm/hf_hub/falcon.py

This file was deleted.

5 changes: 0 additions & 5 deletions genoss/model/prompts/prompt_template.py

This file was deleted.

5 changes: 5 additions & 0 deletions genoss/prompts/prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from langchain import PromptTemplate

system_prompt = "Question from user: {question}?, Answer from helpful chatbot:"

prompt_template = PromptTemplate(template=system_prompt, input_variables=["question"])
18 changes: 12 additions & 6 deletions genoss/services/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Optional

from genoss.model.llm.base_genoss import BaseGenossLLM
from genoss.model.llm.fake_llm import FAKE_LLM_NAME, FakeLLM
from genoss.model.llm.hf_hub.falcon import HuggingFaceHubFalconLLM
from genoss.model.llm.local.gpt4all import Gpt4AllLLM
from genoss.llm.base_genoss import BaseGenossLLM
from genoss.llm.fake_llm import FAKE_LLM_NAME, FakeLLM
from genoss.llm.hf_hub.falcon import HuggingFaceHubFalconLLM
from genoss.llm.hf_hub.gpt2 import HuggingFaceHubGPT2LLM
from genoss.llm.hf_hub.llama2 import HuggingFaceHubLlama2LLM
from genoss.llm.local.gpt4all import Gpt4AllLLM


class ModelFactory:
@staticmethod
def get_model_from_name(name: str, api_key=None) -> Optional[BaseGenossLLM]:
if name.lower().startswith("gpt"):
if name.lower().startswith("gpt4all"):
return Gpt4AllLLM()
if name.lower().startswith("falcon"):
if name.lower().startswith("hf-llama2"):
return HuggingFaceHubLlama2LLM(api_key=api_key)
if name.lower().startswith("hf-gpt2"):
return HuggingFaceHubGPT2LLM(api_key=api_key)
if name.lower().startswith("hf-falcon"):
return HuggingFaceHubFalconLLM(api_key=api_key)
elif name == FAKE_LLM_NAME:
return FakeLLM()
Expand Down
4 changes: 2 additions & 2 deletions tests/services/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from genoss.model.llm.fake_llm import FAKE_LLM_NAME, FakeLLM
from genoss.model.llm.local.gpt4all import Gpt4AllLLM
from genoss.llm.fake_llm import FAKE_LLM_NAME, FakeLLM
from genoss.llm.local.gpt4all import Gpt4AllLLM
from genoss.services.model_factory import ModelFactory


Expand Down

0 comments on commit 2fd2347

Please sign in to comment.