Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix
  • Loading branch information
flybird11111 committed Sep 13, 2024
1 parent 37d9623 commit 5965f8b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -11,9 +12,6 @@
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from contextlib import nullcontext

from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule

Expand Down Expand Up @@ -487,7 +485,13 @@ def backward_b_step(
assert output_obj_grad is None

input_obj_ = input_obj["hidden_states"]
ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext()

# Attempt to disable gradient synchronization when using the LowLevelZeroPlugin.
try:
ctx = optimizer.no_sync()
except Exception as e:
ctx = nullcontext()

with ctx:
if output_obj_grad is None:
optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True)
Expand Down

0 comments on commit 5965f8b

Please sign in to comment.