Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Sep 10, 2024
1 parent 78f71c5 commit 937afce
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 98 deletions.
7 changes: 4 additions & 3 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def __init__(self,
# migrate args
self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy,
global_scheduler_config.migrate_out_load_threshold,
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)
self.instance_load_calculator)
# auto-scaling args
self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold,
global_scheduler_config.scale_down_threshold,
global_scheduler_config.scaling_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)

self.num_instances = 0
self.instance_id_set: Set[str] = set()
Expand All @@ -60,6 +60,7 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None:
# Llumnix have different instance load compuatation methods for dispatch/migrate/scale.
instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch')
instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate')
instance_info.instance_type = self.scaling_scheduler.get_instance_type_info(instance_info.instance_id)
self.instance_info[instance_info.instance_id] = instance_info

def dispatch(self) -> str:
Expand Down
123 changes: 73 additions & 50 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator
from llumnix.global_scheduler.scaling_scheduler import InstanceType

logger = init_logger(__name__)

Expand All @@ -30,19 +31,11 @@ class PairMigrationConstraints(str, Enum):
DECODING_2_DECODING = "DECODING_2_DECODING"
PREFILL_2_DECODING = "PREFILL_2_DECODING"

class InstanceType(str, Enum):
NO_CONSTRAINTS = "NO_CONSTRAINTS"

# Specific to Prefill-Decoding disaggregation.
PREFILL = "prefill"
DECODE = "decode"

class MigrationScheduler:
def __init__(self,
pair_migration_policy: str,
migrate_out_load_threshold: float,
instance_load_calculator: InstanceLoadCalculator,
constraint_prefill_instance_num: int) -> None:
instance_load_calculator: InstanceLoadCalculator) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.instance_load_calculator = instance_load_calculator
self.enable_defrag = instance_load_calculator.enable_defrag
Expand All @@ -59,38 +52,19 @@ def __init__(self,

self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.instance_type_id_set: Dict[InstanceType, Set[str]] = {instance_type: set() for instance_type in InstanceType}
self.constraint_prefill_instance_num = constraint_prefill_instance_num
# instance info args
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: Dict[str, List[InstanceInfo]] = {instance_type: [] for instance_type in InstanceType}
self.sorted_instance_infos: List[InstanceInfo] = None

def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
self._sort_instance_infos(descending=False)
sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type)
return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos)

def _get_migration_instance_infos(self, pair_migration_type:str) -> Dict[str, InstanceInfo]:
if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS:
# migrate out instances
self._sort_instance_infos([InstanceType.NO_CONSTRAINTS])
sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.NO_CONSTRAINTS])
if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold]
# migrate in instances
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.NO_CONSTRAINTS]
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
elif pair_migration_type == PairMigrationConstraints.PREFILL_2_DECODING:
self._sort_instance_infos([InstanceType.PREFILL, InstanceType.DECODE])
sorted_src_instance_infos = list(reversed(self.sorted_instance_infos[InstanceType.PREFILL]))
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE]
if i.num_killed_requests == 0]
# TODO[xinyi]: Considering decoding instances load, try to decode on the prefill instance.
elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
self._sort_instance_infos([InstanceType.DECODE])
sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.DECODE])
if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold]
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE]
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
return sorted_src_instance_infos, sorted_dst_instance_infos
filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type,
migrate_out_load_threshold=self.migrate_out_load_threshold)
return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type)

def update_instance_infos(self,
instance_info: Dict[str, InstanceInfo]) -> None:
Expand All @@ -99,30 +73,79 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
if self.constraint_prefill_instance_num > 0:
if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.constraint_prefill_instance_num:
self.instance_type_id_set[InstanceType.PREFILL].add(instance_id)
else:
self.instance_type_id_set[InstanceType.DECODE].add(instance_id)
else:
self.instance_type_id_set[InstanceType.NO_CONSTRAINTS].add(instance_id)

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
self.num_instances = len(self.instance_id_set)

def _sort_instance_infos(self, instance_types_list: str,
descending: bool = False) -> None:
def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
filtered_instance_infos = {inst_type: set() for inst_type in instance_types_list}
key_attr = 'instance_load_migrate'
for inst_type in instance_types_list:
filtered_instance_infos[inst_type] = [info for info in instance_infos if info.instance_id in self.instance_type_id_set[inst_type]]
self.sorted_instance_infos[inst_type] = sorted(
filtered_instance_infos[inst_type],
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
self.sorted_instance_infos = sorted(
instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)

class FilteringInstanceInfosPolicy(ABC):
def __init__(self,
migrate_out_load_threshold: float) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.filter_instances_rules = {
PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS),
PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE),
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_instances(self, sorted_instance_infos: List[InstanceInfo], pair_migration_type: str = None) -> Dict[str, InstanceInfo]:
src_type, dst_type = self.filter_instances_rules[pair_migration_type]
filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type]
filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type]
src_instance_infos = self.filter_src_instances(filtered_src_instance_infos)
dst_instance_infos = self.filter_dst_instances(filtered_dst_instance_infos)
return src_instance_infos, dst_instance_infos

@abstractmethod
def filter_src_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]:
raise NotImplementedError

@abstractmethod
def filter_dst_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]:
raise NotImplementedError

class FilterConstrained(FilteringInstanceInfosPolicy):
def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
src_instance_infos = [i for i in reversed(filtered_instance_infos)
if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold]
return src_instance_infos

def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
dst_instance_infos = [i for i in filtered_instance_infos
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
return dst_instance_infos

class FilterRelaxed(FilteringInstanceInfosPolicy):
# The policy is currently used to select the decoding instances to migrate requests from the prefill instances.
def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
src_instance_infos = list(reversed(filtered_instance_infos))
return src_instance_infos

def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
dst_instance_infos = [i for i in filtered_instance_infos
if i.num_killed_requests == 0]
return dst_instance_infos

class FilteringInstanceInfosPolicyFactory:
_POLICY_REGISTRY = {
PairMigrationConstraints.NO_CONSTRAINTS: FilterConstrained,
PairMigrationConstraints.DECODING_2_DECODING: FilterConstrained,
PairMigrationConstraints.PREFILL_2_DECODING: FilterRelaxed,
}

@classmethod
def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> FilteringInstanceInfosPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

class PairMigrationPolicy(ABC):
def __init__(self,
Expand Down
33 changes: 32 additions & 1 deletion llumnix/global_scheduler/scaling_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@

from typing import Dict, List, Tuple, Set
from abc import ABC, abstractmethod
from enum import Enum
import numpy as np

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator

logger = init_logger(__name__)

class InstanceType(str, Enum):
NO_CONSTRAINTS = "NO_CONSTRAINTS"

# Specific to Prefill-Decoding disaggregation.
PREFILL = "prefill"
DECODE = "decode"


class ScalingScheduler:
def __init__(self,
scale_up_threshold: float,
scale_down_threshold: float,
scaling_policy: str,
instance_load_calculator: InstanceLoadCalculator) -> None:
instance_load_calculator: InstanceLoadCalculator,
maximum_prefill_instance_num: int) -> None:
self.scale_up_threshold = scale_up_threshold
self.scale_down_threshold = scale_down_threshold
self.scaling_policy = ScalePolicyFactory.get_policy(scaling_policy,
Expand All @@ -35,10 +44,14 @@ def __init__(self,

self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.maximum_prefill_instance_num = maximum_prefill_instance_num
# instance info args
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: List[InstanceInfo] = None

# TODO(Xinyi): Tag instance type for scheduler, should be extended to auto-scaling for prefill/decoding instances.
self.instance_type_id_set: Dict[InstanceType, Set[str]] = {instance_type: set() for instance_type in InstanceType}

def check_scale(self) -> Tuple[str, str]:
scale_up_num = 0
scale_down_num = 0
Expand All @@ -63,6 +76,18 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
instance_type = None
if self.maximum_prefill_instance_num > 0:
if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num:
self.instance_type_id_set[InstanceType.PREFILL].add(instance_id)
instance_type = InstanceType.PREFILL
else:
self.instance_type_id_set[InstanceType.DECODE].add(instance_id)
instance_type = InstanceType.DECODE
else:
self.instance_type_id_set[InstanceType.NO_CONSTRAINTS].add(instance_id)
instance_type = InstanceType.NO_CONSTRAINTS
return instance_type

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
Expand All @@ -79,6 +104,12 @@ def get_empty_instance_info(self) -> InstanceInfo:
dummy_intance_info.num_available_gpu_blocks_waiting = np.inf
return dummy_intance_info

def get_instance_type_info(self, instance_id: str) -> InstanceInfo:
for instance_type in InstanceType:
if instance_id in self.instance_type_id_set[instance_type]:
return instance_type
return self.add_instance(instance_id)

class ScalePolicy(ABC):
def __init__(self,
instance_load_calculator: InstanceLoadCalculator) -> None:
Expand Down
2 changes: 2 additions & 0 deletions llumnix/instance_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self,
num_blocks_first_waiting_request: int = 0,
waiting_time_first_waiting_request: int = 0,
num_blocks_all_waiting_requests: int = 0,
instance_type: str = "",
inference_type: str = "",
num_batched_tokens: int = 0) -> None:
self.num_total_gpu_blocks = num_total_gpu_blocks
Expand All @@ -54,6 +55,7 @@ def __init__(self,
# For global scheduling.
self.instance_load_migrate = -np.inf
self.instance_load_dispatch_scale = -np.inf
self.instance_type = instance_type

# For record statistics, assigned in scheduler.
self.inference_type = inference_type
Expand Down
6 changes: 3 additions & 3 deletions llumnix/llumlet/migration_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m
# live migration, transfer all blocks except last one(currently updating)
migration_status = MigrationStatus.RUNNING
is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks)
if not is_last_stage and migrate_out_request.blocking_migration:
if not (is_last_stage or migrate_out_request.blocking_migration):
src_blocks = incremental_blocks[:-1]
stage_block_num = len(incremental_blocks) - 1
dst_blocks = ray.get(migrate_in_ray_actor.execute_migration_method \
Expand All @@ -70,7 +70,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m

if len(dst_blocks) != len(src_blocks):
# migrate-in instance failed to prev alloc
if is_last_stage or not migrate_out_request.blocking_migration:
if is_last_stage or migrate_out_request.blocking_migration:
self.backend_engine.add_running_request(migrate_out_request)
self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request)
migration_status = MigrationStatus.FINISHED_ABORTED
Expand All @@ -80,7 +80,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m
migrate_out_request.stage_num_blocks_list.append(stage_block_num)
# TODO(ZeldaHuang): send_blocks in migrate_in_pre_alloc/migrate_in_last_stage
self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks)
if not is_last_stage and migrate_out_request.blocking_migration and migrate_out_request.should_abort_migration():
if not (is_last_stage or migrate_out_request.blocking_migration) and migrate_out_request.should_abort_migration():
# migrate-out request abort by scheduler during send/recv
migration_status = MigrationStatus.FINISHED_ABORTED

Expand Down
6 changes: 3 additions & 3 deletions llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def prompt_len(self) -> int:
def output_len(self) -> int:
raise NotImplementedError

# Whether the migration of request is divided into multiple stages. For requests that have already reached
# the expected steps, the migration will completed within one stage.
# Whether the migration of request is completed within one stage. For requests that have already reached
# the expected steps, blocking_migration is True.
@property
def blocking_migration(self) -> bool:
return self.output_len < self.expected_steps
return self.output_len >= self.expected_steps

def should_abort_migration(self) -> bool:
return self.output_len == 0 \
Expand Down
3 changes: 3 additions & 0 deletions tests/global_scheduler/test_llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from llumnix.arg_utils import EngineManagerArgs
from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME
from llumnix.instance_info import InstanceInfo
from llumnix.global_scheduler.scaling_scheduler import InstanceType

from tests.utils import setup_ray_env

Expand Down Expand Up @@ -204,6 +205,7 @@ def get_instance_info_migrate_in(instance_id):
instance_info.num_available_gpu_blocks = np.inf
instance_info.num_running_requests = 1
instance_info.num_blocks_first_waiting_request = 0
instance_info.instance_type = InstanceType.NO_CONSTRAINTS
return instance_info

def get_instance_info_migrate_out(instance_id):
Expand All @@ -212,6 +214,7 @@ def get_instance_info_migrate_out(instance_id):
instance_info.num_available_gpu_blocks = 0
instance_info.num_running_requests = 1
instance_info.num_blocks_first_waiting_request = np.inf
instance_info.instance_type = InstanceType.NO_CONSTRAINTS
return instance_info

def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager):
Expand Down
Loading

0 comments on commit 937afce

Please sign in to comment.