From 70b8c9a9c8696782d48b2c9121eed870323854f0 Mon Sep 17 00:00:00 2001 From: Ziming Huang <48115868+ZeldaHuang@users.noreply.github.com> Date: Mon, 26 Aug 2024 17:45:49 +0800 Subject: [PATCH] [CI] Add unittest for llumlet and backends (#14) --- docs/Quickstart.md | 6 +- tests/__init__.py | 12 + tests/backends/__init__.py | 12 + tests/backends/vllm/__init__.py | 12 + tests/backends/vllm/test_llm_engine.py | 91 ++++++++ tests/backends/vllm/test_migration.py | 144 ++++++++++++ tests/backends/vllm/test_scheduler.py | 207 ++++++++++++++++++ tests/backends/vllm/utils.py | 103 +++++++++ tests/entrypoints/__init__.py | 12 + tests/entrypoints/vllm/__init__.py | 12 + tests/global_scheduler/__init__.py | 12 + .../test_llm_engine_manager.py | 4 +- tests/llumlet/__init__.py | 12 + tests/llumlet/test_migration_coordinator.py | 110 ++++++++++ 14 files changed, 744 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/backends/__init__.py create mode 100644 tests/backends/vllm/__init__.py create mode 100644 tests/backends/vllm/test_llm_engine.py create mode 100644 tests/backends/vllm/test_migration.py create mode 100644 tests/backends/vllm/test_scheduler.py create mode 100644 tests/backends/vllm/utils.py create mode 100644 tests/entrypoints/__init__.py create mode 100644 tests/entrypoints/vllm/__init__.py create mode 100644 tests/global_scheduler/__init__.py create mode 100644 tests/llumlet/__init__.py create mode 100644 tests/llumlet/test_migration_coordinator.py diff --git a/docs/Quickstart.md b/docs/Quickstart.md index c046326..56609ff 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -15,7 +15,7 @@ cd llumnix make install ``` -If you want to use gloo as migration backend, please install [Bazel](https://github.com/bazelbuild/bazel) >= 5.1.0. Then, run `make pygloo` to install [pygloo](https://github.com/ZeldaHuang/pygloo). +If you want to use gloo as migration backend, please refer to [this link](https://github.com/ZeldaHuang/pygloo/blob/main/.github/workflows/ubuntu_basic.yml#L24C1-L26C1) to install [Bazel](https://github.com/bazelbuild/bazel) >= 5.1.0. Then, run `make pygloo` to install [pygloo](https://github.com/ZeldaHuang/pygloo). Note: Using conda is not recommended, as it cannot properly handle pygloo's dependency on gcc libstdc++.so.6: version GLIBCXX_3.4.30. @@ -45,7 +45,7 @@ python -m llumnix.entrypoints.vllm.api_server \ Upon starting the server, Llumnix's components are automatically configured. In addition to the server arguments provided above, it's necessary to specify both the Llumnix arguments and the vLLM arguments. For detailed configuration options, please consult the documentation for [Llumnix arguments](./Arguments.md) and [vLLM arguments](https://docs.vllm.ai/en/v0.4.2/models/engine_args.html). -2. Launch multiple servers and connect to the Llumnix cluster. Llumnix uses Ray to manage multiple vLLM servers and instances. You need to configure the following environment variables for Llumnix to correctly set up the cluster. +2. Launch multiple servers and connect to the Llumnix cluster. Llumnix uses Ray to manage multiple vLLM servers and instances. You need to configure the following environment variables for Llumnix to correctly set up the cluster. ``` # Configure on all nodes. export HEAD_NODE_IP=$HEAD_NODE_IP_ADDRESS @@ -66,7 +66,7 @@ When you include the --launch-ray-cluster option in Llumnix's serving deployment # Benchmarking -We provide a benchmarking example to help you get through the usage of Llumnix. +We provide a benchmarking example to help you get through the usage of Llumnix. First, you should start the server to launch Llumnix and backend LLM engine instances: ``` HEAD_NODE=1 python -m llumnix.entrypoints.vllm.api_server \ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/backends/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/backends/vllm/__init__.py b/tests/backends/vllm/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/backends/vllm/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/backends/vllm/test_llm_engine.py b/tests/backends/vllm/test_llm_engine.py new file mode 100644 index 0000000..18e20c3 --- /dev/null +++ b/tests/backends/vllm/test_llm_engine.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from unittest.mock import MagicMock + +from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, + SequenceStatus,SamplerOutput) +from vllm import EngineArgs +from vllm.engine.output_processor.single_step import SingleStepOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +from llumnix.backends.vllm.llm_engine import LLMEngineLlumnix +from llumnix.backends.vllm.executor import LlumnixRayGPUExecutor, SimGPUExecutor +from llumnix.backends.profiling import LatencyMemData + +from .utils import create_dummy_prompt, initialize_scheduler + + +class MockEngine(LLMEngineLlumnix): + def __init__(self, executor_class=None, *args, **kwargs): + self.scheduler = initialize_scheduler() + detokenizer = MagicMock(spec=Detokenizer) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + self.executor_class = executor_class + + self.output_processor = SingleStepOutputProcessor(self.scheduler.scheduler_config,detokenizer, self.scheduler, seq_counter, stop_checker) + + +def test_llm_engine_process_model_outputs(): + + llm_engine = MockEngine() + _, seq_group_0 = create_dummy_prompt( + "0", prompt_length=7, block_size=4 + ) + _, seq_group_1 = create_dummy_prompt( + "1", prompt_length=7, block_size=4 + ) + llm_engine.scheduler.add_seq_group(seq_group_0) + llm_engine.scheduler.add_seq_group(seq_group_1) + metas, out = llm_engine.scheduler.schedule() + + seqs = [seq_group_0.get_seqs()[0], seq_group_1.get_seqs()[0]] + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=1, + logprobs={1: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for seq in seqs + ] + sampler_outputs = [SamplerOutput(outputs=outputs)] + + scheduled_seq_groups = out.scheduled_seq_groups + # normal case, all requests be processed + ret = llm_engine._process_model_outputs(sampler_outputs, scheduled_seq_groups,[], metas) + assert len(ret) == 2 + metas, out = llm_engine.scheduler.schedule() + scheduled_seq_groups = out.scheduled_seq_groups + seqs[0].status=SequenceStatus.WAITING + # migration case , requests stopping during last stage migration, stop process + ret = llm_engine._process_model_outputs(sampler_outputs, scheduled_seq_groups,[], metas) + assert len(ret) == 1 + +def test_llm_engine_from_engine_args(): + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + llm_engine = MockEngine.from_engine_args(engine_args, instance_id="0", migration_config=None) + assert llm_engine.executor_class == LlumnixRayGPUExecutor + + latency_data = LatencyMemData({},{},{}) + llm_engine = MockEngine.from_engine_args(engine_args, instance_id="0", migration_config=None, latency_mem=latency_data) + assert llm_engine.executor_class == SimGPUExecutor \ No newline at end of file diff --git a/tests/backends/vllm/test_migration.py b/tests/backends/vllm/test_migration.py new file mode 100644 index 0000000..009df22 --- /dev/null +++ b/tests/backends/vllm/test_migration.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import time +import ray +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from vllm import EngineArgs, SamplingParams +from vllm.utils import random_uuid + +from llumnix.backends.vllm.llm_engine import BackendVLLM +from llumnix.llumlet.llumlet import Llumlet +from llumnix.backends.utils import BackendType +from llumnix.config import MigrationConfig +from llumnix.server_info import ServerInfo + +from .test_llm_engine import MockEngine +from .utils import create_dummy_prompt + +TEST_PROMPTS = ["hello world, ", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.\n", + "Write a short story about a robot that dreams for the first time.\n", + "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.\n", + "Swahili: 'The early bird catches the worm.'\n"] + +class MockBackendVLLM(BackendVLLM): + def __init__(self): + self.engine = MockEngine() + +class MockLlumlet(Llumlet): + def __init__(self): + self.instance_id = "0" + self.backend_engine = MockBackendVLLM() + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_migration_correctness(): + ray.init(namespace="llumnix", ignore_reinit_error=True) + engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True) + id_rank_map = {"0":0,"1":1} + migration_config = MigrationConfig("LCFS", "gloo",16,1,4,5,20) + que = RayQueue(actor_options={ + "scheduling_strategy": NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False,) + }) + server_info = ServerInfo("0",que) + + llumlet_0:Llumlet = Llumlet.from_args( + False, + True, + ray.get_runtime_context().get_node_id(), + "0", + BackendType.VLLM, + 1, + migration_config, + engine_args,) + + llumlet_1:Llumlet = Llumlet.from_args( + False, + True, + ray.get_runtime_context().get_node_id(), + "1", + BackendType.VLLM, + 1, + migration_config, + engine_args, + ) + while True: + res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) + if all(res): + break + ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) + # empty instance migrate out + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert not res + + # running without migration + def test_correctness(prompt): + sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100) + request_id0 = random_uuid() + llumlet_0.generate.remote(request_id0, server_info, prompt, sampling_params) + request_output_queue = que + origin_output = None + finished = False + while not finished: + qsize = ray.get(request_output_queue.actor.qsize.remote()) + request_outputs = ray.get(request_output_queue.actor.get_nowait_batch.remote(qsize)) + for request_output in request_outputs: + origin_output = request_output.outputs[0] + finished = request_output.finished + + request_id1 = random_uuid() + llumlet_0.generate.remote(request_id1, server_info, prompt, sampling_params) + # wait prefill done + while True: + if ray.get(llumlet_0.execute_engine_method.remote("get_last_running_request")): + break + # migrate request + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert len(res) == 1 + request_output_queue = que + output = None + finished = False + while not finished: + qsize = ray.get(request_output_queue.actor.qsize.remote()) + request_outputs = ray.get(request_output_queue.actor.get_nowait_batch.remote(qsize)) + for request_output in request_outputs: + if request_output.request_id != request_id1: + continue + output = request_output.outputs[0] + finished = request_output.finished + assert output.text == origin_output.text + assert output.cumulative_logprob == origin_output.cumulative_logprob + for prompt in TEST_PROMPTS: + test_correctness(prompt) + ray.shutdown() + +def test_clear_migration_states(): + llumlet = MockLlumlet() + llumlet.backend_engine.pre_alloc("0", 1) + num_gpu_blocks = 8 + block_size = 4 + + llumlet.clear_migration_states(is_migrate_in=True) + assert len(llumlet.backend_engine.pre_alloc("0", num_gpu_blocks)) == num_gpu_blocks + _, seq_group = create_dummy_prompt("0",7,block_size) + llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) + llumlet.clear_migration_states(is_migrate_in=False) + assert llumlet.backend_engine.get_last_running_request() is not None \ No newline at end of file diff --git a/tests/backends/vllm/test_scheduler.py b/tests/backends/vllm/test_scheduler.py new file mode 100644 index 0000000..8de0a92 --- /dev/null +++ b/tests/backends/vllm/test_scheduler.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from vllm.sequence import Sequence +from vllm.sequence import Logprob +from llumnix.backends.vllm.scheduler import BlockManagerLlumnix +from .utils import create_dummy_prompt, initialize_scheduler + + +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + +def schedule_and_update_computed_tokens(scheduler): + metas, out = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + +def append_new_token(out, token_id: int): + seq_groups = get_sequence_groups(out) + for seq_group in seq_groups: + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + +def test_manager_get_free_blocks(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockManagerLlumnix(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + before_allocate = block_manager.get_num_free_gpu_blocks() + block_table = block_manager.get_free_blocks(2) + after_allocate = block_manager.get_num_free_gpu_blocks() + assert after_allocate + 2 == before_allocate + block_manager._free_block_table(block_table) + after_free = block_manager.get_num_free_gpu_blocks() + assert after_free == before_allocate + +def test_manager_add_block_table(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockManagerLlumnix(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + block_table = block_manager.get_free_blocks(2) + seq = Sequence(1,"1",[0], block_size=block_size) + block_manager.add_block_table(block_table, seq.seq_id) + after_allocate = block_manager.get_num_free_gpu_blocks() + assert after_allocate + 2 == num_gpu_blocks + block_manager.free(seq) + after_free = block_manager.get_num_free_gpu_blocks() + assert after_free == num_gpu_blocks + +def test_scheduler_policy(): + scheduler = initialize_scheduler() + num_seq_group = 4 + block_size = 4 + for idx in range(1, num_seq_group + 1): + _, seq_group = create_dummy_prompt(str(idx), prompt_length=idx, block_size=block_size) + scheduler.add_seq_group(seq_group) + + # all seq_group in waiting queue + migrating_request = scheduler.get_last_running_request() + assert migrating_request == None + migrating_request = scheduler.get_shortest_running_request() + assert migrating_request == None + migrating_request = scheduler.get_longest_running_request() + assert migrating_request == None + # all seq_group in prefilling stage + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + migrating_request = scheduler.get_last_running_request() + assert migrating_request == None + migrating_request = scheduler.get_shortest_running_request() + assert migrating_request == None + migrating_request = scheduler.get_longest_running_request() + assert migrating_request == None + append_new_token(out, 1) + schedule_and_update_computed_tokens(scheduler) + # all in running queue + migrating_request = scheduler.get_last_running_request() + assert migrating_request.request_id == str(num_seq_group) + migrating_request = scheduler.get_shortest_running_request() + assert migrating_request.request_id == "1" + migrating_request = scheduler.get_longest_running_request() + assert migrating_request.request_id == str(num_seq_group) + +def test_scheduler_num_killed_request(): + scheduler = initialize_scheduler() + # tot 8 blocks + num_seq_group = 4 + block_size = 4 + for idx in range(1, num_seq_group + 1): + _, seq_group = create_dummy_prompt(str(idx), prompt_length=8, block_size=block_size) + scheduler.add_seq_group(seq_group) + # remain 0 blocks + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + assert scheduler._get_num_killed_requests() == 0 + # preempt 2 requests + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert scheduler._get_num_killed_requests() == 2 + +def test_scheduler_running_request(): + scheduler = initialize_scheduler() + num_seq_group = 4 + block_size = 4 + for idx in range(1, num_seq_group + 1): + _, seq_group = create_dummy_prompt(str(idx), prompt_length=idx, block_size=block_size) + scheduler.add_seq_group(seq_group) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert scheduler.get_num_unfinished_seq_groups() == 4 + scheduler.remove_running_request("1") + assert scheduler.get_num_unfinished_seq_groups() == 3 + _, seq_group = create_dummy_prompt("5", prompt_length=idx, block_size=block_size) + scheduler.add_running_request(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == 4 + +def test_scheduler_migrating_out_request_last_stage(): + scheduler = initialize_scheduler() + block_size = 4 + _, seq_group = create_dummy_prompt("1", prompt_length=1, block_size=block_size) + scheduler.add_migrating_out_request_last_stage(seq_group) + assert len(scheduler.pop_migrating_out_requests_last_stage()) == 1 + scheduler.add_migrating_out_request_last_stage(seq_group) + scheduler.remove_migrating_out_request_last_stage(seq_group) + assert len(scheduler.pop_migrating_out_requests_last_stage()) == 0 + +def test_scheduler_pre_alloc(): + # total 8 blocks + scheduler = initialize_scheduler() + blocks = scheduler.pre_alloc("1", 2) + assert len(blocks) == 2 + assert len(scheduler.pre_alloc_cache_dict["1"]) == 2 + blocks = scheduler.pre_alloc("1", 4) + assert len(blocks) == 4 + assert len(scheduler.pre_alloc_cache_dict["1"]) == 6 + blocks = scheduler.pre_alloc("2,", 4) + assert len(blocks) == 0 + +def test_scheduler_should_abort_migration(): + scheduler = initialize_scheduler() + # tot 8 blocks + block_size = 4 + _, seq_group_0 = create_dummy_prompt("0", prompt_length=7, block_size=block_size) + scheduler.add_seq_group(seq_group_0) + _, seq_group_1 = create_dummy_prompt("1", prompt_length=17, block_size=block_size) + scheduler.add_seq_group(seq_group_1) + # remain 0 blocks + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + + assert scheduler._get_num_killed_requests() == 0 + # assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + # all in running queue + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + assert scheduler._get_num_killed_requests() == 0 + migrating_request = scheduler.get_last_running_request() + last_stage_time = time.time() + assert migrating_request.request_id == "1" + # preempt request 1 + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + assert scheduler.should_abort_migration(seq_group_1, last_stage_time) == True + assert scheduler.should_abort_migration(seq_group_0, last_stage_time) == False + assert scheduler._get_num_killed_requests() == 1 + scheduler.remove_running_request(seq_group_0) + scheduler.free_src_request(seq_group_0) + # free request 0, requset 1 prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + append_new_token(out, 1) + assert scheduler._get_num_killed_requests() == 0 + assert scheduler.should_abort_migration(seq_group_1, last_stage_time) == True + +def test_free_dst_pre_alloc_cache(): + scheduler = initialize_scheduler() + blocks = scheduler.pre_alloc("1", 2) + blocks = scheduler.pre_alloc("1", 4) + assert len(scheduler.pre_alloc_cache_dict["1"]) == 6 + scheduler.free_dst_pre_alloc_cache("1") + assert scheduler.pre_alloc_cache_dict.get("1",None) == None + assert scheduler.block_manager.get_num_free_gpu_blocks() == 8 + +def test_get_request_incremental_blocks(): + scheduler = initialize_scheduler() + block_size = 4 + _, seq_group = create_dummy_prompt("0", prompt_length=16, block_size=block_size) + scheduler.add_seq_group(seq_group) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + incremental_blocks = scheduler.get_request_incremental_blocks(seq_group, 2) + assert len(incremental_blocks) == 2 \ No newline at end of file diff --git a/tests/backends/vllm/utils.py b/tests/backends/vllm/utils.py new file mode 100644 index 0000000..5c69ad0 --- /dev/null +++ b/tests/backends/vllm/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Iterable, Optional, Tuple + +from vllm import SamplingParams +from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob, Sequence, SequenceGroup +from vllm.config import SchedulerConfig, CacheConfig + +from llumnix.backends.vllm.scheduler import SchedulerLlumnix + +def initialize_scheduler(*, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None) -> SchedulerLlumnix: + block_size = 4 + scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, + max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = SchedulerLlumnix(scheduler_config, cache_config, lora_config) + return scheduler + +def create_dummy_prompt( + request_id: str, + prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) + prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + seq_group = SequenceGroup( + request_id, [prompt], + SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + time.time(), lora_request) + + return prompt, seq_group + + +def create_seq_group( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + ) + + return seq_group + + +def round_up_to_next_block(seq_len: int, block_size: int) -> int: + return (seq_len + block_size - 1) // block_size diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/entrypoints/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entrypoints/vllm/__init__.py b/tests/entrypoints/vllm/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/entrypoints/vllm/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/global_scheduler/__init__.py b/tests/global_scheduler/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/global_scheduler/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/global_scheduler/test_llm_engine_manager.py b/tests/global_scheduler/test_llm_engine_manager.py index c0edb41..6da0ae0 100644 --- a/tests/global_scheduler/test_llm_engine_manager.py +++ b/tests/global_scheduler/test_llm_engine_manager.py @@ -123,7 +123,7 @@ def test_init_manager(engine_manager): def test_init_llumlet(llumlet): assert llumlet is not None ray.get(llumlet.is_ready.remote()) - + # TODO(s5u13b): Add init_llumlets test. def test_scale_up_and_down(engine_manager): @@ -156,7 +156,7 @@ def test_connect_to_instances(): def test_generate_and_abort(engine_manager, llumlet): instance_id = ray.get(llumlet.get_instance_id.remote()) - ray.get(engine_manager.scale_up.remote(instance_id, llumlet)) + ray.get(engine_manager.scale_up.remote(instance_id, [llumlet])) request_id = random_uuid() num_requests = ray.get(llumlet.get_num_requests.remote()) assert num_requests == 0 diff --git a/tests/llumlet/__init__.py b/tests/llumlet/__init__.py new file mode 100644 index 0000000..4638bd9 --- /dev/null +++ b/tests/llumlet/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/llumlet/test_migration_coordinator.py b/tests/llumlet/test_migration_coordinator.py new file mode 100644 index 0000000..60dcde4 --- /dev/null +++ b/tests/llumlet/test_migration_coordinator.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +from unittest.mock import MagicMock, patch + +from llumnix.llumlet.migration_coordinator import MigrationCoordinator +from llumnix.llumlet.migrating_request import MigratingRequest +from llumnix.backends.backend_interface import BackendInterface +from llumnix.llumlet.llumlet import MigrationStatus + +@ray.remote +def ray_remote_call(ret): + return ret + +def test_migrate_out_onestage(): + # Initialize Ray + ray.init(ignore_reinit_error=True) + + # Create mock objects + backend_engine = MagicMock(spec=BackendInterface) + migrate_in_ray_actor = MagicMock() + migrate_out_request = MigratingRequest(1, "test_request") + + # Create an instance of MigrationCoordinator + coordinator = MigrationCoordinator(backend_engine, 1, 3) + + # Mock method return values and test data + src_blocks = [1, 2, 3] + dst_blocks = [1, 2] + backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.should_abort_migration.return_value = False + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + + # Test normal migration scenario + status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.RUNNING + + # Test the last stage of migration + src_blocks = [3] + dst_blocks = [3] + backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.should_abort_migration.return_value = False + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED_DONE + + migrate_out_request = MigratingRequest(2, "test_request") + # Test migration aborted scenario + src_blocks = [1, 2, 3] + dst_blocks = [] + backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.should_abort_migration.return_value = False + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED_ABORTED + + migrate_out_request = MigratingRequest(3, "test_request") + src_blocks = [1, 2, 3] + dst_blocks = [1, 2] + backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.should_abort_migration.return_value = True + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED_ABORTED + ray.shutdown() + +@patch.object(MigrationCoordinator, 'migrate_out_onestage') +def test_migrate_out_multistage(migrate_out_onestage): + # Initialize Ray + ray.init(ignore_reinit_error=True) + + # Create mock objects + backend_engine = MagicMock(spec=BackendInterface) + migrate_in_ray_actor = MagicMock() + migrate_out_request = MigratingRequest(1, "test_request") + + # Create an instance of MigrationCoordinator + max_stages = 3 + coordinator = MigrationCoordinator(backend_engine, 1, max_stages) + migrate_in_ray_actor = MagicMock() + migrate_in_ray_actor.execute_engine_method = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote([1]) + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote([1]) + coordinator.migrate_out_onestage.side_effect = [MigrationStatus.FINISHED_DONE] + status = coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + assert coordinator.migrate_out_onestage.call_count == 1 + assert status == MigrationStatus.FINISHED_DONE + + max_stages = 3 + coordinator.migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING] + status = coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + assert coordinator.migrate_out_onestage.call_count == max_stages + 1 + assert status == MigrationStatus.FINISHED_ABORTED + + ray.shutdown() \ No newline at end of file