Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PP] Fix PP meta init #582

Open
wants to merge 1 commit into
base: gh/wconstab/39/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
model.norm = None
model.output = None

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
# Note: these tensors are only here as metadata hints, so pipelining runtime knows what size buffer to allocate.
# these tensors should be on meta device, adn the model should also. It will be allocated on device after
# applying all other parallelisms.

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can avoid specifying input/output shapes
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
Expand All @@ -117,18 +119,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
model_config.vocab_size,
)
if is_first:
(input,) = _llama_trace_input(job_config, model_config, device=device)
(input,) = _llama_trace_input(job_config, model_config, device="meta")
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")

if is_last:
output = torch.rand(output_layer_shape, dtype=torch.float32, device=device)
output = torch.rand(output_layer_shape, dtype=torch.float32, device="meta")
else:
# earlier layers (assume all end in a transformer layer)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta")

model.to_empty(device=device)
stage = PipelineStage(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to understand: the device arg for PipelineStage still needs to be the actual device, e.g. cuda, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. And I want to remove that too in PipelineStage but I didn't do it yet.

model,
stage_idx,
Expand Down
Loading