Skip to content

Commit

Permalink
padding vocab_size when using pipeline parallellism
Browse files Browse the repository at this point in the history
padding vocab_size when using pipeline parallellism

fix

fix
  • Loading branch information
flybird11111 committed Mar 7, 2024
1 parent 2c2c3cd commit f3f454e
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 7 deletions.
1 change: 0 additions & 1 deletion applications/Colossal-LLaMA-2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def format_numel_str(numel: int) -> str:

def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor

Expand Down
4 changes: 4 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
Expand All @@ -961,6 +962,7 @@ def __init__(
pp_style: str = "1f1b",
num_model_chunks: int = 1,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 128,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1033,6 +1035,8 @@ def __init__(
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
9 changes: 5 additions & 4 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,11 +783,12 @@ def forward(
scale = scale * (1 / float(self.layer_idx + 1))

# use coloattention
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
if not hasattr(self, "attention"):
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)

attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)

attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,9 @@ def forward(
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def preprocess(self):
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
elif self.shard_config.pipeline_stage_manager is not None:
# padding vocab_size when using pipeline parallellism
new_vocab_size = vocab_size
multiple = self.shard_config.make_vocab_size_divisible_by
while (new_vocab_size % multiple) != 0:
new_vocab_size += 1
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output: bool = True
make_vocab_size_divisible_by: int = 128
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
Expand Down

0 comments on commit f3f454e

Please sign in to comment.