From 13946c4448ccd2ce981b192190251348ccc8a302 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 11 Sep 2024 16:11:25 +0800 Subject: [PATCH] [fp8] hotfix backward hook (#6053) * [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulation --- .../booster/plugin/hybrid_parallel_plugin.py | 22 ++++++++++--------- .../booster/plugin/low_level_zero_plugin.py | 8 ++++--- colossalai/initialize.py | 6 +++++ .../pipeline/schedule/interleaved_pp.py | 2 +- colossalai/pipeline/schedule/one_f_one_b.py | 2 +- colossalai/zero/low_level/low_level_optim.py | 8 +++++-- 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6a333862a909..8e972d0146da 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -216,7 +216,7 @@ def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - with self._wait_all_gather(): + with self._hook_context(): return super().forward(*args, **kwargs) def unwrap(self): @@ -229,12 +229,8 @@ def _force_wait_all_gather(self): for p in self.module.parameters(): wait_all_gather_handle(p) - def _wait_all_gather(self): - return ( - ColoParamOpHookManager.use_hooks(*self.op_hooks) - if (self.overlap_allgather or self.use_fp8) - else nullcontext() - ) + def _hook_context(self): + return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext() def get_param_info(optim: Optimizer): @@ -317,7 +313,8 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + with self.model._hook_context(): + super().backward(loss, *args, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -540,7 +537,8 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + with self.model._hook_context(): + super().backward(loss, *args, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -683,6 +681,7 @@ def __init__( pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, overlap_allgather: bool = False, + fp8_communication: bool = False, ): self.model = model self.param_info = param_info @@ -712,6 +711,8 @@ def __init__( dp_process_group=dp_process_group, forced_dtype=forced_dtype, overlap_allgather=overlap_allgather, + fp8_communication=fp8_communication, + backward_context=model._hook_context, ) def sync_dp_grads(self): @@ -1206,6 +1207,7 @@ def __init__( partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision], overlap_allgather=overlap_allgather, + fp8_communication=fp8_communication, ) self.max_norm = max_norm @@ -1371,7 +1373,7 @@ def execute_pipeline( # so we disable it, performing manual reduction instead. ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - with ctx, model._wait_all_gather(): + with ctx, model._hook_context(): outputs = self.schedule.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 8cc511a5610f..cec15dd5dd34 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -100,14 +100,16 @@ def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext() - with ctx: + with self._hook_context(): return super().forward(*args, **kwargs) def _force_wait_all_gather(self): for p in self.module.parameters(): wait_all_gather_handle(p) + def _hook_context(self): + return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext() + class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): @@ -520,7 +522,7 @@ def configure( if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( - optimizer, **zero_optim_kwargs, verbose=self.verbose + optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 4e2eff7ce352..5414791461c6 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -9,6 +9,7 @@ # https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16 os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +import torch import torch.distributed as dist from colossalai.accelerator import get_accelerator @@ -64,6 +65,11 @@ def launch( set_seed(seed) + try: + torch._dynamo.config.optimize_ddp = world_size > 1 + except AttributeError: + pass + if verbose: logger = get_dist_logger() logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0]) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 28c4eb8d8d4a..c538ee0715b4 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -318,7 +318,7 @@ def forward_step( if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: - accum_loss.add_(loss.detach()) + accum_loss.add_(loss.data) if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 67aaa5eb1b59..0fc90995adcc 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -273,7 +273,7 @@ def forward_step( loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: - accum_loss.add_(loss.detach()) + accum_loss.add_(loss.data) if outputs is not None: outputs.append(tree_map_hf(detach, output_obj)) return loss diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 458e6e41a29e..ed51c2bacafc 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,6 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial from typing import Dict, Iterator, List, Optional, Tuple from weakref import proxy @@ -88,6 +88,7 @@ def __init__( master_weights: bool = True, # master weights overlap_allgather: bool = False, fp8_communication: bool = False, + backward_context=None, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -130,6 +131,7 @@ def __init__( self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype self._fp8_communication = fp8_communication + self._backward_context = backward_context # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -429,7 +431,9 @@ def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + ctx = nullcontext() if self._backward_context is None else self._backward_context() + with ctx: + loss.backward(retain_graph=retain_graph) if not self.require_grad_sync: return