diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 4b1b82b7c770..7bcdf6fc9892 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -109,6 +109,19 @@ def setup_process_groups( for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) + if self.config.n_shared_experts is not None: + self.shared_experts.gate_proj = Linear1D_Col.from_native_module( + self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + + self.shared_experts.up_proj = Linear1D_Col.from_native_module( + self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + + self.shared_experts.down_proj = Linear1D_Row.from_native_module( + self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + @staticmethod def from_native_module( module, diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index d782a2a09604..4b92dbdee4bf 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -20,14 +20,15 @@ NUM_BATCH = 8 NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4 NUM_LAYERS = 4 -HIDDEN_SIZE_PER_HEAD = 4 +HIDDEN_SIZE_PER_HEAD = 8 NUM_HEADS = 8 TOP_K = 2 -def run_deepseek_commom(config: Tuple[int, ...]): +def run_deepseek_commom(parallel_config: Tuple[int, ...]): Randomizer.reset_index() - stage, ep_size, pp_size, tp_size, sp_size = config + print(f"rank {dist.get_rank()} testing {parallel_config}") + stage, ep_size, pp_size, tp_size, sp_size = parallel_config world_size = dist.get_world_size() rank = dist.get_rank() dtype, precision = torch.bfloat16, "bf16" @@ -65,6 +66,7 @@ def run_deepseek_commom(config: Tuple[int, ...]): attn_implementation="flash_attention_2", torch_dtype="float16", n_routed_experts=NUM_EXPERTS, + n_shared_experts=2, num_experts_per_tok=TOP_K, trust_remote_code=True, ) @@ -159,7 +161,7 @@ def run_deepseek_commom(config: Tuple[int, ...]): if rank == world_size - 1: shutil.rmtree(model_dir) - print(f"rank {dist.get_rank()} test passed") + print(f"rank {dist.get_rank()} passed {parallel_config}") @parameterize(