Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Feb 23, 2024
1 parent 5e6f5f3 commit 459bd21
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 27 deletions.
84 changes: 58 additions & 26 deletions src/fastserve/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,78 @@
import logging
import os
from typing import List
from typing import Any, List, Optional

from fastapi import FastAPI
from llama_cpp import Llama
from pydantic import BaseModel
from vllm import LLM, SamplingParams

tensor_parallel_size = int(os.environ.get("DEVICES", "1"))
print("tensor_parallel_size: ", tensor_parallel_size)
from fastserve.core import FastServe

logger = logging.getLogger(__name__)

DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf"

llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size)


class PromptRequest(BaseModel):
prompt: str
temperature: float = 1
prompt: str = "Llamas are cute animal"
temperature: float = 0.8
top_p: float = 0.0
max_tokens: int = 200
stop: List[str] = []


class ResponseModel(BaseModel):
prompt: str
prompt_token_ids: List # The token IDs of the prompt.
outputs: List[str] # The output sequences of the request.
prompt_token_ids: Optional[List] = None # The token IDs of the prompt.
text: str # The output sequences of the request.
finished: bool # Whether the whole request is finished.


app = FastAPI()
class ServeVLLM(FastServe):
def __init__(
self,
model_path=DEFAULT_MODEL,
batch_size=1,
timeout=0.0,
*args,
**kwargs,
):
from vllm import LLM, SamplingParams

if not os.path.exists(model_path):
raise FileNotFoundError(f"{model_path} not found.")

self.llm = LLM(model_path)
self.model_path = model_path
self.args = args
self.kwargs = kwargs
super().__init__(
batch_size,
timeout,
input_schema=PromptRequest,
response_schema=ResponseModel,
)

def __call__(self, request: PromptRequest) -> Any:
from vllm import SamplingParams

sampling_params = SamplingParams(temperature=request.temperature, top_p=request.top_p)
result = self.llm(request.prompt, sampling_params=sampling_params)
logger.info(result)
return result

def handle(self, batch: List[PromptRequest]) -> List[ResponseModel]:
responses = []
for request in batch:
output = self(request)

@app.post("/serve", response_model=ResponseModel)
def serve(request: PromptRequest):
sampling_params = SamplingParams(
max_tokens=request.max_tokens,
temperature=request.temperature,
stop=request.stop,
)
response = ResponseModel(
**{
"prompt": request.prompt,
"text": output["choices"][0]["text"],
"finished": True,
}
)
responses.append(response)

result = llm.generate(request.prompt, sampling_params=sampling_params)[0]
response = ResponseModel(
prompt=request.prompt,
prompt_token_ids=result.prompt_token_ids,
outputs=result.outputs,
finished=result.finished,
)
return response
return responses
27 changes: 26 additions & 1 deletion src/fastserve/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from typing import Any

from pydantic import BaseModel


Expand All @@ -23,3 +22,29 @@ def get_ui_folder():
path = os.path.join(os.path.dirname(__file__), "../ui")
path = os.path.abspath(path)
return path


def download_file(url:str, dest:str):
import requests
from tqdm import tqdm
from huggingface_hub import HfApi, ModelFilter


if dest is None:
dest = os.path.abspath(os.path.basename(dest))

response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with open(dest, 'wb') as file, tqdm(
desc=dest,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
file.write(data)
bar.update(len(data))
return dest

0 comments on commit 459bd21

Please sign in to comment.