Skip to content

Commit

Permalink
[fp8] hotfix backward hook (#6053)
Browse files Browse the repository at this point in the history
* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
  • Loading branch information
ver217 authored Sep 11, 2024
1 parent c54c4fc commit 13946c4
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 17 deletions.
22 changes: 12 additions & 10 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
with self._hook_context():
return super().forward(*args, **kwargs)

def unwrap(self):
Expand All @@ -229,12 +229,8 @@ def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)

def _wait_all_gather(self):
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()


def get_param_info(optim: Optimizer):
Expand Down Expand Up @@ -317,7 +313,8 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""

# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -540,7 +537,8 @@ def backward(self, loss: Tensor, *args, **kwargs):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -683,6 +681,7 @@ def __init__(
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
self.model = model
self.param_info = param_info
Expand Down Expand Up @@ -712,6 +711,8 @@ def __init__(
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)

def sync_dp_grads(self):
Expand Down Expand Up @@ -1206,6 +1207,7 @@ def __init__(
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)

self.max_norm = max_norm
Expand Down Expand Up @@ -1371,7 +1373,7 @@ def execute_pipeline(
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx, model._wait_all_gather():
with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
Expand Down
8 changes: 5 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
with ctx:
with self._hook_context():
return super().forward(*args, **kwargs)

def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)

def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
Expand Down Expand Up @@ -520,7 +522,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
6 changes: 6 additions & 0 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"

import torch
import torch.distributed as dist

from colossalai.accelerator import get_accelerator
Expand Down Expand Up @@ -64,6 +65,11 @@ def launch(

set_seed(seed)

try:
torch._dynamo.config.optimize_ddp = world_size > 1
except AttributeError:
pass

if verbose:
logger = get_dist_logger()
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def forward_step(
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
accum_loss.add_(loss.data)
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def forward_step(
loss = criterion(output_obj, micro_batch) / self.num_microbatches

if accum_loss is not None:
accum_loss.add_(loss.detach())
accum_loss.add_(loss.data)
if outputs is not None:
outputs.append(tree_map_hf(detach, output_obj))
return loss
Expand Down
8 changes: 6 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
fp8_communication: bool = False,
backward_context=None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)

Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
self._backward_context = backward_context

# gradient clipping
self._clip_grad_norm = clip_grad_norm
Expand Down Expand Up @@ -429,7 +431,9 @@ def backward(self, loss, retain_graph=False):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

loss.backward(retain_graph=retain_graph)
ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx:
loss.backward(retain_graph=retain_graph)

if not self.require_grad_sync:
return
Expand Down

0 comments on commit 13946c4

Please sign in to comment.