Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] DeepseekMoE support #5871

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) ->
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
Expand Down
429 changes: 429 additions & 0 deletions colossalai/shardformer/modeling/deepseek.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class PolicyLocation:
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
Expand Down Expand Up @@ -252,7 +259,6 @@ def get_autopolicy(model: nn.Module) -> Policy:
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
Expand Down
212 changes: 212 additions & 0 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]


class DeepseekPolicy(Policy):
def config_sanity_check(self):
pass

def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)

if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")

if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekModel",
)

if self.shard_config.enable_flash_attention:
warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
)
self.shard_config.enable_flash_attention = False

return policy

def postprocess(self):
return self.model

def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
else:
module = self.model.model

layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)

return

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None

if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

return held_layers


class DeepseekModelPolicy(DeepseekPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls="DeepseekModel",
new_forward=DeepseekPipelineForwards.deepseek_model_forward,
policy=policy,
)
return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []


class DeepseekForCausalLMPolicy(DeepseekPolicy):
def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)

if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls="DeepseekForCausalLM",
new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward,
policy=policy,
)

return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
deepseek_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: deepseek_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
6 changes: 3 additions & 3 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,16 @@ def get_held_layers(self) -> List[Module]:
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
mixtral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
0: mixtral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
Expand Down
72 changes: 72 additions & 0 deletions tests/test_moe/test_deepseek_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
from transformers import AutoConfig, AutoModel

import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
from colossalai.testing.utils import spawn

tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2


def check_deepseek_moe_layer():
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
)

config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
num_hidden_layers=1,
n_routed_experts=n_experts,
num_experts_per_tok=top_k,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
first_k_dense_replace=0,
num_attention_heads=2,
trust_remote_code=True,
)
torch.manual_seed(0)
# get the moe layer in auto model
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)


def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_deepseek_moe_layer()


# @pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)


if __name__ == "__main__":
test_deepseek_moe_layer(2)
Loading
Loading