From 31f8e34aad1c759bbace17f0446d158732bfd7de Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 14 Aug 2024 08:19:29 +0000 Subject: [PATCH] 2D ring backward + llama passed --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/layer/loss.py | 8 +++---- .../test_model/test_shard_llama.py | 24 +++++++++---------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e0fcd3cef876..4ff028a24520 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1090,7 +1090,7 @@ def _other_ring_backward(ring_num_idx, dq): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 64732f1e4dfa..12df824d1c0c 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -182,11 +182,11 @@ def dist_cross_entropy( split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward if sp_mode == "ring_attn": - # For Ring Attention, labels should be split and shifted by RingAttention.prepare_varlen_batch() - # and parallel_output must be True - if sp_rank == sp_size - 1: + # For Zigzag Ring Attention, labels should've been split and + # shifted by RingAttention.prepare_varlen_batch() + if sp_rank == 0: logits = logits[..., :-1, :] - logits = torch.cat([logits, torch.zeros_like(logits[:, :1, :])], dim=seq_dim) + logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) elif is_sp: # Shift only once: either before splitting or in the last rank without splitting if split_labels_here or (sp_rank == sp_size - 1): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 35a706831102..34bb9e414dde 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -154,18 +154,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "test_config", [ # Double Ring Attention - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 4, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, # Ring Attention + PP { "tp_size": 1,