Skip to content

Commit

Permalink
fix sp_size bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jun 17, 2024
1 parent 996944d commit a5b38e9
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,8 +1006,16 @@ def __init__(
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_size = sp_size if sp_size is not None else 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size * self.sp_size)
assert (
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)

elif self.sequence_parallelism_mode in ["all_to_all"]:
self.sp_size = dist.get_world_size() // pp_size if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
Expand Down

0 comments on commit a5b38e9

Please sign in to comment.