Skip to content

Commit

Permalink
Update zero_bubble_pp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 authored Sep 12, 2024
1 parent 24727de commit 9802a7d
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,7 @@ def backward_b_step(
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None

if "hidden_states" in input_obj.keys():
input_obj_ = input_obj["hidden_states"]
else:
input_obj_ = input_obj["input_ids"]

input_obj_ = input_obj["hidden_states"]
if output_obj_grad is None:
optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True)
else:
Expand Down

0 comments on commit 9802a7d

Please sign in to comment.