diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 488bab6356f1..0a11c57e3f51 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -238,12 +238,7 @@ def attention( # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." - if attention_mask_type in ( - AttnMaskType.CUSTOM, - AttnMaskType.CAUSAL, - AttnMaskType.PADDED, - AttnMaskType.PADDED_CAUSAL, - ): + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): assert ( cu_seqlens_q is None and cu_seqlens_kv is None @@ -254,9 +249,18 @@ def attention( ) if attention_mask_type == AttnMaskType.CUSTOM: assert not torch.all(attention_mask != 0, dim=-1).any() - else: - # if attention_mask is None, attention_mask_type should be the default value - assert attention_mask_type == AttnMaskType.CUSTOM + elif attention_mask_type in ( + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): + assert ( + cu_seqlens_q is not None + and cu_seqlens_kv is not None + and max_seqlen_q is not None + and max_seqlen_kv is not None + and q_indices is not None + and kv_indices is not None + ) # kernel dispatch mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) @@ -398,24 +402,16 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + assert not (new_lse.isnan().any() or new_lse.isinf().any()), f"lse is nan: {new_lse}" new_block_lse = torch.exp(block_lse - new_lse) out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) - assert _not_nan(new_lse), new_lse - assert _not_nan(new_block_lse), new_block_lse - assert _not_nan(out), out # block_out = block_out.float() - # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse.copy_(lse - F.logsigmoid(lse - block_lse)) # assert not lse.isnan().any(), lse # assert not out.isnan().any(), out -def _not_nan(x): - return not (x.isnan().any() or x.isinf().any()) - - class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). @@ -469,7 +465,7 @@ def attention( deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). Returns: - out: Output tensor. Shape should be [B, Heads, Sq, D] + out: Output tensor. Shape should be [B, Heads, Sq, D] or [T, Heads, D] softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp). Shape should be [B, Heads, Sq] """ @@ -495,23 +491,12 @@ def attention( # (Ex: https://github.com/zhuzilin/ring-flash-attention/blob/49a50141bdce4e76418afe2051646c9a771fe867/test/test_zigzag_ring_flash_attn_varlen_func.py#L43) # Left some logics here; to be supported depending on demands. elif AttnMaskType.PADDED_CAUSAL: - # TODO: compute cu_seqlens locally using valid_positions - assert attention_mask is not None, "Padded attention requires inputing valid token positions!" - # Sequences are padded to the same length in a training round, so reuse the mask info. - if ( - RingAttention.ATTENTION_MASK - and (RingAttention.ATTENTION_MASK.shape == attention_mask.shape) - and (RingAttention.ATTENTION_MASK == attention_mask).all() - ): - cu_seqlens_q = cu_seqlens_kv = RingAttention.CU_SEQLENS - max_seqlen_q = max_seqlen_kv = RingAttention.MAX_SEQLEN - else: - max_seqlen, cu_seqlens, valid_positions = get_pad_info(attention_mask) - RingAttention.CU_SEQLENS = cu_seqlens - RingAttention.MAX_SEQLEN = max_seqlen - RingAttention.ATTENTION_MASK = attention_mask - # To [T, H, D] where T is the number of non-zero tokens - q, k, v = [_unpad_input(x, valid_positions) for x in (q, k, v)] + assert ( + cu_seq_lens_q is not None + and cu_seq_lens_kv is not None + and max_seq_len_q is not None + and max_seq_len_kv is not None + ), "Packed mode requires pre-computed cu_seqlens and max_seqlens." out, softmax_lse = RingAttention.apply( q, @@ -529,11 +514,7 @@ def attention( return_softmax, ) - if attention_mask_type == AttnMaskType.PADDED_CAUSAL: - # Pad and reshape back - # [T, N, D] -> [B, H, Sq, D] - out = _pad_input(out, valid_positions, b, sq) - else: + if not attention_mask_type == AttnMaskType.PADDED_CAUSAL: out = out.transpose(1, 2) # [B, Sq, H, D] -> [B, H, Sq, D] if return_softmax: @@ -575,8 +556,6 @@ def forward( b, h, sq, d = q.shape # (B, H, Sq, D) -> (B, Sq, H, D) q, k, v = [x.transpose(1, 2) for x in (q, k, v)] - assert _not_nan(q), q - assert _not_nan(k), k kv_comms = [RingComm(sp_group) for _ in range(2)] sp_size = kv_comms[0].world_size sp_rank = kv_comms[0].rank @@ -638,7 +617,6 @@ def forward( kv_block = kv_buffers[i % 2] # (2, B * Sq // 2, H, D) kv_block = kv_block.view(2, b * sq, h, d)[:, : b * sq // 2].clone() - assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}" ( _, _, @@ -665,7 +643,6 @@ def forward( # Drop the first half of q q_block = q.view(b * sq, h, d)[b * sq // 2 :] kv_block = kv_buffers[i % 2].view(2, b * sq, h, d).clone() - assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}" ( _, @@ -696,11 +673,7 @@ def forward( block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(1, 2).contiguous().unsqueeze(-1).float() ) # (B, H, Sq) -> (B, Sq, H, 1) - assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] - assert _not_nan( - block_softmax_lse[i % 2] - ), f"rank {sp_rank} step {i} softmax_lse is nan: {block_softmax_lse[i % 2]}" # Overlap output correction with next flash attn kernel if i == 0: @@ -767,7 +740,6 @@ def backward(ctx, dout, _): assert ( out.shape == dout.shape == (b, sq, h, d) ), f"out {out.shape} and dout {dout.shape} should have shape ({b}, {sq}, {h}, {d}) instead" - assert _not_nan(dout), f"dout is nan" # Sequence parallel args sp_group = ctx.sp_group sp_rank = dist.get_rank(sp_group) @@ -887,10 +859,6 @@ def backward(ctx, dout, _): # Wait for mobile kv grad accumulators dkv_comm.wait() - assert _not_nan(dq_block), f"rank {dist.get_rank()} step {i} dq_block is nan" - assert _not_nan(dkv_recv), f"rank {dist.get_rank()} step {i} dkv_buffers is nan" - assert _not_nan(dq) - if i <= sp_rank: # q blocks "surrounded" by kv blocks dkv_recv[0][:, : sq // 2] += dk_block[:, : sq // 2] # (B, Sq // 2, H, D) @@ -899,14 +867,10 @@ def backward(ctx, dout, _): # q blocks "surrounding" kv blocks dkv_recv[0] += dk_block dkv_recv[1] += dv_block - if dist.get_rank() == 0: - torch.save(dkv_recv, f"colo_step_{i}.pt") + dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) dkv_comm.wait() dkv_recv = dkv_send dq, dk, dv = [x.view(b, sq, h, d).transpose(1, 2).to(q.dtype) for x in (dq, *dkv_recv)] - assert _not_nan(dq), f"dq is nan" - assert _not_nan(dk), f"dk is nan" - assert _not_nan(dv), f"dv is nan" return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index d4206cbe6f94..31da5b96aae4 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -295,9 +295,9 @@ def zigzag_split_batch( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False ): """ - Split the input along the sequence dimension for Ring Attention. As naively spliting sequence - in the causual setting will result in the first ranks having much less workload than the last ranks, - we split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). + Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask + in the causal setting will result in the preceding ranks having much less workload. + We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. Args: diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index af1e794374ed..694c5cf91acc 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -17,7 +17,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost. ## Our Modifications diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md index a01209cbda0e..3776e0c64552 100644 --- a/examples/tutorial/opt/opt/README.md +++ b/examples/tutorial/opt/opt/README.md @@ -19,7 +19,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost. We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). diff --git a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py deleted file mode 100644 index a6742e04a696..000000000000 --- a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ( - ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - batch_size = 1 - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - - if forward_only: - with torch.no_grad(): - for _ in range(num_iter): - _ = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - - else: - for _ in range(num_iter): - qkv.grad = None - out = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_qkvpacked_func, - ring_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py deleted file mode 100644 index 18c8cafc0078..000000000000 --- a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func, zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - - cu_seqlens_list = [ - torch.tensor([0, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32), - ] - max_seqlen_list = [(cu_seqlens[1:] - cu_seqlens[:1]).max().item() for cu_seqlens in cu_seqlens_list] - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - if forward_only: - with torch.no_grad(): - for i in range(num_iter): - _ = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - else: - for i in range(num_iter): - qkv.grad = None - out = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time} iter/s, {time} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_varlen_qkvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/ring_flash_attn/__init__.py b/ring-flash-attention/ring_flash_attn/__init__.py deleted file mode 100644 index 01d5ec36218c..000000000000 --- a/ring-flash-attention/ring_flash_attn/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func -from .ring_flash_attn_varlen import ( - ring_flash_attn_varlen_func, - ring_flash_attn_varlen_kvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, -) -from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func -from .zigzag_ring_flash_attn import ( - zigzag_ring_flash_attn_func, - zigzag_ring_flash_attn_kvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) -from .zigzag_ring_flash_attn_varlen import ( - zigzag_ring_flash_attn_varlen_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, -) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py deleted file mode 100644 index b36484dbd145..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py +++ /dev/null @@ -1,281 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py deleted file mode 100644 index 118bdea4c7d0..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py +++ /dev/null @@ -1,318 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py deleted file mode 100644 index ca426920f4ed..000000000000 --- a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py +++ /dev/null @@ -1,325 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def stripe_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q[:, 1:], - k[:, :-1], - v[:, :-1], - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None))) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def stripe_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - shift_causal = step > kv_comm.rank - softmax_lse_1 = None - if not shift_causal: - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - else: - if softmax_lse_1 is None: - # lazy init, since the last rank does not need softmax_lse_1 - softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() - _flash_attn_backward( - dout[:, 1:], - q[:, 1:], - k[:, :-1], - v[:, :-1], - out[:, 1:], - softmax_lse_1, - block_dq_buffer[:, 1:], - block_dk_buffer[:, :-1], - block_dv_buffer[:, :-1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - if not shift_causal: - dq += block_dq_buffer - else: - dq[:, 1:] += block_dq_buffer[:, 1:] - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk = next_dk - dv = next_dv - - if not shift_causal: - dk = block_dk_buffer + dk - dv = block_dv_buffer + dv - else: - dk[:, :-1] += block_dk_buffer[:, :-1] - dv[:, :-1] += block_dv_buffer[:, :-1] - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class StripeFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = stripe_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = stripe_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def stripe_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/triton_utils.py b/ring-flash-attention/ring_flash_attn/triton_utils.py deleted file mode 100644 index 66e362d93d68..000000000000 --- a/ring-flash-attention/ring_flash_attn/triton_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def flatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_nheads, - stride_out_seqlen, - stride_lse_batch, - stride_lse_nheads, - stride_lse_seqlen, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads - OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def flatten_varlen_lse(lse, cu_seqlens): - """ - Arguments: - lse: (batch_size, nheads, max_seqlen) - cu_seqlens: (batch_size + 1,) - Return: - flatten_lse: (nheads, total_seqlen) - """ - total_seqlen = cu_seqlens[-1] - batch_size, nheads, max_seqlen = lse.shape - output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - flatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - lse.stride(0), - lse.stride(1), - lse.stride(2), - BLOCK_M, - ) - return output - - -@triton.jit -def unflatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_batch, - stride_out_nheads, - stride_out_seqlen, - stride_lse_seqlen, - stride_lse_nheads, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - """ - Arguments: - lse: (total_seqlen, nheads, 1) - cu_seqlens: (batch_size + 1,) - max_seqlen: int - Return: - unflatten_lse: (batch_size, nheads, max_seqlen) - """ - lse = lse.unsqueeze(dim=-1) - batch_size = len(cu_seqlens) - 1 - nheads = lse.shape[1] - output = torch.empty( - (batch_size, nheads, max_seqlen), - dtype=lse.dtype, - device=lse.device, - ) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - unflatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - output.stride(2), - lse.stride(0), - lse.stride(1), - BLOCK_M, - ) - return output diff --git a/ring-flash-attention/ring_flash_attn/utils.py b/ring-flash-attention/ring_flash_attn/utils.py deleted file mode 100644 index 787732af8135..000000000000 --- a/ring-flash-attention/ring_flash_attn/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -__all__ = ["update_out_and_lse", "RingComm"] - - -@torch.jit.script -def _update_out_and_lse( - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - - # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out - # For additional context and discussion, please refer to: - # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - - return out, lse - - -def update_out_and_lse( - out: Optional[torch.Tensor], - lse: Optional[torch.Tensor], - block_out: torch.Tensor, - block_lse: torch.Tensor, - slice_=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - if out is None: - if slice_ is not None: - raise RuntimeError("first update_out_and_lse should not pass slice_ args") - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - elif slice_ is not None: - slice_out, slice_lse = out[slice_], lse[slice_] - slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) - out[slice_], lse[slice_] = slice_out, slice_lse - else: - out, lse = _update_out_and_lse(out, lse, block_out, block_lse) - return out, lse - - -@torch.jit.script -def flatten_varlen_lse(lse, cu_seqlens): - new_lse = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse.append(lse[i, :, : end - start]) - return torch.cat(new_lse, dim=1) - - -@torch.jit.script -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - num_seq = len(cu_seqlens) - 1 - num_head = lse.shape[-2] - new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) - for i in range(num_seq): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse[i, : end - start] = lse[start:end] - return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() - - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = None - - self.send_rank = (self.rank + 1) % self.world_size - self.recv_rank = (self.rank - 1) % self.world_size - - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py deleted file mode 100644 index d3e2821c5d4d..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py +++ /dev/null @@ -1,327 +0,0 @@ -import torch -import torch.distributed as dist -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def zigzag_ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[1] // 2 - q1 = q[:, block_seq_len:] - - out = None - lse = None - next_k, next_v = None, None - - def forward(q, k, v, causal): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - block_out, block_lse = forward(q, k0, v0, causal=False) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - out, lse = update_out_and_lse( - out, - lse, - block_out, - block_lse, - slice_=(slice(None), slice(block_seq_len, None)), - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def zigzag_ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout.chunk(2, dim=1)[1] - q1 = q.chunk(2, dim=1)[1] - out1 = out.chunk(2, dim=1)[1] - softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() - block_seq_len = q.shape[1] // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[1] - seqlen_kv = k.shape[1] - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:, :seqlen_q], - dk_buffer[:, :seqlen_kv], - dv_buffer[:, :seqlen_kv], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - # always use the first half in dq_buffer. - dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] - dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - if dist.get_rank() == 0: - torch.save(torch.stack((dk, dv)), f"step_{step}.pt") - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = zigzag_ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = zigzag_ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py deleted file mode 100644 index 5d4a8dd2daf0..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py +++ /dev/null @@ -1,441 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def get_half_index(cu_seqlens, *, front: bool): - if len(cu_seqlens) == 2: - if front: - return slice(None, cu_seqlens[-1] // 2) - else: - return slice(cu_seqlens[-1] // 2, None) - - index = torch.zeros((cu_seqlens[-1],), dtype=bool) - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - if front: - end = (start + end) // 2 - else: - start = (start + end) // 2 - index[start:end] = True - return index - - -@torch.jit.script -def get_half_lse(lse, cu_seqlens, *, front: bool): - new_lse = torch.empty( - (lse.shape[0], lse.shape[1], lse.shape[2] // 2), - dtype=lse.dtype, - device=lse.device, - ) - for i in range(len(cu_seqlens) - 1): - seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() - if front: - start, end = 0, seqlen // 2 - else: - start, end = seqlen // 2, seqlen - new_lse[i, :, : seqlen // 2] = lse[i, :, start:end] - return new_lse - - -def zigzag_ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[0] // 2 - q1 = q[half_index1] - - out = None - lse = None - next_k, next_v = None, None - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - def forward(q, k, v, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - block_out, block_lse = forward(q, k0, v0, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=half_cu_seqlens, - ) - out[half_index1], lse[half_index1] = update_out_and_lse( - out[half_index1], lse[half_index1], block_out, block_lse - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def zigzag_ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout[half_index1] - q1 = q[half_index1] - out1 = out[half_index1] - softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False) - block_seq_len = q.shape[0] // 2 - - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:seqlen_q], - dk_buffer[:seqlen_kv], - dv_buffer[:seqlen_kv], - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - dq[half_index1] += dq_buffer[:block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[half_index0] += dk_buffer[:block_seq_len] - dv[half_index0] += dv_buffer[:block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - half_index0 = get_half_index(cu_seqlens, front=True) - half_index1 = get_half_index(cu_seqlens, front=False) - out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - is_half_index_tensor = isinstance(half_index0, torch.Tensor) - ctx.is_half_index_tensor = is_half_index_tensor - if is_half_index_tensor: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) - else: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.half_index0 = half_index0 - ctx.half_index1 = half_index1 - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - if ctx.is_half_index_tensor: - (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors - else: - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - half_index0 = ctx.half_index0 - half_index1 = ctx.half_index1 - dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - half_index0, - half_index1, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/setup.py b/ring-flash-attention/setup.py deleted file mode 100644 index 58413e1b54f3..000000000000 --- a/ring-flash-attention/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="ring_flash_attn", - version="0.1", - author="zhuzilin", - url="https://github.com/zhuzilin/ring-flash-attention", - packages=find_packages(), -) diff --git a/ring-flash-attention/test/test_ring_flash_attn_func.py b/ring-flash-attention/test/test_ring_flash_attn_func.py deleted file mode 100644 index 50edd03bef4e..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_func.py +++ /dev/null @@ -1,124 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ring_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3816 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % world_size == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() - local_qkv.requires_grad = True - local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = out.chunk(world_size, dim=1)[rank] - local_lse = lse.chunk(world_size, dim=-1)[rank] - - fn = ring_flash_attn_qkvpacked_func - - ring_out, ring_lse, _ = fn( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = dqkv.chunk(world_size, dim=1)[rank] - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py deleted file mode 100644 index 51bb1ec5d67d..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,157 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() - local_values.append(local_value) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 120, 1248, 4232] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % world_size == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for lse, ring_lse in zip(lse_list, ring_lse_list): - local_lse = lse.chunk(world_size, dim=-1)[rank] - log("lse", lse, rank0_only=True) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/ring-flash-attention/test/test_stripe_flash_attn_func.py b/ring-flash-attention/test/test_stripe_flash_attn_func.py deleted file mode 100644 index dc9f5248d69d..000000000000 --- a/ring-flash-attention/test/test_stripe_flash_attn_func.py +++ /dev/null @@ -1,130 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import stripe_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(dim, dim + 1) - slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] - return value[slicer].contiguous() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - local_dout = extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - local_lse = extract_local(lse, rank, world_size, dim=2) - - ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - - local_dqkv = extract_local(dqkv, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk0 diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_triton_kernels.py b/ring-flash-attention/test/test_triton_kernels.py deleted file mode 100644 index aa1c1fdcd338..000000000000 --- a/ring-flash-attention/test/test_triton_kernels.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from ring_flash_attn.triton_utils import flatten_varlen_lse as triton_flatten_varlen_lse -from ring_flash_attn.triton_utils import unflatten_varlen_lse as triton_unflatten_varlen_lse -from ring_flash_attn.utils import flatten_varlen_lse, unflatten_varlen_lse - -if __name__ == "__main__": - device = torch.device("cuda:0") - - cu_seqlens = [0, 15, 156, 529] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - batch_size = len(cu_seqlens) - 1 - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - n_head = 5 - - lse = torch.randn((batch_size, n_head, max_seqlen), dtype=torch.float32, device=device) - flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) - triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) - assert torch.all(flatten_lse == triton_flatten_lse) - - flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - - unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) - triton_unflatten_lse = triton_unflatten_varlen_lse(triton_flatten_lse, cu_seqlens_tensor, max_seqlen) - - for i in range(batch_size): - seqlen = cu_seqlens[i + 1] - cu_seqlens[i] - assert torch.all( - unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] - ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py deleted file mode 100644 index 5f84bc58cf10..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import random - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func - -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value_chunks = value.chunk(2 * world_size, dim=dim) - local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) - return local_value.contiguous() - - -def run_test(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" # or the IP of the master node - os.environ["MASTER_PORT"] = "8125" # make sure this port is free - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - set_seed(rank) - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - # local_lse = extract_local(lse, rank, world_size, dim=2) - q, k, v = local_qkv.chunk(3, dim=2) - q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] - q.requires_grad = k.requires_grad = v.requires_grad = True - sp_stream = torch.cuda.Stream() - sp_group = dist.new_group() - colo_out = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - log("colo_out", colo_out, rank0_only=True) - log("ring_out", ring_out, rank0_only=True) - # log("lse", lse, rank0_only=True) - log("colo_out - ring_out", colo_out - ring_out) - # log("lse diff", local_lse - ring_lse) - log("ring_out - local_out", ring_out - local_out) - log("colo_out - local_out", colo_out - local_out) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - colo_out.sum().backward() - qkv.grad - # q, k, v = [x.transpose(1, 2) for x in (q, k, v)] - colo_dq, colo_dk, colo_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] - - ring_out.sum().backward() - ring_dqkv = local_qkv.grad - out.sum().backward() - dqkv = extract_local(qkv.grad, rank, world_size) - - # log("colo_dq", colo_dq) - log("dq diff", colo_dq - ring_dqkv[:, :, 0, :]) - - # log("colo_dk", colo_dk) - log("dk diff", colo_dk - ring_dqkv[:, :, 1, :]) - - # log("colo_dv", colo_dv) - log("dv diff", colo_dv - ring_dqkv[:, :, 2, :]) - log("colo_dv - local_dv", colo_dv - dqkv[:, :, 2, :]) - - -if __name__ == "__main__": - world_size = 4 - mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py deleted file mode 100644 index 7f6eced6e57b..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,163 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(2 * world_size, dim=0) - local_values.extend( - [ - local_value[rank].detach().clone(), - local_value[2 * world_size - 1 - rank].detach().clone(), - ] - ) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 128, 1248, 4240] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): - local_lse = lse.chunk(2 * world_size, dim=-1) - local_lse = torch.cat([local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1) - log(f"lse {i}", lse, rank0_only=True) - log(f"lse diff {i}", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv - ring_dqkv) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 14e3dbe08acb..51f31d89adab 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -23,7 +23,7 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype): # Some outliers may seem large, but our errors are still much lower than # than Megatron-LM's context parallel - # https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215 + # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # and the original zigzag implementation: https://github.com/zhuzilin/ring-flash-attention/tree/main atol = rtol = 7e-3