Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix] moe hybrid parallelism benchmark & follow-up fix #6048

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def __init__(
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
if dp_process_group is moe_dp_group:
pg_param_list = {
dp_process_group: list(model.parameters()),
}
else:
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}

if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
if len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")

super().__init__(
model=model,
Expand Down Expand Up @@ -407,24 +412,25 @@ def configure(
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

if use_ddp:
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
)
self.ddp_config["find_unused_parameters"] = True

if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

model = HybridParallelModule(
module=model,
precision=self.precision,
Expand Down Expand Up @@ -466,6 +472,7 @@ def configure(
tp_process_group=self.tp_group,
)
else:
is_zero = True
if self.dp_size <= 1:
self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
Expand Down
7 changes: 2 additions & 5 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
grad.mul_(ctx.ep_size)
return grad, None


Expand All @@ -328,7 +328,7 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
grad.div_(ctx.ep_size)
return grad, None


Expand Down Expand Up @@ -449,7 +449,4 @@ def all_to_all_uneven(
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
81 changes: 78 additions & 3 deletions colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.functional as F
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
Expand All @@ -28,11 +28,13 @@
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
linear_with_async_comm,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group


Expand All @@ -58,7 +60,7 @@ def backward(ctx, grad_output):
return grad_output, grad_loss


class EPDeepseekMoE(nn.Module):
class EPDeepseekMoE(ParallelModule):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

Expand Down Expand Up @@ -214,6 +216,79 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output_hidden_states


class DeepseekMoEGate_Col(ParallelModule):
def parallel_linear(self, hidden_states):
assert (
hidden_states.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
)

output = linear_with_async_comm(
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
)

# All-gather across the partitions.
output = gather_forward_split_backward(
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = self.parallel_linear(hidden_states)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")

### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator

### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@staticmethod
def from_native_module(
module, process_group: ProcessGroup, config, gather_output, fp8_communication
) -> "DeepseekMoEGate_Col":
LazyInitContext.materialize(module)
module.process_group = process_group
module.fp8_communication = fp8_communication
sharded_weight = shard_rowwise(module.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, module.weight)
module.__class__ = DeepseekMoEGate_Col
return module


class DeepseekPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
Expand All @@ -49,7 +49,7 @@
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
class EPMixtralSparseMoeBlock(ParallelModule):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

Expand Down
36 changes: 31 additions & 5 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import (
DeepseekMoEGate_Col,
DeepseekPipelineForwards,
EPDeepseekMoE,
get_deepseek_flash_attention_forward,
Expand Down Expand Up @@ -56,16 +57,24 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size

# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
"num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads

policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
Expand Down Expand Up @@ -97,6 +106,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding

if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
Expand All @@ -107,10 +117,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
num_q_heads //= tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": num_q_heads,
}
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads

policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand All @@ -135,8 +150,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="mlp.gate",
target_module=DeepseekMoEGate_Col,
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"config": self.model.config,
},
ignore_if_not_exist=True,
),
],
)

if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
Expand Down
22 changes: 16 additions & 6 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size

# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)

if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
"num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads

policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand Down Expand Up @@ -101,12 +109,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
num_q_heads //= tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
"self_attn.num_heads": num_q_heads,
}
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads

policy[MixtralDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand All @@ -131,7 +141,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription( # or replicate?
SubModuleReplacementDescription(
suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
Expand Down
Loading
Loading