Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add iter time tracking via cuda events, add data loading times, add columnar display to show both, show avg iter & data loading times at end of training #87

Merged
merged 6 commits into from
Feb 26, 2024

Conversation

lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Feb 25, 2024

This PR adds basic perf timing and display for 'per iter' and 'final iter average' display. (in part based on Andrew's comment about having to open the trace to compare iter timing).

  1. tracking list is housed in TrainState, but I do not save it as part of the state dict as I view this as useful but not saveable info.

  2. iter times are tracked after dataloading is done each iter and after optimizer step. The idea is to make this timing expressly the model training iter (not data loading or post iter other metrics calcs).

  3. 'time' is now displayed at each iter along with the usual loss and lr.

  4. at the end of training, assuming more than 3 iters run, then the average iter time is calculated by igoring the first three iters (consider these as warmup esp as cudaCacheAllocator gets warmed up) and displayed.

  5. based on @tianyu-l feedback: I have added data loading times as well.
    I used the same timeit.default_timer() from timeit to be consistent.
    (cpu side so no synch's needed :)

6 - after fiddling with printf width formatting options, added beautiful aligned columnar display for the per iter updates:
Now:
Screenshot 2024-02-26 at 9 39 25 AM

before:
Screenshot 2024-02-26 at 8 39 46 AM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 25, 2024
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be useful! Some comments:

  1. I feel it could be useful to have another version where data loading time is included, so that not only we know the marginal time spent on each iter, but also by the difference we know how much time is used/wasted on data loading. It could be especially valuable to our future exploration on more scalable data loading solutions.
  2. According to https://docs.python.org/3/library/timeit.html, time.perf_counter() and timeit.default_timer() are doing the same thing under hood. Let's consolidate them into one, either is fine.
  3. Optionally we can log this to TensorBoard as well, although I assume it shouldn't fluctuate too much other than the first couple of iters.

train.py Outdated
@@ -207,6 +211,11 @@ def main(job_config: JobConfig):
# updates the scale for next iteration
scaler.update()

# training iteration complete
iter_end_time = perf_counter()
Copy link
Contributor

@yifuwang yifuwang Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that a device synchronization has already taken place at this point? I might be missing something, but I don't see anything that guarantees a synchronization between .backward() and iter_end_time = perf_counter(). Any chance we should be measuring with cuda events here instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss.item is a cpu gpu synch point, so I was thinking the loss Calc above would synch. But that may be incorrect so agree, I'll update to use cuda events to guarantee the timing. Thanks for flagging this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a torch.cuda.synchronize() before iter_time_end = perf_counter() should be good. I agree loss.item() is a sync point, but I think it is only called a few lines after :/

https://github.com/pytorch/torchtrain/blob/eafcee6b5d7156ec2db833c693987927b3698075/train.py#L214-L223

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yifuwang, @awgu and @tianyu-l for all the feedback here!
to update - I went ahead and just moved to cuda events in order to ideally get max precision, and have added the cuda synchronize.
I also tested moving the loss.item() higher to use that as a synch point, but seemed to be cleaner to just stick with pure cuda events and synch.
Anyway, all tested and looks good.
Of interest, the net times were currently the same with and without synch for eager, but that would go out the window as soon as we start using torch.compile, so we definitely do want to keep the synch here.

@lessw2020 lessw2020 changed the title add iter time tracking and display, avg iter time at end of training add iter time tracking and display via cuda events, add data loading times, show avg iter and data loading times at end of training Feb 26, 2024
@lessw2020 lessw2020 changed the title add iter time tracking and display via cuda events, add data loading times, show avg iter and data loading times at end of training add iter time tracking via cuda events, add data loading times, add columnar display to show both, show avg iter & data loading times at end of training Feb 26, 2024
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding these time metrics!

@lessw2020 lessw2020 merged commit ae85e97 into pytorch:main Feb 26, 2024
4 checks passed
@lessw2020 lessw2020 deleted the add_perf_iter_timing branch February 26, 2024 18:16
lessw2020 added a commit that referenced this pull request Apr 18, 2024
…olumnar display to show both, show avg iter & data loading times at end of training (#87)

This PR adds basic perf timing and display for 'per iter' and 'final
iter average' display. (in part based on Andrew's comment about having
to open the trace to compare iter timing).

1. tracking list is housed in TrainState, but I do not save it as part
of the state dict as I view this as useful but not saveable info.
2. iter times are tracked after dataloading is done each iter and after
optimizer step. The idea is to make this timing expressly the model
training iter (not data loading or post iter other metrics calcs).

3. 'time' is now displayed at each iter along with the usual loss and
lr.

4. at the end of training, assuming more than 3 iters run, then the
average iter time is calculated by igoring the first three iters
(consider these as warmup esp as cudaCacheAllocator gets warmed up) and
displayed.
5. based on @tianyu-l feedback: I have added data loading times as well.
I used the same timeit.default_timer() from timeit to be consistent.
(cpu side so no synch's needed :)

6 - after fiddling with printf width formatting options, added beautiful
aligned columnar display for the per iter updates:
Now: 
<img width="1282" alt="Screenshot 2024-02-26 at 9 39 25 AM"
src="https://github.com/pytorch/torchtrain/assets/46302957/9ee2ea7b-5c28-4d41-ba91-d4176c64fc66">

before: 
<img width="1282" alt="Screenshot 2024-02-26 at 8 39 46 AM"
src="https://github.com/pytorch/torchtrain/assets/46302957/37cbfa20-7f1d-4d94-be94-3505ef4498c0">
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
…olumnar display to show both, show avg iter & data loading times at end of training (pytorch#87)

This PR adds basic perf timing and display for 'per iter' and 'final
iter average' display. (in part based on Andrew's comment about having
to open the trace to compare iter timing).

1. tracking list is housed in TrainState, but I do not save it as part
of the state dict as I view this as useful but not saveable info.
2. iter times are tracked after dataloading is done each iter and after
optimizer step. The idea is to make this timing expressly the model
training iter (not data loading or post iter other metrics calcs).

3. 'time' is now displayed at each iter along with the usual loss and
lr.

4. at the end of training, assuming more than 3 iters run, then the
average iter time is calculated by igoring the first three iters
(consider these as warmup esp as cudaCacheAllocator gets warmed up) and
displayed.
5. based on @tianyu-l feedback: I have added data loading times as well.
I used the same timeit.default_timer() from timeit to be consistent.
(cpu side so no synch's needed :)

6 - after fiddling with printf width formatting options, added beautiful
aligned columnar display for the per iter updates:
Now: 
<img width="1282" alt="Screenshot 2024-02-26 at 9 39 25 AM"
src="https://github.com/pytorch/torchtrain/assets/46302957/9ee2ea7b-5c28-4d41-ba91-d4176c64fc66">

before: 
<img width="1282" alt="Screenshot 2024-02-26 at 8 39 46 AM"
src="https://github.com/pytorch/torchtrain/assets/46302957/37cbfa20-7f1d-4d94-be94-3505ef4498c0">
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants