From 9c59e6cae311c7f6cca33d2ddcb6c1b7ac0390d2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 02:51:28 +0000 Subject: [PATCH] fix --- .../pipeline/schedule/zero_bubble_pp.py | 24 +++++++++++-------- .../test_model/test_shard_llama.py | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 759a6144c169..5b4092bcefe3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -11,6 +11,8 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from contextlib import nullcontext from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -485,16 +487,18 @@ def backward_b_step( assert output_obj_grad is None input_obj_ = input_obj["hidden_states"] - if output_obj_grad is None: - optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - else: - output_obj_ = output_obj["hidden_states"] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad, - inputs=input_obj_, - retain_graph=True, - ) + ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext() + with ctx: + if output_obj_grad is None: + optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + inputs=input_obj_, + retain_graph=True, + ) return input_obj_.grad def backward_w_step( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d925687cd875..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -299,7 +299,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 0, + "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, "parallel_output": False,