Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer #7836

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

import openai
import requests
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from typing_extensions import ParamSpec

from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip

Expand Down Expand Up @@ -60,36 +61,40 @@ class RemoteOpenAIServer:

def __init__(self,
model: str,
cli_args: List[str],
serve_args: List[str],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename the arg name btw?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Aug 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to associate them with vllm serve specifically, not just any CLI. Perhaps it could be clearer.

Edit: Updated the name to vllm_serve_args

*,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
if not model.startswith("/"):
# download the model if it's not a local path
# to exclude the model download time from the server start time
model = snapshot_download(model)
if auto_port:
if "-p" in cli_args or "--port" in cli_args:
if "-p" in serve_args or "--port" in serve_args:
raise ValueError("You have manually specified the port"
"when `auto_port=True`.")

cli_args = cli_args + ["--port", str(get_open_port())]
serve_args = serve_args + ["--port", str(get_open_port())]

parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you just need args = parser.parse_args(["--model", model, *cli_args])

args = parser.parse_args(["--model", model, *serve_args])
self.host = str(args.host or 'localhost')
self.port = int(args.port)

# download the model before starting the server to avoid timeout
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we directly create a load config object? it should be simple in my opinion. just load format auto.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix the tests first before simplifying this.

Copy link
Member Author

@DarkLight1337 DarkLight1337 Aug 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that familiar with the model loading code so may need some help regarding fixing the tests.

dummy_loader = DefaultModelLoader(engine_config.load_config)
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision,
fall_back_to_pt=True)

env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
self.proc = subprocess.Popen(["vllm", "serve", model, *serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def from_cli_args(cls, args: argparse.Namespace):
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

def create_engine_config(self, ) -> EngineConfig:
def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf"
Expand Down
Loading