-
-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer
#7836
Changes from 1 commit
7f63693
e553a77
4b0c805
c13996b
ffcd139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,13 +11,14 @@ | |
|
||
import openai | ||
import requests | ||
from huggingface_hub import snapshot_download | ||
from transformers import AutoTokenizer | ||
from typing_extensions import ParamSpec | ||
|
||
from vllm.distributed import (ensure_model_parallel_initialized, | ||
init_distributed_environment) | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.entrypoints.openai.cli_args import make_arg_parser | ||
from vllm.model_executor.model_loader.loader import DefaultModelLoader | ||
from vllm.platforms import current_platform | ||
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip | ||
|
||
|
@@ -60,36 +61,40 @@ class RemoteOpenAIServer: | |
|
||
def __init__(self, | ||
model: str, | ||
cli_args: List[str], | ||
serve_args: List[str], | ||
*, | ||
env_dict: Optional[Dict[str, str]] = None, | ||
auto_port: bool = True, | ||
max_wait_seconds: Optional[float] = None) -> None: | ||
if not model.startswith("/"): | ||
# download the model if it's not a local path | ||
# to exclude the model download time from the server start time | ||
model = snapshot_download(model) | ||
if auto_port: | ||
if "-p" in cli_args or "--port" in cli_args: | ||
if "-p" in serve_args or "--port" in serve_args: | ||
raise ValueError("You have manually specified the port" | ||
"when `auto_port=True`.") | ||
|
||
cli_args = cli_args + ["--port", str(get_open_port())] | ||
serve_args = serve_args + ["--port", str(get_open_port())] | ||
|
||
parser = FlexibleArgumentParser( | ||
description="vLLM's remote OpenAI server.") | ||
parser = make_arg_parser(parser) | ||
args = parser.parse_args(cli_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you just need |
||
args = parser.parse_args(["--model", model, *serve_args]) | ||
self.host = str(args.host or 'localhost') | ||
self.port = int(args.port) | ||
|
||
# download the model before starting the server to avoid timeout | ||
engine_args = AsyncEngineArgs.from_cli_args(args) | ||
engine_config = engine_args.create_engine_config() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we directly create a load config object? it should be simple in my opinion. just load format auto. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's fix the tests first before simplifying this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not that familiar with the model loading code so may need some help regarding fixing the tests. |
||
dummy_loader = DefaultModelLoader(engine_config.load_config) | ||
dummy_loader._prepare_weights(engine_config.model_config.model, | ||
engine_config.model_config.revision, | ||
fall_back_to_pt=True) | ||
|
||
env = os.environ.copy() | ||
# the current process might initialize cuda, | ||
# to be safe, we should use spawn method | ||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' | ||
if env_dict is not None: | ||
env.update(env_dict) | ||
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, | ||
self.proc = subprocess.Popen(["vllm", "serve", model, *serve_args], | ||
env=env, | ||
stdout=sys.stdout, | ||
stderr=sys.stderr) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why rename the arg name btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant to associate them with
vllm serve
specifically, not just any CLI. Perhaps it could be clearer.Edit: Updated the name to
vllm_serve_args