Skip to content

Commit

Permalink
Add initial integration tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
notsyncing committed Sep 1, 2024
1 parent 4dd8d40 commit e2bda4c
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 24 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ test = [
"pytest==7.4.1",
"pytest-cov==4.1.0",
"coverage[toml]==7.3.1",
"openai==1.43.0"
]
doc = [
"sphinx",
Expand Down Expand Up @@ -117,6 +118,7 @@ lint.ignore = [
[tool.ruff.lint.per-file-ignores]
"tests/**" = [
"S101", # Use of `assert` detected
"RUF001",
]
"**/__init__.py" = [
"D104",
Expand Down Expand Up @@ -155,6 +157,8 @@ addopts = """
--cov-config=pyproject.toml
--cov-report=
"""
log_cli = true
log_cli_level = "INFO"

[tool.coverage.paths]
# Maps coverage measured in site-packages to source files in src
Expand Down
8 changes: 4 additions & 4 deletions src/azarrot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

@dataclass
class ServerConfig:
models_dir = Path("./models")
working_dir = Path("./working")
host = "127.0.0.1"
port = 8080
models_dir: Path = Path("./models")
working_dir: Path = Path("./working")
host: str = "127.0.0.1"
port: int = 8080

model_device_map: dict[str, str] = field(default_factory=dict)
single_token_generation_timeout: int = 60000
Expand Down
9 changes: 5 additions & 4 deletions src/azarrot/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,27 @@ def get_models(self) -> list[Model]:
def get_model(self, model_id: str) -> Model | None:
return self._models.get(model_id)

def load_huggingface_model(self, huggingface_id: str, backend_id: str, for_task: str) -> None:
def load_huggingface_model(self, huggingface_id: str, backend_id: str, for_task: str, skip_if_loaded=False) -> None:
backend = self._backends.get(backend_id)

if backend is None:
raise ValueError(f"Unknown backend {backend_id}")

if huggingface_id in self._models:
if huggingface_id in self._models and not skip_if_loaded:
raise ValueError(f"Model {huggingface_id} from huggingface is already loaded!")

self._log.info("Downloading model %s from huggingface...", huggingface_id)

model_path = Path(huggingface_hub.snapshot_download(huggingface_id))
model_generation_variant = self.__determine_model_generation_variant(model_path)

model = Model(
id=huggingface_id,
backend=backend_id,
path=model_path,
task=for_task,
generation_variant=self.__determine_model_generation_variant(model_path),
preset=DEFAULT_MODEL_PRESET,
generation_variant=model_generation_variant,
preset=DEFAULT_MODEL_PRESETS.get(model_generation_variant, DEFAULT_MODEL_PRESET),
ipex_llm=None,
create_time=datetime.now()
)
Expand Down
67 changes: 53 additions & 14 deletions src/azarrot/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import uvicorn
import yaml
from fastapi import FastAPI

from azarrot.backends.backend_base import BaseBackend
from azarrot.backends.ipex_llm_backend import IPEXLLMBackend
from azarrot.backends.openvino_backend import OpenVINOBackend
from azarrot.common_data import WorkingDirectories
Expand All @@ -17,9 +18,6 @@
from azarrot.models.model_manager import ModelManager
from azarrot.tools import GLOBAL_TOOL_MANAGER

if TYPE_CHECKING:
from azarrot.backends.backend_base import BaseBackend

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -113,12 +111,29 @@ def __create_working_directories(config: ServerConfig) -> WorkingDirectories:
return WorkingDirectories(root=config.working_dir, uploaded_images=image_temp_path)


def main() -> None:
logging.basicConfig(level=logging.INFO)
@dataclass
class Server:
config: ServerConfig
model_manager: ModelManager
backend_pipe: BackendPipe
backends: list[BaseBackend]
frontends: list[OpenAIFrontend]
api: FastAPI

def start(self) -> None:
log.info("Starting API server...")
uvicorn.run(self.api, host=self.config.host, port=self.config.port)


def create_server(
config: ServerConfig | None = None,
enable_backends: list[type[BaseBackend]] | None = None
) -> Server:
log.info("Azarrot is initializing...")

config = __parse_arguments_and_load_config()
if config is None:
config = __parse_arguments_and_load_config()

log.info("Current config:")

for attr in dir(config):
Expand All @@ -129,19 +144,43 @@ def main() -> None:

chat_template_manager = ChatTemplateManager(GLOBAL_TOOL_MANAGER)

backends: list[BaseBackend] = [
IPEXLLMBackend(config),
OpenVINOBackend(config),
]
backends: list[BaseBackend]

if enable_backends is not None:
backends = [b(config) for b in enable_backends]
else:
backends = [
IPEXLLMBackend(config),
OpenVINOBackend(config),
]

log.info("Enabled backends: %s", backends)

model_manager = ModelManager(config, backends)

backend_pipe = BackendPipe(backends, chat_template_manager, GLOBAL_TOOL_MANAGER)

log.info("Starting API server...")
api = FastAPI()
OpenAIFrontend(model_manager, backend_pipe, api, working_dirs)
uvicorn.run(api, host=config.host, port=config.port)

frontends = [
OpenAIFrontend(model_manager, backend_pipe, api, working_dirs)
]

return Server(
config=config,
model_manager=model_manager,
backend_pipe=backend_pipe,
backends=backends,
frontends=frontends,
api=api
)


def main() -> None:
logging.basicConfig(level=logging.INFO)

server = create_server()
server.start()


if __name__ == "__main__":
Expand Down
File renamed without changes.
File renamed without changes.
38 changes: 38 additions & 0 deletions tests/integration/ipex_llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
import tempfile
import time
from collections.abc import Generator
from pathlib import Path
from threading import Thread
from typing import Any

import pytest

from azarrot.backends.ipex_llm_backend import IPEXLLMBackend
from azarrot.config import ServerConfig
from azarrot.server import Server, create_server


@pytest.fixture(scope="module")
def ipex_llm_server() -> Generator[Server, Any, Any]:
logging.basicConfig(level=logging.INFO)

tmp_dir = tempfile.TemporaryDirectory()
tmp_path = Path(tmp_dir.name).absolute()

server = create_server(
config=ServerConfig(
models_dir=tmp_path / "models",
working_dir=tmp_path / "working"
),
enable_backends=[IPEXLLMBackend]
)

thread = Thread(target=server.start, daemon=True)
thread.start()

time.sleep(5)

yield server

tmp_dir.cleanup()
121 changes: 121 additions & 0 deletions tests/integration/ipex_llm/test_qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from openai import OpenAI

from azarrot.backends.ipex_llm_backend import BACKEND_ID_IPEX_LLM
from azarrot.server import Server

QWEN2_CHAT_MODEL = "Qwen/Qwen2-7B-Instruct"


def test_qwen2_hello(ipex_llm_server: Server) -> None:
ipex_llm_server.model_manager.load_huggingface_model(
QWEN2_CHAT_MODEL, BACKEND_ID_IPEX_LLM, "text-generation",
skip_if_loaded=True
)

client = OpenAI(
base_url=f"http://{ipex_llm_server.config.host}:{ipex_llm_server.config.port}/v1",
api_key="__TEST__"
)

completion = client.chat.completions.create(
model=QWEN2_CHAT_MODEL,
messages=[
{"role": "system", "content": "你是一个乐于助人的智能助理。"},
{"role": "user", "content": "你好!"}
],
seed=100
)

result = completion.choices[0].message
assert result is not None
assert result.content is not None
assert result.content.find("你好!有什么问题我可以帮助你解答吗?") > 0


def test_qwen2_tool_calling(ipex_llm_server: Server) -> None:
ipex_llm_server.model_manager.load_huggingface_model(
QWEN2_CHAT_MODEL, BACKEND_ID_IPEX_LLM, "text-generation",
skip_if_loaded=True
)

client = OpenAI(
base_url=f"http://{ipex_llm_server.config.host}:{ipex_llm_server.config.port}/v1",
api_key="__TEST__"
)

tools = [
{
"type": "function",
"function": {
"description": "用于把两个数进行RRR运算的工具",
"name": "rrr-calc",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "number",
"description": "第一个数"
},
"b": {
"type": "number",
"description": "第二个数"
}
},
"required": ["a", "b"]
}
}
}
]

messages = [
{"role": "user", "content": "193与27的RRR运算结果是多少?"}
]

completion = client.chat.completions.create(
model=QWEN2_CHAT_MODEL,
messages=messages, # pyright: ignore[reportArgumentType]
tools=tools, # pyright: ignore[reportArgumentType]
seed=100
)

result = completion.choices[0].message
assert result is not None
assert result.content is None
assert result.tool_calls is not None
assert len(result.tool_calls) == 1

tool_call = result.tool_calls[0]
assert tool_call.function.name == "rrr-calc"
assert tool_call.function.arguments == '{"a": 193, "b": 27}'

messages.append({
"role": "assistant",
"tool_calls": [ # pyright: ignore[reportArgumentType]
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
]
})

messages.append({
"role": "tool",
"content": "888",
"tool_call_id": tool_call.id
})

completion = client.chat.completions.create(
model=QWEN2_CHAT_MODEL,
messages=messages, # pyright: ignore[reportArgumentType]
tools=tools, # pyright: ignore[reportArgumentType]
seed=100
)

result = completion.choices[0].message
assert result is not None
assert result.content is not None
assert result.content.find("888") > 0
Empty file added tests/unit/models/__init__.py
Empty file.
Empty file.
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ setenv =
COVERAGE_FILE = reports{/}.coverage.{envname}
commands =
# Run tests and doctests from .py files
pytest --junitxml=reports/pytest.xml.{envname} {posargs}
pytest --junitxml=reports/pytest.xml.{envname} tests/unit {posargs}


[testenv:combine-test-reports]
Expand All @@ -41,7 +41,6 @@ deps =
coverage[toml]
commands =
junitparser merge --glob reports/pytest.xml.* reports/pytest.xml
coverage combine --keep
coverage html


Expand Down

0 comments on commit e2bda4c

Please sign in to comment.