Skip to content

Commit

Permalink
Update flash_attention_patch.py
Browse files Browse the repository at this point in the history
To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
huggingface/transformers#25598
  • Loading branch information
Orion-Zheng authored and flybird11111 committed Nov 9, 2023
1 parent dd59ca2 commit 52707c6
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
Expand Down

0 comments on commit 52707c6

Please sign in to comment.