Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Sep 13, 2024
1 parent 9802a7d commit 9c59e6c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 14 additions & 10 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from contextlib import nullcontext

from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
Expand Down Expand Up @@ -485,16 +487,18 @@ def backward_b_step(
assert output_obj_grad is None

input_obj_ = input_obj["hidden_states"]
if output_obj_grad is None:
optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True)
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad,
inputs=input_obj_,
retain_graph=True,
)
ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext()
with ctx:
if output_obj_grad is None:
optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True)
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad,
inputs=input_obj_,
retain_graph=True,
)
return input_obj_.grad

def backward_w_step(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 0,
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
Expand Down

0 comments on commit 9c59e6c

Please sign in to comment.