Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] fix gathering output when using tensor parallelism #5431

Merged
merged 15 commits into from
Mar 18, 2024
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):

if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
Expand Down Expand Up @@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
Expand Down Expand Up @@ -939,6 +945,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
Expand Down Expand Up @@ -1035,6 +1042,7 @@ def __init__(
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
16 changes: 12 additions & 4 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward


class GPT2PipelineForwards:
Expand Down Expand Up @@ -337,6 +338,9 @@ def gpt2_lmhead_model_forward(
else:
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -793,11 +797,12 @@ def forward(
scale = scale * (1 / float(self.layer_idx + 1))

# use coloattention
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
if not hasattr(self, "attention"):
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)

attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)

attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand Down Expand Up @@ -1083,6 +1088,9 @@ def forward(
else:
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down
11 changes: 6 additions & 5 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer._operation import _gather
from ..layer._operation import gather_forward_split_backward

try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -290,7 +290,7 @@ def llama_for_causal_lm_forward(
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -485,8 +485,9 @@ def forward(
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -593,7 +594,7 @@ def forward(
loss = loss_fct(shift_logits, shift_labels)

if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,4 @@ def get_stage_index(
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
return stage_indices[0] if num_model_chunks == 1 else stage_indices
4 changes: 3 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output = True
parallel_output: bool = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_3d_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)


def run_dist(rank, world_size, port, early_stop: bool = True):
Expand Down
Loading