From 459bd2170969af76ab3646b9a9be157ed9dc3349 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 23 Feb 2024 00:09:50 +0000 Subject: [PATCH] update --- src/fastserve/models/vllm.py | 84 +++++++++++++++++++++++++----------- src/fastserve/utils.py | 27 +++++++++++- 2 files changed, 84 insertions(+), 27 deletions(-) diff --git a/src/fastserve/models/vllm.py b/src/fastserve/models/vllm.py index 3f57fe2..dcf9b9f 100644 --- a/src/fastserve/models/vllm.py +++ b/src/fastserve/models/vllm.py @@ -1,46 +1,78 @@ +import logging import os -from typing import List +from typing import Any, List, Optional -from fastapi import FastAPI +from llama_cpp import Llama from pydantic import BaseModel -from vllm import LLM, SamplingParams -tensor_parallel_size = int(os.environ.get("DEVICES", "1")) -print("tensor_parallel_size: ", tensor_parallel_size) +from fastserve.core import FastServe + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "openhermes-2-mistral-7b.Q6_K.gguf" -llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size) class PromptRequest(BaseModel): - prompt: str - temperature: float = 1 + prompt: str = "Llamas are cute animal" + temperature: float = 0.8 + top_p: float = 0.0 max_tokens: int = 200 stop: List[str] = [] class ResponseModel(BaseModel): prompt: str - prompt_token_ids: List # The token IDs of the prompt. - outputs: List[str] # The output sequences of the request. + prompt_token_ids: Optional[List] = None # The token IDs of the prompt. + text: str # The output sequences of the request. finished: bool # Whether the whole request is finished. -app = FastAPI() +class ServeVLLM(FastServe): + def __init__( + self, + model_path=DEFAULT_MODEL, + batch_size=1, + timeout=0.0, + *args, + **kwargs, + ): + from vllm import LLM, SamplingParams + + if not os.path.exists(model_path): + raise FileNotFoundError(f"{model_path} not found.") + + self.llm = LLM(model_path) + self.model_path = model_path + self.args = args + self.kwargs = kwargs + super().__init__( + batch_size, + timeout, + input_schema=PromptRequest, + response_schema=ResponseModel, + ) + + def __call__(self, request: PromptRequest) -> Any: + from vllm import SamplingParams + + sampling_params = SamplingParams(temperature=request.temperature, top_p=request.top_p) + result = self.llm(request.prompt, sampling_params=sampling_params) + logger.info(result) + return result + def handle(self, batch: List[PromptRequest]) -> List[ResponseModel]: + responses = [] + for request in batch: + output = self(request) -@app.post("/serve", response_model=ResponseModel) -def serve(request: PromptRequest): - sampling_params = SamplingParams( - max_tokens=request.max_tokens, - temperature=request.temperature, - stop=request.stop, - ) + response = ResponseModel( + **{ + "prompt": request.prompt, + "text": output["choices"][0]["text"], + "finished": True, + } + ) + responses.append(response) - result = llm.generate(request.prompt, sampling_params=sampling_params)[0] - response = ResponseModel( - prompt=request.prompt, - prompt_token_ids=result.prompt_token_ids, - outputs=result.outputs, - finished=result.finished, - ) - return response + return responses diff --git a/src/fastserve/utils.py b/src/fastserve/utils.py index 55deeb6..32612d2 100644 --- a/src/fastserve/utils.py +++ b/src/fastserve/utils.py @@ -1,6 +1,5 @@ import os from typing import Any - from pydantic import BaseModel @@ -23,3 +22,29 @@ def get_ui_folder(): path = os.path.join(os.path.dirname(__file__), "../ui") path = os.path.abspath(path) return path + + +def download_file(url:str, dest:str): + import requests + from tqdm import tqdm + from huggingface_hub import HfApi, ModelFilter + + + if dest is None: + dest = os.path.abspath(os.path.basename(dest)) + + response = requests.get(url, stream=True) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 + with open(dest, 'wb') as file, tqdm( + desc=dest, + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(block_size): + file.write(data) + bar.update(len(data)) + return dest