Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] fix gathering output when using tensor parallelism #5431

Merged
merged 15 commits into from
Mar 18, 2024
4 changes: 4 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
Expand All @@ -963,6 +964,7 @@ def __init__(
pp_style: str = "1f1b",
num_model_chunks: int = 1,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 128,
ver217 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1035,6 +1037,8 @@ def __init__(
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
9 changes: 5 additions & 4 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,12 @@ def forward(
scale = scale * (1 / float(self.layer_idx + 1))

# use coloattention
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
if not hasattr(self, "attention"):
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)

attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)

attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,9 @@ def forward(
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def preprocess(self):
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
elif self.shard_config.pipeline_stage_manager is not None:
ver217 marked this conversation as resolved.
Show resolved Hide resolved
# padding vocab_size when using pipeline parallellism
new_vocab_size = vocab_size
multiple = self.shard_config.make_vocab_size_divisible_by
while (new_vocab_size % multiple) != 0:
ver217 marked this conversation as resolved.
Show resolved Hide resolved
new_vocab_size += 1
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output = True
parallel_output: bool = True
make_vocab_size_divisible_by: int = 128
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
Expand Down
Loading