Skip to content

Commit

Permalink
2D ring backward + llama passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 14, 2024
1 parent e6bcde2 commit 31f8e34
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def _other_ring_backward(ring_num_idx, dq):
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]

return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None)
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)

@staticmethod
def prepare_varlen_batch(
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ def dist_cross_entropy(
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward

if sp_mode == "ring_attn":
# For Ring Attention, labels should be split and shifted by RingAttention.prepare_varlen_batch()
# and parallel_output must be True
if sp_rank == sp_size - 1:
# For Zigzag Ring Attention, labels should've been split and
# shifted by RingAttention.prepare_varlen_batch()
if sp_rank == 0:
logits = logits[..., :-1, :]
logits = torch.cat([logits, torch.zeros_like(logits[:, :1, :])], dim=seq_dim)
logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
elif is_sp:
# Shift only once: either before splitting or in the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
Expand Down
24 changes: 12 additions & 12 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config",
[
# Double Ring Attention
# {
# "tp_size": 1,
# "pp_size": 1,
# "sp_size": 4,
# "num_microbatches": 1,
# "enable_sequence_parallelism": True,
# "sequence_parallelism_mode": "ring_attn",
# "use_lazy_init": True,
# "zero_stage": 0,
# "precision": "fp16",
# "initial_scale": 1,
# },
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
# Ring Attention + PP
{
"tp_size": 1,
Expand Down

0 comments on commit 31f8e34

Please sign in to comment.