Skip to content

Commit

Permalink
[Feature] MoE Ulysses Support (#5918)
Browse files Browse the repository at this point in the history
* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and botbw committed Jul 19, 2024
1 parent c8bf268 commit 633849f
Show file tree
Hide file tree
Showing 6 changed files with 570 additions and 71 deletions.
52 changes: 46 additions & 6 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from collections import defaultdict
from copy import deepcopy
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple

Expand All @@ -22,6 +24,8 @@
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.tensor.moe_tensor.api import is_moe_tensor


Expand Down Expand Up @@ -114,21 +118,25 @@ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False,
self.ddp_config["find_unused_parameters"] = True

world_size = dist.get_world_size()
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size

if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)

self._init_moe_param_comm()
# self._init_moe_param_comm()

self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])

# set ep_group after super init
# TODO do it in a better way
self.moe_dp_group = self.pp_group
self.ep_group = self.pp_group
self.moe_tp_group = self.pp_group

self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
Expand Down Expand Up @@ -205,15 +213,32 @@ def configure(
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)

# TODO: Support Galore + ZeRO
self.zero_stage
deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=self.use_ddp,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
Expand All @@ -224,6 +249,7 @@ def configure(
reinitialize_optimizer(optimizer, model)

if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
Expand All @@ -236,14 +262,21 @@ def configure(
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
else:
if not (self.dp_size > 1 or self.moe_dp_size > 1):
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
Expand All @@ -262,4 +295,11 @@ def configure(
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

# Setup optimizers that require global states
optim = optimizer.optim
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)

return model, optimizer, criterion, dataloader, lr_scheduler
8 changes: 6 additions & 2 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def create_group_along_axis(
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Expand Down Expand Up @@ -257,7 +257,11 @@ def create_group_along_axis(
return target_group

def get_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
self,
axis: Union[int, List[int]],
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Expand Down
Loading

0 comments on commit 633849f

Please sign in to comment.