From a5b38e9cc543a8c82ca4908a8e250b2a8076b6a1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 17 Jun 2024 03:33:41 +0000 Subject: [PATCH] fix sp_size bug --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3874c8f1f2c2..a2b5db8c0fed 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1006,8 +1006,16 @@ def __init__( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" if self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_size = sp_size if sp_size is not None else 1 - self.dp_size = dist.get_world_size() // (tp_size * pp_size * self.sp_size) + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: self.sp_size = dist.get_world_size() // pp_size if sp_size is None else sp_size self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)