Skip to content

Commit

Permalink
fix typos and remove misc files
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 22, 2024
1 parent 4bad575 commit c46f7c8
Show file tree
Hide file tree
Showing 22 changed files with 29 additions and 2,961 deletions.
82 changes: 23 additions & 59 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,7 @@ def attention(
# sanity check
if attention_mask is not None:
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
if attention_mask_type in (
AttnMaskType.CUSTOM,
AttnMaskType.CAUSAL,
AttnMaskType.PADDED,
AttnMaskType.PADDED_CAUSAL,
):
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
assert (
cu_seqlens_q is None
and cu_seqlens_kv is None
Expand All @@ -254,9 +249,18 @@ def attention(
)
if attention_mask_type == AttnMaskType.CUSTOM:
assert not torch.all(attention_mask != 0, dim=-1).any()
else:
# if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM
elif attention_mask_type in (
AttnMaskType.PADDED,
AttnMaskType.PADDED_CAUSAL,
):
assert (
cu_seqlens_q is not None
and cu_seqlens_kv is not None
and max_seqlen_q is not None
and max_seqlen_kv is not None
and q_indices is not None
and kv_indices is not None
)
# kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
Expand Down Expand Up @@ -398,24 +402,16 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))

new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
assert not (new_lse.isnan().any() or new_lse.isinf().any()), f"lse is nan: {new_lse}"
new_block_lse = torch.exp(block_lse - new_lse)
out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out)
lse.copy_(new_lse)
assert _not_nan(new_lse), new_lse
assert _not_nan(new_block_lse), new_block_lse
assert _not_nan(out), out

# block_out = block_out.float()
# out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out))
# lse.copy_(lse - F.logsigmoid(lse - block_lse))
# assert not lse.isnan().any(), lse
# assert not out.isnan().any(), out


def _not_nan(x):
return not (x.isnan().any() or x.isinf().any())


class RingAttention(torch.autograd.Function):
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
(https://arxiv.org/abs/2310.01889).
Expand Down Expand Up @@ -469,7 +465,7 @@ def attention(
deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
Returns:
out: Output tensor. Shape should be [B, Heads, Sq, D]
out: Output tensor. Shape should be [B, Heads, Sq, D] or [T, Heads, D]
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
Shape should be [B, Heads, Sq]
"""
Expand All @@ -495,23 +491,12 @@ def attention(
# (Ex: https://github.com/zhuzilin/ring-flash-attention/blob/49a50141bdce4e76418afe2051646c9a771fe867/test/test_zigzag_ring_flash_attn_varlen_func.py#L43)
# Left some logics here; to be supported depending on demands.
elif AttnMaskType.PADDED_CAUSAL:
# TODO: compute cu_seqlens locally using valid_positions
assert attention_mask is not None, "Padded attention requires inputing valid token positions!"
# Sequences are padded to the same length in a training round, so reuse the mask info.
if (
RingAttention.ATTENTION_MASK
and (RingAttention.ATTENTION_MASK.shape == attention_mask.shape)
and (RingAttention.ATTENTION_MASK == attention_mask).all()
):
cu_seqlens_q = cu_seqlens_kv = RingAttention.CU_SEQLENS
max_seqlen_q = max_seqlen_kv = RingAttention.MAX_SEQLEN
else:
max_seqlen, cu_seqlens, valid_positions = get_pad_info(attention_mask)
RingAttention.CU_SEQLENS = cu_seqlens
RingAttention.MAX_SEQLEN = max_seqlen
RingAttention.ATTENTION_MASK = attention_mask
# To [T, H, D] where T is the number of non-zero tokens
q, k, v = [_unpad_input(x, valid_positions) for x in (q, k, v)]
assert (
cu_seq_lens_q is not None
and cu_seq_lens_kv is not None
and max_seq_len_q is not None
and max_seq_len_kv is not None
), "Packed mode requires pre-computed cu_seqlens and max_seqlens."

out, softmax_lse = RingAttention.apply(
q,
Expand All @@ -529,11 +514,7 @@ def attention(
return_softmax,
)

if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
# Pad and reshape back
# [T, N, D] -> [B, H, Sq, D]
out = _pad_input(out, valid_positions, b, sq)
else:
if not attention_mask_type == AttnMaskType.PADDED_CAUSAL:
out = out.transpose(1, 2) # [B, Sq, H, D] -> [B, H, Sq, D]

if return_softmax:
Expand Down Expand Up @@ -575,8 +556,6 @@ def forward(
b, h, sq, d = q.shape
# (B, H, Sq, D) -> (B, Sq, H, D)
q, k, v = [x.transpose(1, 2) for x in (q, k, v)]
assert _not_nan(q), q
assert _not_nan(k), k
kv_comms = [RingComm(sp_group) for _ in range(2)]
sp_size = kv_comms[0].world_size
sp_rank = kv_comms[0].rank
Expand Down Expand Up @@ -638,7 +617,6 @@ def forward(
kv_block = kv_buffers[i % 2]
# (2, B * Sq // 2, H, D)
kv_block = kv_block.view(2, b * sq, h, d)[:, : b * sq // 2].clone()
assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}"
(
_,
_,
Expand All @@ -665,7 +643,6 @@ def forward(
# Drop the first half of q
q_block = q.view(b * sq, h, d)[b * sq // 2 :]
kv_block = kv_buffers[i % 2].view(2, b * sq, h, d).clone()
assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}"

(
_,
Expand Down Expand Up @@ -696,11 +673,7 @@ def forward(
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(1, 2).contiguous().unsqueeze(-1).float()
) # (B, H, Sq) -> (B, Sq, H, 1)

assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
assert _not_nan(
block_softmax_lse[i % 2]
), f"rank {sp_rank} step {i} softmax_lse is nan: {block_softmax_lse[i % 2]}"

# Overlap output correction with next flash attn kernel
if i == 0:
Expand Down Expand Up @@ -767,7 +740,6 @@ def backward(ctx, dout, _):
assert (
out.shape == dout.shape == (b, sq, h, d)
), f"out {out.shape} and dout {dout.shape} should have shape ({b}, {sq}, {h}, {d}) instead"
assert _not_nan(dout), f"dout is nan"
# Sequence parallel args
sp_group = ctx.sp_group
sp_rank = dist.get_rank(sp_group)
Expand Down Expand Up @@ -887,10 +859,6 @@ def backward(ctx, dout, _):

# Wait for mobile kv grad accumulators
dkv_comm.wait()
assert _not_nan(dq_block), f"rank {dist.get_rank()} step {i} dq_block is nan"
assert _not_nan(dkv_recv), f"rank {dist.get_rank()} step {i} dkv_buffers is nan"
assert _not_nan(dq)

if i <= sp_rank:
# q blocks "surrounded" by kv blocks
dkv_recv[0][:, : sq // 2] += dk_block[:, : sq // 2] # (B, Sq // 2, H, D)
Expand All @@ -899,14 +867,10 @@ def backward(ctx, dout, _):
# q blocks "surrounding" kv blocks
dkv_recv[0] += dk_block
dkv_recv[1] += dv_block
if dist.get_rank() == 0:
torch.save(dkv_recv, f"colo_step_{i}.pt")

dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
dkv_comm.wait()
dkv_recv = dkv_send

dq, dk, dv = [x.view(b, sq, h, d).transpose(1, 2).to(q.dtype) for x in (dq, *dkv_recv)]
assert _not_nan(dq), f"dq is nan"
assert _not_nan(dk), f"dk is nan"
assert _not_nan(dv), f"dv is nan"
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None)
6 changes: 3 additions & 3 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ def zigzag_split_batch(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False
):
"""
Split the input along the sequence dimension for Ring Attention. As naively spliting sequence
in the causual setting will result in the first ranks having much less workload than the last ranks,
we split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
Args:
Expand Down
2 changes: 1 addition & 1 deletion examples/language/opt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.


## Our Modifications
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial/opt/opt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.

We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
Expand Down
87 changes: 0 additions & 87 deletions ring-flash-attention/benchmark/benchmark_qkvpacked_func.py

This file was deleted.

91 changes: 0 additions & 91 deletions ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py

This file was deleted.

Loading

0 comments on commit c46f7c8

Please sign in to comment.