Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vLLM #21

Merged
merged 8 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ python -m fastserve

## Usage/Examples

### Serve Mistral-7B with Llama-cpp

### Serve LLMs with Llama-cpp

```python
from fastserve.models import ServeLlamaCpp
Expand All @@ -38,6 +39,29 @@ serve.run_server()

or, run `python -m fastserve.models --model llama-cpp --model_path openhermes-2-mistral-7b.Q5_K_M.gguf` from terminal.


### Serve vLLM

```python
from fastserve.models import ServeVLLM

app = ServeVLLM("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
app.run_server()
```

You can use the FastServe client that will automatically apply chat template for you -

```python
from fastserve.client import vLLMClient
from rich import print

client = vLLMClient("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
response = client.chat("Write a python function to resize image to 224x224", keep_context=True)
# print(client.context)
print(response["outputs"][0]["text"])
```


### Serve SDXL Turbo

```python
Expand Down
Empty file.
47 changes: 47 additions & 0 deletions src/fastserve/client/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

import requests


class Client:
def __init__(self):
pass


class vLLMClient(Client):
def __init__(self, model: str, base_url="http://localhost:8000/endpoint"):
from transformers import AutoTokenizer

super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.context = []
self.base_url = base_url

def chat(self, prompt: str, keep_context=False):
new_msg = {"role": "user", "content": prompt}
if keep_context:
self.context.append(new_msg)
messages = self.context
else:
messages = [new_msg]

logging.info(messages)
chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
data = {
"prompt": chat,
"temperature": 0.8,
"top_p": 1,
"max_tokens": 500,
"stop": [],
}

response = requests.post(self.base_url, headers=headers, json=data).json()
if keep_context:
self.context.append(
{"role": "assistant", "content": response["outputs"][0]["text"]}
)
return response
1 change: 1 addition & 0 deletions src/fastserve/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from fastserve.models.llama_cpp import ServeLlamaCpp as ServeLlamaCpp
from fastserve.models.sdxl_turbo import ServeSDXLTurbo as ServeSDXLTurbo
from fastserve.models.ssd import ServeSSD1B as ServeSSD1B
from fastserve.models.vllm import ServeVLLM as ServeVLLM
73 changes: 46 additions & 27 deletions src/fastserve/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,65 @@
import os
from typing import List
import logging
from typing import Any, List, Optional

from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams

tensor_parallel_size = int(os.environ.get("DEVICES", "1"))
print("tensor_parallel_size: ", tensor_parallel_size)
from fastserve.core import FastServe

llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size)
logger = logging.getLogger(__name__)


class PromptRequest(BaseModel):
prompt: str
temperature: float = 1
prompt: str = "Write a python function to resize image to 224x224"
temperature: float = 0.8
top_p: float = 1.0
max_tokens: int = 200
stop: List[str] = []


class ResponseModel(BaseModel):
prompt: str
prompt_token_ids: List # The token IDs of the prompt.
outputs: List[str] # The output sequences of the request.
prompt_token_ids: Optional[List] = None # The token IDs of the prompt.
text: str # The output sequences of the request.
finished: bool # Whether the whole request is finished.


app = FastAPI()
class ServeVLLM(FastServe):
def __init__(
self,
model,
batch_size=1,
timeout=0.0,
*args,
**kwargs,
):
from vllm import LLM

self.llm = LLM(model)
self.args = args
self.kwargs = kwargs
super().__init__(
batch_size,
timeout,
input_schema=PromptRequest,
# response_schema=ResponseModel,
)

def __call__(self, request: PromptRequest) -> Any:
from vllm import SamplingParams

sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
result = self.llm.generate(request.prompt, sampling_params=sampling_params)
logger.info(result)
return result

@app.post("/serve", response_model=ResponseModel)
def serve(request: PromptRequest):
sampling_params = SamplingParams(
max_tokens=request.max_tokens,
temperature=request.temperature,
stop=request.stop,
)
def handle(self, batch: List[PromptRequest]) -> List:
responses = []
for request in batch:
output = self(request)
responses.extend(output)

result = llm.generate(request.prompt, sampling_params=sampling_params)[0]
response = ResponseModel(
prompt=request.prompt,
prompt_token_ids=result.prompt_token_ids,
outputs=result.outputs,
finished=result.finished,
)
return response
return responses
24 changes: 24 additions & 0 deletions src/fastserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,27 @@ def get_ui_folder():
path = os.path.join(os.path.dirname(__file__), "../ui")
path = os.path.abspath(path)
return path


def download_file(url: str, dest: str):
import requests
from tqdm import tqdm

if dest is None:
dest = os.path.abspath(os.path.basename(dest))

response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with open(dest, "wb") as file, tqdm(
desc=dest,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
file.write(data)
bar.update(len(data))
return dest
Loading