Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer]fix flash attention, when mask is casual, just don't unpad it #5084

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(
embed_dim=self.hidden_size_per_partition,
Expand Down
9 changes: 5 additions & 4 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,11 +771,12 @@ def forward(
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding

scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx:
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
Expand Down
5 changes: 4 additions & 1 deletion colossalai/shardformer/modeling/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
attn_type = AttnMaskType.paddedcausal
if not torch.all(flash_attention_mask):
attn_type = AttnMaskType.paddedcausal
else:
attn_type = AttnMaskType.causal

attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
Expand Down
1 change: 1 addition & 0 deletions examples/language/llama2/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non

def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
tensor.div_(dist.get_world_size())
return tensor

Expand Down
Loading