Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Sep 5, 2024
1 parent 5f11521 commit 16a05d1
Show file tree
Hide file tree
Showing 24 changed files with 635 additions and 193 deletions.
42 changes: 21 additions & 21 deletions llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import vllm
from vllm import *
# import vllm
# from vllm import *

from llumnix.server_info import ServerInfo
from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
init_manager, init_llumlets)
from llumnix.arg_utils import EngineManagerArgs
from llumnix.llm_engine_manager import LLMEngineManager
from llumnix.llumlet.llumlet import Llumlet
# from llumnix.server_info import ServerInfo
# from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
# init_manager, init_llumlets)
# from llumnix.arg_utils import EngineManagerArgs
# from llumnix.llm_engine_manager import LLMEngineManager
# from llumnix.llumlet.llumlet import Llumlet

from .version import __version__
# from .version import __version__

__all__ = [
"__version__",
"ServerInfo",
"launch_ray_cluster",
"connect_to_ray_cluster",
"init_manager",
"init_llumlets",
"EngineManagerArgs",
"LLMEngineManager",
"Llumlet"
]
# __all__ = [
# "__version__",
# "ServerInfo",
# "launch_ray_cluster",
# "connect_to_ray_cluster",
# "init_manager",
# "init_llumlets",
# "EngineManagerArgs",
# "LLMEngineManager",
# "Llumlet"
# ]

__all__.extend(getattr(vllm, "__all__", []))
# __all__.extend(getattr(vllm, "__all__", []))
18 changes: 16 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
import argparse
from typing import Tuple

from llumnix.common.config import get_cfg
from llumnix.config import GlobalSchedulerConfig, MigrationConfig
from llumnix.logger import init_logger

logger = init_logger(__name__)

@dataclass
class EngineManagerArgs:
Expand Down Expand Up @@ -59,9 +62,15 @@ class EngineManagerArgs:
last_stage_max_blocks: int = 16
max_stages: int = 3

config_file: str = None
def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

config_data = get_cfg()
config_data.merge_from_file(self.config_file)

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.load_metric,
self.dispatch_policy,
Expand All @@ -70,7 +79,9 @@ def create_global_scheduler_configs(
self.enable_defrag,
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold)
self.scale_down_threshold,
config_data.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION,
config_data.PDD_CONFIG.PREFILL_INSTANCE_NUM)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -229,5 +240,8 @@ def add_cli_args(
type=int,
default=EngineManagerArgs.max_stages,
help='drop migration if the number of stages > max_stages')

parser.add_argument("--config-file",
type=str,
default=EngineManagerArgs.config_file,
help="path to the configuration file")
return parser
20 changes: 19 additions & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_sim_backend(status: "BackendType") -> bool:
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo,
def add_request(self, request_id: str, server_info: ServerInfo, request_expected_steps: int,
*args, **kwargs) -> None:
"""Adds a new inference request to the backend's processing queue.
Expand All @@ -42,6 +42,9 @@ def add_request(self, request_id: str, server_info: ServerInfo,
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
request_expected_steps: The expected number of steps for the request to run.The number of steps
represents the sum of the times 'engine.step()' has been called by the
backend instances for the request.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down Expand Up @@ -267,6 +270,21 @@ def commit_dst_request(self, backend_request: LlumnixRequest) -> None:
of the request.
"""
raise NotImplementedError

@abstractmethod
def update_strict_pre_migration(self, new_strict_pre_migration: bool) -> None:
"""Update the status of whether to force migration in the backend engine.
This method updates the status of whether to force migration in the backend engine. This action is performed only when the
corresponding status in the llumlet is changed.
`pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that
the instance enables migration when `request.output_len` >= `request.request_expected_steps`. If `pre_migration` is set
to False, migration will not occur, and requests on the instance that reach the `request_expected_steps` will continue with inference.
Args:
new_strict_pre_migration: New migration status provided for backend engine.
"""
raise NotImplementedError

@abstractmethod
def get_all_request_ids(self) -> List[str]:
Expand Down
22 changes: 14 additions & 8 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,12 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs):
def add_request(self, request_id: str, server_info: ServerInfo, request_expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
logger.info("add_request")
seq_group = self.scheduler.waiting[-1]
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, request_expected_steps, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
Expand Down Expand Up @@ -181,13 +182,14 @@ def __init__(
placement_group: "PlacementGroup" = None,
node_id: str = None
) -> None:
self.strict_pre_migration = True
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
migration_config=migration_config,
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler = SchedulerLlumnix(self.strict_pre_migration, self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
Expand All @@ -212,19 +214,21 @@ def execute_worker_method(self, method, *args, **kwargs):
def add_request(self,
request_id: str,
server_info: ServerInfo,
request_expected_steps: int,
*args,
**kwargs) -> None:
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, *args, **kwargs)
self.engine.add_request(request_id, server_info, request_expected_steps, *args, **kwargs)


def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.seq_id = next(self.engine.seq_counter)
logger.info("add seq {} to block table".format(seq.seq_id))
pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
self.add_running_request(backend_request)
backend_request.reset_migration_args()
self.add_running_request(backend_request)

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers",
Expand All @@ -248,7 +252,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:

def get_running_queue(self ) -> List[SequenceGroupLlumnix]:
return self.engine.scheduler.get_running_queue()

def update_strict_pre_migration(self, new_migration_state: bool):
if self.strict_pre_migration != new_migration_state:
self.strict_pre_migration = new_migration_state
self.engine.scheduler.update_strict_pre_migration(new_migration_state)
def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)

Expand Down Expand Up @@ -281,6 +288,5 @@ def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None:

def free_src_request(self, backend_request: SequenceGroup) -> None:
return self.engine.scheduler.free_src_request(backend_request)

def get_all_request_ids(self) -> List[str]:
return self.engine.scheduler.get_all_request_ids()
23 changes: 22 additions & 1 deletion llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from asyncio.log import logger
import time
import threading
import copy
from typing import Dict, List, Optional, Tuple

from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable
Expand Down Expand Up @@ -45,7 +46,7 @@ def add_block_table(self, block_table: BlockTable, seq_id: int) -> None:
self.block_tables[seq_id] = block_table.copy()

class SchedulerLlumnix(Scheduler):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, strict_pre_migration, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.block_manager = BlockManagerLlumnix(
block_size=self.cache_config.block_size,
Expand All @@ -56,6 +57,7 @@ def __init__(self, *args, **kwargs) -> None:
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[LlumnixRequest] = []
self.strict_pre_migration = strict_pre_migration

def add_update_instance_info_callback(self, update_instance_info_callback):
self.update_instance_info_callback = update_instance_info_callback
Expand Down Expand Up @@ -205,6 +207,25 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.update_instance_info_callback(self._get_instance_info())
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, *args, **kwargs):
args_list = list(args)
args_list[0] = copy.deepcopy(self.running)
remove_running = []
if self.strict_pre_migration:
for seq_group in list(args_list[0]):
if seq_group.expected_steps > 0 and seq_group.output_len >= seq_group.expected_steps:
args_list[0].remove(seq_group)
remove_running.append(seq_group)
new_args = tuple(args_list)
remaining_running, running_scheduled = super()._schedule_running(*new_args, **kwargs)
for seq_group in remove_running:
remaining_running.append(seq_group)
return remaining_running, running_scheduled

@scheduler_lock
def update_strict_pre_migration(self, new_migration_state: bool) -> None:
self.strict_pre_migration = new_migration_state

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest):
def __init__(self, request_id, server_info, *args, **kwargs) -> None:
def __init__(self, request_id, server_info, request_expected_steps: int, *args, **kwargs) -> None:
SequenceGroup.__init__(self, request_id, *args, **kwargs)
LlumnixRequest.__init__(self, request_id, server_info)
LlumnixRequest.__init__(self, request_id, server_info, request_expected_steps)

@property
def prompt_len(self) -> int:
Expand Down
Loading

0 comments on commit 16a05d1

Please sign in to comment.