Skip to content

Commit

Permalink
[CI] Add test for migration backend and worker
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Aug 29, 2024
1 parent be0d326 commit 8dd876f
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
111 changes: 111 additions & 0 deletions tests/backends/vllm/test_migration_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import pytest
import torch
import ray

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port

from llumnix.backends.vllm.worker import MigrationWorker
from llumnix.arg_utils import EngineManagerArgs
from llumnix.utils import random_uuid

from tests.backends.vllm.test_worker import create_worker

class MockMigrationWorker(MigrationWorker):
def set_gpu_cache(self, data):
for layer_idx in range(self.cache_engine.num_layers):
self.gpu_cache[layer_idx].copy_(data[layer_idx])
torch.cuda.synchronize()

def get_gpu_cache(self):
torch.cuda.synchronize()
return self.gpu_cache

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_migrate_cache(backend):
ray.init(namespace="llumnix", ignore_reinit_error=True)

engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config()
migraiton_config.migration_backend = backend

worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")
worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")

ray.get(worker0.execute_method.remote('init_device'))
ray.get(worker1.execute_method.remote('init_device'))

num_gpu_blocks = 8
ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))
ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))

worker0_id = random_uuid()
ray.get(worker0.execute_method.remote(
'init_migration',
instance_id=worker0_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker0],
node_id=ray.get_runtime_context().get_node_id()))

worker1_id = random_uuid()
ray.get(worker1.execute_method.remote(
'init_migration',
instance_id=worker1_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker1],
node_id=ray.get_runtime_context().get_node_id()))

instance_rank = {worker0_id: 0, worker1_id: 1}
group_name = random_uuid()
assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name),
worker1.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name)]))
assert all(ray.get([worker0.execute_method.remote('warmup'),
worker1.execute_method.remote('warmup')]))

num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
head_size = engine_config.model_config.get_head_size()
num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config)
block_size = engine_config.cache_config.block_size

dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size))
ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data))
worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache'))

dst_blocks = list(range(num_gpu_blocks))
random.shuffle(dst_blocks)
src_to_dst = {idx: block_num for idx, block_num in enumerate(dst_blocks)}
ray.get(worker1.execute_method.remote(
'migrate_cache',
src_worker_handle_list=[worker0],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))

worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache'))

for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx])
assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx])

ray.shutdown()
139 changes: 139 additions & 0 deletions tests/backends/vllm/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.config import EngineConfig
from vllm.executor.ray_gpu_executor import RayWorkerWrapper

from llumnix.arg_utils import EngineManagerArgs
from llumnix.utils import random_uuid


def create_worker(rank: int, local_rank: int, engine_config: EngineConfig,
worker_module_name="llumnix.backends.vllm.worker",
worker_class_name="MigrationWorker"):
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,
)

worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=scheduling_strategy
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=True
)

worker.init_worker.remote(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()),
lora_config=engine_config.lora_config,
vision_language_config=engine_config.vision_language_config,
is_driver_worker = False
)

return worker

@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_reserve_memory_for_migration(backend):
ray.init(namespace="llumnix", ignore_reinit_error=True)

engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config()
migraiton_config.migration_backend = backend
worker = create_worker(rank=0, local_rank=0, engine_config=engine_config)
ray.get(worker.execute_method.remote('init_device'))

block_size = CacheEngine.get_cache_block_size(engine_config.cache_config, engine_config.model_config,
engine_config.parallel_config)
num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
occupy_memory = migraiton_config.migration_cache_blocks * block_size * migraiton_config.migration_num_layers // num_layers

migration_cache_size = ray.get(worker.execute_method.remote('reserve_memory_for_migration',
migration_config=migraiton_config,
model_config=engine_config.model_config,
cache_config=engine_config.cache_config,
parallel_config=engine_config.parallel_config))
assert migration_cache_size == occupy_memory

ray.shutdown()

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_rebuild_migration_backend(backend):
ray.init(namespace="llumnix", ignore_reinit_error=True)

engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config()
migraiton_config.migration_backend = backend

worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config)
worker0_id = random_uuid()
ray.get(worker0.execute_method.remote('init_device'))
ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0))
ray.get(worker0.execute_method.remote(
'init_migration',
instance_id=worker0_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker0],
node_id=ray.get_runtime_context().get_node_id()))
instance_rank = {worker0_id: 0}
assert ray.get(worker0.execute_method.remote('rebuild_migration_backend', instance_rank=instance_rank,
group_name=random_uuid()))
assert ray.get(worker0.execute_method.remote('warmup'))

worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config)
worker1_id = random_uuid()
ray.get(worker1.execute_method.remote('init_device'))
ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0))
ray.get(worker1.execute_method.remote(
'init_migration',
instance_id=worker1_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker1],
node_id=ray.get_runtime_context().get_node_id()))

instance_rank = {worker1_id: 1, worker0_id: 0}
group_name = random_uuid()
assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name),
worker1.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name)]))
assert all(ray.get([worker0.execute_method.remote('warmup'),
worker1.execute_method.remote('warmup')]))

ray.kill(worker1)

instance_rank = {worker0_id: 0}
assert ray.get(worker0.execute_method.remote('rebuild_migration_backend', instance_rank=instance_rank,
group_name=random_uuid()))
assert ray.get(worker0.execute_method.remote('warmup'))

ray.shutdown()

0 comments on commit 8dd876f

Please sign in to comment.