Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pp validation for schedule #568

Open
wants to merge 2 commits into
base: gh/H-Huang/16/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree

# Validation that the stages are compatible with the schedule
if isinstance(schedule_class, PipelineScheduleSingle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, shouldn't this validation actually go inside of PipelineSchedule* __init__ functions?

Copy link
Member Author

@H-Huang H-Huang Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm yeah for PipelineScheduleMulti init we do check that len(stages) > 2. I guess the issue I ran into was because of this line below:

stages if looped_schedule else stages[0],

I accidentally commented out the schedule config so it defaulted to a single stage schedule, but still had the stages cut such that it was a "looped schedule". Since in the line above, only stage[0] is retrieved the schedule was created correctly, but later in runtime caused a hang.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to be more consistent then we could just support passing in a list of stages for both single and multi, then all the validation could be done in pytorch

if len(stages) != 1:
raise ValueError(
f"PipelineScheduleSingle requires exactly one stage, got {len(stages)}"
)
elif isinstance(schedule_class, PipelineScheduleMulti):
if len(stages) < 2:
raise ValueError(
f"PipelineScheduleMulti requires at least two stages, got {len(stages)}"
)

return schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
Expand Down
Loading