-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PP] Fix PP meta init #582
base: gh/wconstab/39/base
Are you sure you want to change the base?
Conversation
Uses meta device for tensors/model used before pipeline splitting. *Important:* Relies on pytorch/pytorch#136243 to make PipelineStage avoid materializing the model and the input/output buffers eagerly. Relies on existing .to(device) in train.py to finally materialize the model. ghstack-source-id: 66fa9f1f78dff0b1af753dc4b2afcc09d897751d Pull Request resolved: #582
|
||
model.to_empty(device=device) | ||
stage = PipelineStage( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to understand: the device
arg for PipelineStage
still needs to be the actual device, e.g. cuda
, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. And I want to remove that too in PipelineStage but I didn't do it yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm
Curious -- why would train.py make a .to call? Should init_weight create tensor on the right device directly? Or, if we are loading from DCP, would DCP return a state dict with DTensors on target device or just a state dict with DTensors on CPU? |
Stack from ghstack (oldest at bottom):
Uses meta device for tensors/model used before pipeline splitting.
Important:
Relies on pytorch/pytorch#136243 to make PipelineStage avoid
materializing the model and the input/output buffers eagerly.
Relies on existing .to(device) in train.py to finally materialize the
model.