Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sum loss instead of mean loss should be used if gradient accumulation step is larger than 1 when training a language model #24725

Closed
4 tasks
Atry opened this issue Jul 10, 2023 · 16 comments

Comments

@Atry
Copy link

Atry commented Jul 10, 2023

System Info

Not applicable, because this is a design issue, not a runtime error.

Who can help?

@sgugger, @ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Given gradient accumulation step 2, batch size 1, and a training set of 2 samples, where sample 1 contains 11 tokens and sample 2 contains 101 tokens, train a decoder-only model in unsupervised learning (first token in each sample is untrainable), then the gradient will be different from training on same dataset and model at gradient accumulation step 1, batch size 2.

The reason is that currently transformers use mean loss for most models (if not all), as a result, each token in sample 1 would produce 10 times larger gradient than that of each token in sample 2.

Expected behavior

Settings of accumulation step 2 / batch size 1 should produce the same gradient as settings of accumulation step 1 / batch size 2.

@Atry Atry changed the title Sum loss instead of mean loss should be used when gradient accumulation step is larger than 1 when training a language model Sum loss instead of mean loss should be used if gradient accumulation step is larger than 1 when training a language model Jul 10, 2023
@Atry Atry closed this as completed Jul 10, 2023
@Atry Atry reopened this Jul 10, 2023
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 10, 2023

Hi @Atry

Your description is correct. However, the loss logic is implemented in each model classes, and therefore it could not see multiple batches in a single model forward pass (and that's probably the main reason for which we just simply use mean).

The best and easy way to have a correct computation if to modify the trainer class to compute back, given the loss from model output, compute the sum of losses in a batch (by considering the sequence length, or total number of tokens that is meaningful - i.e. not padding token etc.), and send this new custom loss values to compute the gradients then accumulate it.

@Atry
Copy link
Author

Atry commented Jul 10, 2023

Computing back the gradient would damage the precision if the gradient is in fp16.

@Atry
Copy link
Author

Atry commented Jul 10, 2023

An idea is to switch all models to sum loss and create a custom GradientScaler to count the number of trainable tokens.

@Atry
Copy link
Author

Atry commented Jul 10, 2023

By the way there is another example of the issue in mean loss. Suppose you have batch size 33, 1 epoch, a data set of 100 samples, then the last iteration will have only 1 sample and the gradient produced by the last sample is 33 times larger than other samples'.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 10, 2023

switch all models to sum loss

This would be a big breaking change, and would not be an option.

Computing back the gradient would damage the precision if the gradient is in fp16

I would not think it will produce a big difference, if at the end, we still use some form of mean after we accumulate (sum) all the gradients (saying divided by the total number of non-padding tokens appear in all the batches in a gradient accumulation).

When the loss is computed by sum in a batch, it actually requires specific work to perform to get back to the usual definition of that loss (say the average non-padding token loss) when we sum over all batches.

(Here I only say non-padding token. But loss definition could get very complex depending on the tasks and the specific models)

@Atry
Copy link
Author

Atry commented Jul 10, 2023

As studied in https://arxiv.org/abs/1711.00489, changing batch size would have a side effect to also change learning rate per sample (and learning rate per token) even when the learning rate per iteration is unchanged. However their analysis to their experiment result is non-sense. The actual explanation is that the side effect is just due to the mean loss. Sum loss would not lead to the side effect.

@sgugger
Copy link
Collaborator

sgugger commented Jul 11, 2023

If you are not happy with the loss computation inside the model, you can just not pass the labels to the model and compute it yourself outside of the forward pass. Note that all of our examples account for gradient accumulation by dividing the final loss by the number of gradient accumulation steps.

As @ydshieh mentioned, a breaking change across all models of this magnitude is not possible.

@Atry
Copy link
Author

Atry commented Jul 12, 2023

Good idea! I wonder if the Trainer can fix this loss issue by not passing labels, too.

@sgugger
Copy link
Collaborator

sgugger commented Jul 12, 2023

The Trainer already does divide the loss by the number of gradient accumulation steps and there are tests in the CI to ensure training with batch size X and batch size X / g gradient accumulation steps g yield the same results.

@Atry
Copy link
Author

Atry commented Jul 12, 2023

Suppose you have a dataset of two samples used in unsupervised learning against a decoder-only language model, sample 1 contains 11 tokens, sample 2 contains 101 tokens, when training at batch size 1 without padding, the mean loss of sample 1 is 0.1 and the mean loss of sample 2 is 0.9, then mathematically what's your expected loss when the batch size is 2?

In current transformers implementation:

  • when gradient accumulation step is 1 and batch size is 2, padding to sequence length 101, the loss would be (0.1*10+0.9*100)/(10+100)=0.82727
  • when gradient accumulation step is 2 and batch size is 1, no padding, the loss would be (0.1+0.9)/2=0.5.

IMHO ideally the loss should be 0.82727

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2023

when gradient accumulation step is 1 and batch size is 2, padding to sequence length 101, the loss would be (0.110+0.9100)/(100*2)=0.455

where does 100*2 come from in the denominator?

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2023

I believe in transformers we do take care of the padding token.

If you find a HF causal LM model that has a loss computation (in the model forward) that doesn't take care of the padding token, please let us know. 🙏

@Atry
Copy link
Author

Atry commented Jul 12, 2023

You are right. I misunderstood the implementation. I just updated my previous comments. Thank you!

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2023

Thanks!

As mentioned earlier:

  • you can either compute back the sum from the mean
  • but as you don't like the precision loss in fp16 if using the above way, you can choose not to pass the labels to the model forward, and compute the actual sum.

But

  • (*) you need to modify a bit the code to not to divide by the accumulation step 2, but the total number of non-padding tokens seen in all the batches during that gradient accumulation
  • this necessary change (*) is not possible to be done in the model forward, no matter if we return mean or sum in forward pass.

@getao
Copy link

getao commented Jul 15, 2023

I confronted the same issue. The gradient accumulation's result is much worse than using a large batch size (per device).

The main reason that I assume is probably that the gradient accumulation macro-averages the loss scores, but they should be micro-averaged.

I think this problem is so critical that it affects the result a lot for LMs (variable lengths across batches). Otherwise, the training result must be suboptimal.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants