Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support for Scheduling-defined Prefill-Decode Disaggregation feature #15

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EngineManagerArgs:
polling_interval: float = None

dispatch_policy: str = None
num_dispatch_instances: int = None

enable_migration: bool = None
enable_defrag: bool = None
Expand Down Expand Up @@ -61,6 +62,8 @@ class EngineManagerArgs:
last_stage_max_blocks: int = None
max_stages: int = None

enable_pd_disagg: bool = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this file, set the default value to None, set default value in config/default.py.

we want to get default value from only one palce

def __post_init__(self):
for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
Expand All @@ -69,15 +72,19 @@ def __post_init__(self):
def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.load_metric,
self.dispatch_policy,
self.num_dispatch_instances,
self.pair_migration_policy,
self.migrate_out_threshold,
self.enable_defrag,
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold)
self.scale_down_threshold,
self.enable_pd_disagg)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -135,6 +142,10 @@ def add_cli_args(
type=str,
choices=['balanced', 'load', 'queue', 'flood'],
help='request dispatch policy')
parser.add_argument('--num-available-dispatch-instances',
type=int,
default=None,
help='number of available instances for dispatching')

parser.add_argument('--enable-migration',
action='store_true',
Expand Down Expand Up @@ -216,5 +227,8 @@ def add_cli_args(
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')

parser.add_argument('--enable-pd-disagg',
type=bool,
default=None,
help='enable prefill decoding disaggregation')
return parser
7 changes: 6 additions & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_sim_backend(status: "BackendType") -> bool:
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo,
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int,
*args, **kwargs) -> None:
KuilongCui marked this conversation as resolved.
Show resolved Hide resolved
"""Adds a new inference request to the backend's processing queue.

Expand All @@ -42,6 +42,11 @@ def add_request(self, request_id: str, server_info: ServerInfo,
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
expected_steps: The expected number of steps for the request to run. The number of steps
represents the times 'engine.step()' has been called by the backend
instances for the request. Currently, `expected_steps` is used
to implement prefill-decoding disaggregation. For requests dispatched to
prefill instances `expected_steps` is set to 1.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down
10 changes: 6 additions & 4 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs):
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
seq_group = self.scheduler.waiting[-1]
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
Expand Down Expand Up @@ -250,10 +251,11 @@ def execute_worker_method(self, method, *args, **kwargs):
def add_request(self,
request_id: str,
server_info: ServerInfo,
expected_steps: int,
*args,
**kwargs) -> None:
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, *args, **kwargs)
self.engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
Expand Down
14 changes: 14 additions & 0 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
import threading
from typing import Dict, List, Optional, Tuple
from collections import deque

from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable
from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs)
Expand Down Expand Up @@ -205,6 +206,19 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.update_instance_info_callback(self._get_instance_info())
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, running_queue: deque, *args, **kwargs):
filtered_running_queue = deque()
remove_running = deque()
for seq_group in running_queue:
if seq_group.output_len >= seq_group.expected_steps:
remove_running.extend([seq_group])
else:
filtered_running_queue.extend([seq_group])
remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs)
for seq_group in remove_running:
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest):
def __init__(self, request_id, server_info, *args, **kwargs) -> None:
def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs) -> None:
SequenceGroup.__init__(self, request_id, *args, **kwargs)
LlumnixRequest.__init__(self, request_id, server_info)
LlumnixRequest.__init__(self, request_id, server_info, expected_steps)

@property
def prompt_len(self) -> int:
Expand Down
8 changes: 8 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
_C.MANAGER.LOAD_METRIC = 'remaining_steps'
# Request dispatch policy
_C.MANAGER.DISPATCH_POLICY = 'load'
# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching
_C.MANAGER.NUM_DISPATCH_INSTANCES = -1

# -----------------------------------------------------------------------------
# MIGRATION CONFIGURATION
Expand Down Expand Up @@ -122,3 +124,9 @@
_C.MANAGER.SCALE_UP_THRESHOLD = 10
# Scale down threshold
_C.MANAGER.SCALE_DOWN_THRESHOLD = 60

# -----------------------------------------------------------------------------
# PREFILL DECODING DISAGGREGATION CONFIGURATION
# -----------------------------------------------------------------------------
# Enable prefill decoding disaggregation
_C.MANAGER.ENABLE_PD_DISAGG = False
1 change: 0 additions & 1 deletion llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,

instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
migration_configs = engine_manager_args.create_migration_config()

for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
if not engine_manager_args.profiling_result_file_path:
Expand Down
16 changes: 12 additions & 4 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
class DispatchScheduler:
def __init__(self,
dispatch_policy: str,
instance_load_calculator: InstanceLoadCalculator) -> None:
instance_load_calculator: InstanceLoadCalculator,
num_dispatch_instances: int) -> None:
self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy)
self.instance_load_calculator = instance_load_calculator
self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.available_dispatch_instance_set: Set[str] = set()
self.num_dispatch_instances = num_dispatch_instances
# instance info args
self.instance_info: Dict[str, InstanceInfo] = {}
self.sorted_instance_infos: List[InstanceInfo] = None
Expand Down Expand Up @@ -56,22 +59,27 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
self.instance_num_requests[instance_id] = 0
if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and
len(self.available_dispatch_instance_set) < self.num_dispatch_instances):
self.available_dispatch_instance_set.add(instance_id)
self.instance_num_requests[instance_id] = 0

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
self.num_instances = len(self.instance_id_set)
del self.instance_num_requests[instance_id]
if instance_id in self.instance_num_requests:
del self.instance_num_requests[instance_id]

def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set]
if isinstance(self.dispatch_policy, Queue):
key_attr = 'num_waiting_requests'
else:
key_attr = 'instance_load_dispatch_scale'
self.sorted_instance_infos = sorted(
instance_infos,
available_instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
Expand Down
16 changes: 11 additions & 5 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

from typing import Dict, List, Tuple, Union, Iterable, Set
import math

from llumnix.logger import init_logger
from llumnix.internal_config import GlobalSchedulerConfig
Expand All @@ -30,12 +31,14 @@ def __init__(self,
# instance load and instance info args
self.load_metric = global_scheduler_config.load_metric
self.enable_defrag = global_scheduler_config.enable_defrag
self.enable_pd_disagg = global_scheduler_config.enable_pd_disagg
self.instance_load_calculator = InstanceLoadCalculator(load_metric=self.load_metric,
enable_defrag=self.enable_defrag)
# dispatch args
self.dispatch_policy = global_scheduler_config.dispatch_policy
self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)
# migrate args
self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy,
global_scheduler_config.migrate_out_load_threshold,
Expand All @@ -44,7 +47,8 @@ def __init__(self,
self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold,
global_scheduler_config.scale_down_threshold,
global_scheduler_config.scaling_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)

self.num_instances = 0
self.instance_id_set: Set[str] = set()
Expand All @@ -56,16 +60,18 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None:
# Llumnix have different instance load compuatation methods for dispatch/migrate/scale.
instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch')
instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate')
instance_info.instance_type = self.scaling_scheduler.get_instance_type_info(instance_info.instance_id)
self.instance_info[instance_info.instance_id] = instance_info

def dispatch(self) -> str:
self.dispatch_scheduler.update_instance_infos(self.instance_info)
instance_id = self.dispatch_scheduler.dispatch()
return instance_id
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration()
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs

def check_scale(self) -> Tuple[str, str]:
Expand Down
Loading