You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please check that this issue hasn't been reported before.
I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
Flash attention should not make training losses differs a lot.
Current behaviour
I did preliminary experiments on Gemma 2b with different datasets. When flash attention is on, the loss is significantly lower than when flash attention is off.
Please see the figure below. The wandb run name with -flash means the flash attention is on.
However, the validation losses are normal.
Steps to reproduce
Simply enable and disable flash attention in the configuration.
Config yaml
No response
Possible solution
Is there anything wrong with loss calculation when flash_attention is off? Since usually the training loss should be slightly lower than validation loss.
Which Operating Systems are you using?
Linux
macOS
Windows
Python Version
3.10
axolotl branch-commit
main/4d6490b
Acknowledgements
My issue title is concise, descriptive, and in title casing.
I have searched the existing issues to make sure this bug has not been reported yet.
I am using the latest version of axolotl.
I have provided enough information for the maintainers to reproduce and diagnose the issue.
The text was updated successfully, but these errors were encountered:
Please check that this issue hasn't been reported before.
Expected Behavior
Flash attention should not make training losses differs a lot.
Current behaviour
I did preliminary experiments on Gemma 2b with different datasets. When flash attention is on, the loss is significantly lower than when flash attention is off.
Please see the figure below. The wandb run name with -flash means the flash attention is on.
However, the validation losses are normal.
Steps to reproduce
Simply enable and disable flash attention in the configuration.
Config yaml
No response
Possible solution
Is there anything wrong with loss calculation when flash_attention is off? Since usually the training loss should be slightly lower than validation loss.
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
main/4d6490b
Acknowledgements
The text was updated successfully, but these errors were encountered: