From e2bda4c4b92795e871d4861549d6fb4446ee7988 Mon Sep 17 00:00:00 2001 From: notsyncing Date: Sun, 1 Sep 2024 22:22:43 +0800 Subject: [PATCH] Add initial integration tests. --- pyproject.toml | 4 + src/azarrot/config.py | 8 +- src/azarrot/models/model_manager.py | 9 +- src/azarrot/server.py | 67 ++++++++-- tests/{models => integration}/__init__.py | 0 .../ipex_llm}/__init__.py | 0 tests/integration/ipex_llm/conftest.py | 38 ++++++ tests/integration/ipex_llm/test_qwen2.py | 121 ++++++++++++++++++ tests/unit/models/__init__.py | 0 tests/unit/models/supports/__init__.py | 0 .../supports/test_qwen2_chat_support.py | 0 tox.ini | 3 +- 12 files changed, 226 insertions(+), 24 deletions(-) rename tests/{models => integration}/__init__.py (100%) rename tests/{models/supports => integration/ipex_llm}/__init__.py (100%) create mode 100644 tests/integration/ipex_llm/conftest.py create mode 100644 tests/integration/ipex_llm/test_qwen2.py create mode 100644 tests/unit/models/__init__.py create mode 100644 tests/unit/models/supports/__init__.py rename tests/{ => unit}/models/supports/test_qwen2_chat_support.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 4923910..e8156ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ test = [ "pytest==7.4.1", "pytest-cov==4.1.0", "coverage[toml]==7.3.1", + "openai==1.43.0" ] doc = [ "sphinx", @@ -117,6 +118,7 @@ lint.ignore = [ [tool.ruff.lint.per-file-ignores] "tests/**" = [ "S101", # Use of `assert` detected + "RUF001", ] "**/__init__.py" = [ "D104", @@ -155,6 +157,8 @@ addopts = """ --cov-config=pyproject.toml --cov-report= """ +log_cli = true +log_cli_level = "INFO" [tool.coverage.paths] # Maps coverage measured in site-packages to source files in src diff --git a/src/azarrot/config.py b/src/azarrot/config.py index 5e15724..2d50093 100644 --- a/src/azarrot/config.py +++ b/src/azarrot/config.py @@ -6,10 +6,10 @@ @dataclass class ServerConfig: - models_dir = Path("./models") - working_dir = Path("./working") - host = "127.0.0.1" - port = 8080 + models_dir: Path = Path("./models") + working_dir: Path = Path("./working") + host: str = "127.0.0.1" + port: int = 8080 model_device_map: dict[str, str] = field(default_factory=dict) single_token_generation_timeout: int = 60000 diff --git a/src/azarrot/models/model_manager.py b/src/azarrot/models/model_manager.py index 9b037c8..dc72b3e 100644 --- a/src/azarrot/models/model_manager.py +++ b/src/azarrot/models/model_manager.py @@ -144,26 +144,27 @@ def get_models(self) -> list[Model]: def get_model(self, model_id: str) -> Model | None: return self._models.get(model_id) - def load_huggingface_model(self, huggingface_id: str, backend_id: str, for_task: str) -> None: + def load_huggingface_model(self, huggingface_id: str, backend_id: str, for_task: str, skip_if_loaded=False) -> None: backend = self._backends.get(backend_id) if backend is None: raise ValueError(f"Unknown backend {backend_id}") - if huggingface_id in self._models: + if huggingface_id in self._models and not skip_if_loaded: raise ValueError(f"Model {huggingface_id} from huggingface is already loaded!") self._log.info("Downloading model %s from huggingface...", huggingface_id) model_path = Path(huggingface_hub.snapshot_download(huggingface_id)) + model_generation_variant = self.__determine_model_generation_variant(model_path) model = Model( id=huggingface_id, backend=backend_id, path=model_path, task=for_task, - generation_variant=self.__determine_model_generation_variant(model_path), - preset=DEFAULT_MODEL_PRESET, + generation_variant=model_generation_variant, + preset=DEFAULT_MODEL_PRESETS.get(model_generation_variant, DEFAULT_MODEL_PRESET), ipex_llm=None, create_time=datetime.now() ) diff --git a/src/azarrot/server.py b/src/azarrot/server.py index 6ea0a0e..86ddc01 100644 --- a/src/azarrot/server.py +++ b/src/azarrot/server.py @@ -1,12 +1,13 @@ import argparse import logging +from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING import uvicorn import yaml from fastapi import FastAPI +from azarrot.backends.backend_base import BaseBackend from azarrot.backends.ipex_llm_backend import IPEXLLMBackend from azarrot.backends.openvino_backend import OpenVINOBackend from azarrot.common_data import WorkingDirectories @@ -17,9 +18,6 @@ from azarrot.models.model_manager import ModelManager from azarrot.tools import GLOBAL_TOOL_MANAGER -if TYPE_CHECKING: - from azarrot.backends.backend_base import BaseBackend - log = logging.getLogger(__name__) @@ -113,12 +111,29 @@ def __create_working_directories(config: ServerConfig) -> WorkingDirectories: return WorkingDirectories(root=config.working_dir, uploaded_images=image_temp_path) -def main() -> None: - logging.basicConfig(level=logging.INFO) +@dataclass +class Server: + config: ServerConfig + model_manager: ModelManager + backend_pipe: BackendPipe + backends: list[BaseBackend] + frontends: list[OpenAIFrontend] + api: FastAPI + + def start(self) -> None: + log.info("Starting API server...") + uvicorn.run(self.api, host=self.config.host, port=self.config.port) + +def create_server( + config: ServerConfig | None = None, + enable_backends: list[type[BaseBackend]] | None = None +) -> Server: log.info("Azarrot is initializing...") - config = __parse_arguments_and_load_config() + if config is None: + config = __parse_arguments_and_load_config() + log.info("Current config:") for attr in dir(config): @@ -129,19 +144,43 @@ def main() -> None: chat_template_manager = ChatTemplateManager(GLOBAL_TOOL_MANAGER) - backends: list[BaseBackend] = [ - IPEXLLMBackend(config), - OpenVINOBackend(config), - ] + backends: list[BaseBackend] + + if enable_backends is not None: + backends = [b(config) for b in enable_backends] + else: + backends = [ + IPEXLLMBackend(config), + OpenVINOBackend(config), + ] + + log.info("Enabled backends: %s", backends) model_manager = ModelManager(config, backends) backend_pipe = BackendPipe(backends, chat_template_manager, GLOBAL_TOOL_MANAGER) - log.info("Starting API server...") api = FastAPI() - OpenAIFrontend(model_manager, backend_pipe, api, working_dirs) - uvicorn.run(api, host=config.host, port=config.port) + + frontends = [ + OpenAIFrontend(model_manager, backend_pipe, api, working_dirs) + ] + + return Server( + config=config, + model_manager=model_manager, + backend_pipe=backend_pipe, + backends=backends, + frontends=frontends, + api=api + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO) + + server = create_server() + server.start() if __name__ == "__main__": diff --git a/tests/models/__init__.py b/tests/integration/__init__.py similarity index 100% rename from tests/models/__init__.py rename to tests/integration/__init__.py diff --git a/tests/models/supports/__init__.py b/tests/integration/ipex_llm/__init__.py similarity index 100% rename from tests/models/supports/__init__.py rename to tests/integration/ipex_llm/__init__.py diff --git a/tests/integration/ipex_llm/conftest.py b/tests/integration/ipex_llm/conftest.py new file mode 100644 index 0000000..e0c3e01 --- /dev/null +++ b/tests/integration/ipex_llm/conftest.py @@ -0,0 +1,38 @@ +import logging +import tempfile +import time +from collections.abc import Generator +from pathlib import Path +from threading import Thread +from typing import Any + +import pytest + +from azarrot.backends.ipex_llm_backend import IPEXLLMBackend +from azarrot.config import ServerConfig +from azarrot.server import Server, create_server + + +@pytest.fixture(scope="module") +def ipex_llm_server() -> Generator[Server, Any, Any]: + logging.basicConfig(level=logging.INFO) + + tmp_dir = tempfile.TemporaryDirectory() + tmp_path = Path(tmp_dir.name).absolute() + + server = create_server( + config=ServerConfig( + models_dir=tmp_path / "models", + working_dir=tmp_path / "working" + ), + enable_backends=[IPEXLLMBackend] + ) + + thread = Thread(target=server.start, daemon=True) + thread.start() + + time.sleep(5) + + yield server + + tmp_dir.cleanup() diff --git a/tests/integration/ipex_llm/test_qwen2.py b/tests/integration/ipex_llm/test_qwen2.py new file mode 100644 index 0000000..4b9661e --- /dev/null +++ b/tests/integration/ipex_llm/test_qwen2.py @@ -0,0 +1,121 @@ +from openai import OpenAI + +from azarrot.backends.ipex_llm_backend import BACKEND_ID_IPEX_LLM +from azarrot.server import Server + +QWEN2_CHAT_MODEL = "Qwen/Qwen2-7B-Instruct" + + +def test_qwen2_hello(ipex_llm_server: Server) -> None: + ipex_llm_server.model_manager.load_huggingface_model( + QWEN2_CHAT_MODEL, BACKEND_ID_IPEX_LLM, "text-generation", + skip_if_loaded=True + ) + + client = OpenAI( + base_url=f"http://{ipex_llm_server.config.host}:{ipex_llm_server.config.port}/v1", + api_key="__TEST__" + ) + + completion = client.chat.completions.create( + model=QWEN2_CHAT_MODEL, + messages=[ + {"role": "system", "content": "你是一个乐于助人的智能助理。"}, + {"role": "user", "content": "你好!"} + ], + seed=100 + ) + + result = completion.choices[0].message + assert result is not None + assert result.content is not None + assert result.content.find("你好!有什么问题我可以帮助你解答吗?") > 0 + + +def test_qwen2_tool_calling(ipex_llm_server: Server) -> None: + ipex_llm_server.model_manager.load_huggingface_model( + QWEN2_CHAT_MODEL, BACKEND_ID_IPEX_LLM, "text-generation", + skip_if_loaded=True + ) + + client = OpenAI( + base_url=f"http://{ipex_llm_server.config.host}:{ipex_llm_server.config.port}/v1", + api_key="__TEST__" + ) + + tools = [ + { + "type": "function", + "function": { + "description": "用于把两个数进行RRR运算的工具", + "name": "rrr-calc", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "第一个数" + }, + "b": { + "type": "number", + "description": "第二个数" + } + }, + "required": ["a", "b"] + } + } + } + ] + + messages = [ + {"role": "user", "content": "193与27的RRR运算结果是多少?"} + ] + + completion = client.chat.completions.create( + model=QWEN2_CHAT_MODEL, + messages=messages, # pyright: ignore[reportArgumentType] + tools=tools, # pyright: ignore[reportArgumentType] + seed=100 + ) + + result = completion.choices[0].message + assert result is not None + assert result.content is None + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + + tool_call = result.tool_calls[0] + assert tool_call.function.name == "rrr-calc" + assert tool_call.function.arguments == '{"a": 193, "b": 27}' + + messages.append({ + "role": "assistant", + "tool_calls": [ # pyright: ignore[reportArgumentType] + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } + ] + }) + + messages.append({ + "role": "tool", + "content": "888", + "tool_call_id": tool_call.id + }) + + completion = client.chat.completions.create( + model=QWEN2_CHAT_MODEL, + messages=messages, # pyright: ignore[reportArgumentType] + tools=tools, # pyright: ignore[reportArgumentType] + seed=100 + ) + + result = completion.choices[0].message + assert result is not None + assert result.content is not None + assert result.content.find("888") > 0 diff --git a/tests/unit/models/__init__.py b/tests/unit/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/models/supports/__init__.py b/tests/unit/models/supports/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/supports/test_qwen2_chat_support.py b/tests/unit/models/supports/test_qwen2_chat_support.py similarity index 100% rename from tests/models/supports/test_qwen2_chat_support.py rename to tests/unit/models/supports/test_qwen2_chat_support.py diff --git a/tox.ini b/tox.ini index 1d56ffd..5642c78 100644 --- a/tox.ini +++ b/tox.ini @@ -27,7 +27,7 @@ setenv = COVERAGE_FILE = reports{/}.coverage.{envname} commands = # Run tests and doctests from .py files - pytest --junitxml=reports/pytest.xml.{envname} {posargs} + pytest --junitxml=reports/pytest.xml.{envname} tests/unit {posargs} [testenv:combine-test-reports] @@ -41,7 +41,6 @@ deps = coverage[toml] commands = junitparser merge --glob reports/pytest.xml.* reports/pytest.xml - coverage combine --keep coverage html