Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat_completion class #4

Merged
merged 3 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions genoss/api/completions_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from fastapi import APIRouter, Body, HTTPException

from genoss.model.messages import Message
from genoss.services.model_factory import ModelFactory
from logger import get_logger
from typing import List, Dict, Optional
Expand All @@ -9,11 +11,6 @@
completions_router = APIRouter()


class Message(BaseModel):
role: str
content: str


class RequestBody(BaseModel):
model: str
messages: List[Message]
Expand Down
54 changes: 54 additions & 0 deletions genoss/model/chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, Any

import time
import uuid

from genoss.api.completions_routes import Message


class Choice:
def __init__(self, message: Message, finish_reason: str = "stop", index: int = 0):
self.message = message
self.finish_reason = finish_reason
self.index = index

def to_dict(self) -> Dict[str, Any]:
return {
"message": self.message.to_dict(),
"finish_reason": self.finish_reason,
"index": self.index
}


class Usage:
def __init__(self, prompt_tokens: int, completion_tokens: int, total_tokens: int):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens

def to_dict(self) -> Dict[str, Any]:
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens
}


class ChatCompletion:
def __init__(self, model: str, last_messages: list, answer: str):
self.id = str(uuid.uuid4())
self.object = "chat.completion"
self.created = int(time.time())
self.model = model
self.usage = Usage(len(last_messages), len(answer), len(last_messages) + len(answer))
self.choices = [Choice(Message(role="assistant", content=answer), "stop", 0)]

def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"object": self.object,
"created": self.created,
"model": self.model,
"usage": self.usage.to_dict(),
"choices": [choice.to_dict() for choice in self.choices]
}
38 changes: 9 additions & 29 deletions genoss/model/fake_llm.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,37 @@
from __future__ import annotations
import uuid
from typing import Dict

from langchain import PromptTemplate, LLMChain
from langchain.llms import FakeListLLM
from langchain.embeddings import GPT4AllEmbeddings, FakeEmbeddings
import time
from langchain.embeddings import FakeEmbeddings
from genoss.model.base_genoss_llm import BaseGenossLLM
from genoss.model.chat_completion import ChatCompletion

FAKE_LLM_NAME = "fake"


class FakeLLM(BaseGenossLLM):
name: str = "fake"
name: str = FAKE_LLM_NAME
description: str = "Fake LLM for testing purpose"
model_path: str = ""

def generate_answer(self, messages: list) -> Dict:
print("Generating Answer")
print(messages)
last_message = messages
last_messages = messages

llm = FakeListLLM(responses=["Hello from FakeLLM!"])
prompt_template = "Question from user: {question}?, Answer from bot:"
llm_chain = LLMChain(
llm=llm, prompt=PromptTemplate.from_template(prompt_template)
)
response_text = llm_chain(last_message)
response_text = llm_chain(last_messages)
print("###################")
print(response_text)
answer = response_text["text"]
chat_completion = ChatCompletion(model=self.name, answer=answer, last_messages=last_messages)

# Format the response to match OpenAI's format
unique_id = uuid.uuid4()
response = {
"id": unique_id, # You might want to generate a unique ID here
"object": "chat.completion",
"created": int(time.time()), # This gets the current Unix timestamp
"model": self.name,
"usage": {
"prompt_tokens": len(last_message), # This is a simplification
"completion_tokens": len(answer), # This is a simplification
"total_tokens": len(last_message)
+ len(answer), # This is a simplification
},
"choices": [
{
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop", # This might not always be 'stop'
"index": 0,
}
],
}

return response
return chat_completion.to_dict()

def generate_embedding(self, text: str):
model = FakeEmbeddings()
Expand Down
30 changes: 5 additions & 25 deletions genoss/model/gpt4all_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from langchain import PromptTemplate, LLMChain
from langchain.llms import GPT4All
from langchain.embeddings import GPT4AllEmbeddings
import time
from genoss.model.base_genoss_llm import BaseGenossLLM
from genoss.model.chat_completion import ChatCompletion


class Gpt4AllLLM(BaseGenossLLM):
Expand All @@ -17,7 +17,7 @@ class Gpt4AllLLM(BaseGenossLLM):
def generate_answer(self, messages: list) -> Dict:
print("Generating Answer")
print(messages)
last_message = messages
last_messages = messages

llm = GPT4All(
model=self.model_path, # pyright: ignore reportPrivateUsage=none
Expand All @@ -26,33 +26,13 @@ def generate_answer(self, messages: list) -> Dict:
llm_chain = LLMChain(
llm=llm, prompt=PromptTemplate.from_template(prompt_template)
)
response_text = llm_chain(last_message)
response_text = llm_chain(last_messages)
print("###################")
print(response_text)
answer = response_text["text"]
chat_completion = ChatCompletion(model=self.name, answer=answer, last_messages=last_messages)

# Format the response to match OpenAI's format
response = {
"id": "chatcmpl-abc123", # You might want to generate a unique ID here
"object": "chat.completion",
"created": int(time.time()), # This gets the current Unix timestamp
"model": "gpt4all",
"usage": {
"prompt_tokens": len(last_message), # This is a simplification
"completion_tokens": len(answer), # This is a simplification
"total_tokens": len(last_message)
+ len(answer), # This is a simplification
},
"choices": [
{
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop", # This might not always be 'stop'
"index": 0,
}
],
}

return response
return chat_completion.to_dict()

def generate_embedding(self, embedding: str | list[str]):
gpt4all_embd = GPT4AllEmbeddings() # pyright: ignore reportPrivateUsage=none
Expand Down
11 changes: 10 additions & 1 deletion genoss/model/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict, Any

from pydantic import BaseModel, Field

Expand All @@ -13,6 +13,15 @@ class Message(BaseModel):
description="The contents of the message. content is required for all messages, and may be null for assistant messages with function calls.",
)

def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content
}


class Messages(BaseModel):
messages: List[Message]

def to_dict(self) -> List[Dict[str, Any]]:
return [message.to_dict() for message in self.messages]
4 changes: 2 additions & 2 deletions genoss/services/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
from genoss.model.base_genoss_llm import BaseGenossLLM
from genoss.model.fake_llm import FakeLLM
from genoss.model.fake_llm import FakeLLM, FAKE_LLM_NAME
from genoss.model.gpt4all_llm import Gpt4AllLLM


Expand All @@ -9,6 +9,6 @@ class ModelFactory:
def get_model_from_name(name: str) -> Optional[BaseGenossLLM]:
if name.lower().startswith("gpt"):
return Gpt4AllLLM()
elif name == "fake":
elif name == FAKE_LLM_NAME:
return FakeLLM()
return None
4 changes: 2 additions & 2 deletions tests/services/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from genoss.model.fake_llm import FakeLLM
from genoss.model.fake_llm import FakeLLM, FAKE_LLM_NAME
from genoss.model.gpt4all_llm import Gpt4AllLLM
from genoss.services.model_factory import ModelFactory

Expand All @@ -11,7 +11,7 @@ def test_get_model_from_name_gpt(self):
self.assertIsInstance(model, Gpt4AllLLM)

def test_get_model_from_name_fake(self):
model = ModelFactory.get_model_from_name('fake')
model = ModelFactory.get_model_from_name(FAKE_LLM_NAME)
self.assertIsInstance(model, FakeLLM)

def test_get_model_from_name_unknown(self):
Expand Down