From e4197a4c1358499fd5ba13651739eafd0c244dfc Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 5 Jun 2024 08:15:46 +0000 Subject: [PATCH] [WIP] further refactor strategy: move grad&comm to base class --- .../booster/plugin/low_level_zero_plugin.py | 30 +- .../low_level/bookkeeping/gradient_store.py | 17 ++ colossalai/zero/low_level/low_level_optim.py | 69 ++--- .../zero/low_level/low_level_strategy.py | 270 ++++++++++-------- 4 files changed, 221 insertions(+), 165 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d21496f0b758..2af6ab32e1f6 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -25,7 +25,9 @@ sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -320,8 +322,34 @@ def configure( model = LowLevelZeroModel(model, self.precision) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + # TODO @botbw: better way of doing this + # seperate param into two groups: norm and extra_moe_dp + assert ( + len(optimizer.param_groups) == 1 and len(optimizer.state) == 0 + ), "This requires the optimizer to be uninitialized" + non_moe_params, moe_params = [], [] + for param in model.parameters(): + if is_moe_tensor(param): + moe_params.append(param) + else: + non_moe_params.append(param) + strategies = None + if len(moe_params) != 0: + print(f"{len(moe_params)=}, {len(non_moe_params)=}") + prev_param_group = optimizer.param_groups[0] + prev_param_group.pop("params") + optimizer.param_groups = [ + {"params": non_moe_params, **prev_param_group}, + {"params": moe_params, **prev_param_group}, + ] + strategies = [ + LowLevelOptStrategy(param_group=optimizer.param_groups[0], **self.zero_optim_kwargs), + MoeZeroStrategy(param_group=optimizer.param_groups[1], **self.zero_optim_kwargs), + ] + if "moe_extra_dp_process_group" in self.zero_optim_kwargs: + self.zero_optim_kwargs.pop("moe_extra_dp_process_group") optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( - optimizer, **self.zero_optim_kwargs, verbose=self.verbose + optimizer, group_strategies=strategies, **self.zero_optim_kwargs, verbose=self.verbose ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 73a1db5a0c0d..ecb711a3e88a 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -104,6 +104,23 @@ def get_working_grad_by_param_id(self, param_id) -> Tensor: raise KeyError(f"Working gradient for param_id {param_id} not found.") + def get_master_grad_by_param_id(self, param_id) -> List[Tensor]: + """ + Return the working gradient for the specified parameter. + + Args: + param_id (int): The index of the parameter. + + Returns: + List[Tensor]: The the working gradient slices for the specified param_id. + """ + + for group in self._grads_of_params.values(): + if param_id in group.keys(): + return group[param_id] + + raise KeyError(f"Working gradient for param_id {param_id} not found.") + def reset_grads_by_group_id(self, group_id: int): self._grads_of_params[group_id] = dict() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index f9b3ae4fc20e..8bb9b397989d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -84,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, + group_strategies: List[LowLevelOptStrategyBase] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -107,10 +108,6 @@ def __init__( self._logger = get_dist_logger() self._verbose = verbose - assert len(self.optim.param_groups) == 1 - self.moe_extra_dp_pg = None # TODO @botbw: refactor this - self._world_size = dist.get_world_size(group=dp_process_group) # TODO @botbw: refactor this - self._local_rank = dist.get_rank(group=dp_process_group) # TODO @botbw: refactor this # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -124,18 +121,23 @@ def __init__( # check argument conflict self._sanity_checks() - self._group_strategies = [ - LowLevelOptStrategy( - param_group=self.optim.param_groups[0], - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grads=partition_grad, - cpu_offload=cpu_offload, - dp_process_group=dp_process_group, - master_weights=master_weights, - ) - ] + if len(self.optim.param_groups) == 1 and group_strategies is None: + group_strategies = [ + LowLevelOptStrategy( + param_group=self.optim.param_groups[0], + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grads=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + master_weights=master_weights, + ) + ] + elif len(self.optim.param_groups) > 1 and group_strategies is None: + raise ValueError("group_strategies must be provided when the optimizer has multiple param groups") + + self._group_strategies = group_strategies # initialize mixed precision mixin self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None @@ -198,34 +200,24 @@ def load_state_dict(self, state_dict: Dict): for strategy in self._group_strategies: strategy.scatter_optim_state(self.optim.state) - # def update_master_params(self, model: nn.Module) -> None: - # """Update master params from working params - - # Args: - # model (nn.Module): The model to update master params - # """ - # for p in model.parameters(): - # p_id = id(p) - # if p_id in self._group_strategies[0]._param_store.working_to_master_param: - # master_param = self._group_strategies[0]._param_store.working_to_master_param[p_id] - # padding_size = self._group_strategies[0]._param_store.get_param_padding_size(p) - # working_param = p.data.view(-1) - # if padding_size > 0: - # working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - # if self.moe_extra_dp_pg is not None and is_moe_tensor(p): - # master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) - # else: - # master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) - # if hasattr(self, "master_moe_params"): - # for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - # master_moe_param.copy_(working_moe_param) + def update_master_params(self, model: nn.Module) -> None: + """Update master params from working params + + Args: + model (nn.Module): The model to update master params + """ + # TODO @botbw: not sure if we should access params directly + all_working_params_strategy = [] + for stategy in self._group_strategies: + all_working_params_strategy.extend(stategy.working_params) + stategy.update_master_params() + assert set(all_working_params_strategy) == set(model.parameters()), "model parameters should be the same" def step(self, closure=None): assert closure is None, "closure is not supported by step()" if not self.require_grad_sync: return - # TODO @botbw: implement this if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): if self._verbose: self._logger.info(f"Found overflow. Skip step") @@ -234,6 +226,7 @@ def step(self, closure=None): strategy.zero_grad() return + # TODO @botbw can be further refactored grad_partition_groups = [] norm_groups = [] for strategy in self._group_strategies: diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py index f66dae2b1f93..ff7afcc557d9 100644 --- a/colossalai/zero/low_level/low_level_strategy.py +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -17,7 +17,7 @@ class LowLevelOptStrategyBase(ABC): """ Base class for low-level optimization strategies, this is to reduce the - coupling between different param group and their process group (parallel settings) + coupling between different param group and corresponding process group This class contains only necessary stores/data for optimizer: 1. params @@ -33,7 +33,18 @@ class LowLevelOptStrategyBase(ABC): # but currently only one is used DEFAULT_STORE_GROUP_ID = 0 - def __init__(self, param_group, process_group, master_weights, partition_grads, cpu_offload, **kwargs): + def __init__( + self, + param_group, + process_group, + master_weights, + partition_grads, + cpu_offload, + overlap_communication, + reduce_bucket_size, + communication_dtype, + **kwargs, + ): # param_group that current strategy is working on self.param_group = param_group self._dtype = self.param_group["params"][0].dtype @@ -48,12 +59,12 @@ def __init__(self, param_group, process_group, master_weights, partition_grads, self._world_size = dist.get_world_size(group=self.process_group) # master weights copy - self._master_weights = master_weights + self._master_weights = master_weights # TODO @botbw: this should be unique across all strategies - self._cpu_offload = cpu_offload + self._cpu_offload = cpu_offload # TODO @botbw: this should be unique across all strategies - # stage 2 TODO @botbw: this should be unique across all strategies - self._partition_grads = partition_grads + # stage 2 + self._partition_grads = partition_grads # TODO @botbw: this should be unique across all strategies # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training @@ -68,12 +79,40 @@ def __init__(self, param_group, process_group, master_weights, partition_grads, group_params.append(param) master_param_current_rank = self._create_master_param_current_rank(group_params) param_group["params"] = master_param_current_rank - self._working_param_group: List[torch.Tensor] = group_params - self._master_param_group_of_current_rank: List[torch.Tensor] = master_param_current_rank + self.working_param_group: List[torch.Tensor] = group_params + self.master_param_group: List[torch.Tensor] = master_param_current_rank # by default this shouldn't be manipulate self.require_grad_sync = True + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # initialize communication stream for + # communication-computation overlapping + if self._overlap_communication: + self._comm_stream = get_accelerator().Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grads: + # we iterate over the working params + # on each param, we register a hook to its AccumulateGrad object + param_group = self.working_param_group + for param in param_group: + if param.requires_grad: + + def _grad_handler(grad, param): + # if run with no_sync context, would not sync grad when backward + if self.require_grad_sync: + self._add_to_bucket(param) + return grad + + param.register_hook(partial(_grad_handler, param=param)) + def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] @@ -143,16 +182,97 @@ def _add_to_bucket(self, param): padding_size = self._param_store.get_param_padding_size(param) self._bucket_store.add_param_grad(LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, param, padding_size) + def _reduce_grad(self): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not self._partition_grads and not self._overlap_communication: + self._sync_grad() + else: + self._run_reduction() + + def _sync_grad(self): + param_group = self.working_param_group + for param in param_group: + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param) + + self._run_reduction() + + def _run_reduction(self): + if self._bucket_store.num_elements_in_bucket() <= 0: + return + + self._bucket_store.build_grad_in_bucket() + + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + + # ready to add other tensors to bucket + self._bucket_store.reset_num_elements_in_bucket() + + if self._overlap_communication: + stream = self._comm_stream + # in case of the memory being reused in the default stream + flat_grads.record_stream(stream) + # waiting for ops in the default stream finishing + stream.wait_stream(get_accelerator().current_stream()) + else: + stream = get_accelerator().current_stream() + + with get_accelerator().stream(stream): + group_id = self._bucket_store.current_group_id + assert group_id == LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, "after refactoring, group_id should be 0" + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grads: + dist.all_reduce(flat_grads, group=self.process_group) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.process_group) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + + self._bucket_store.reset() + ###################################################################### # interfaces for child classes to manipulate the params, grads and buckets (and their stores) @property def master_params(self): - return self._master_param_group_of_current_rank + return self.master_param_group @property def working_grads(self): return self._grad_store.get_working_grads_by_group_id(LowLevelOptStrategyBase.DEFAULT_STORE_GROUP_ID) + def get_param_padding_size(self, param): + return self._param_store.get_param_padding_size(param) + + def get_working_param_grads(self, working_param): + return self._grad_store.get_partitioned_gradients_by_param_id( + LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, id(working_param) + ) + + def update_master_params(self, working_param): + for working_param, master_param in zip(self.working_params, self.master_params): + padding_size = self.get_param_padding_size(working_param) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + def get_grad_norm(self, norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -193,7 +313,7 @@ def get_grad_norm(self, norm_type: int = 2) -> float: return total_norm def zero_grad(self, set_to_none=True): - param_group = self._working_param_group + param_group = self.working_param_group for param in param_group: if set_to_none: param.grad = None @@ -281,34 +401,6 @@ def __init__( **kwargs, ) - # communication params - self._overlap_communication = overlap_communication - self._reduce_bucket_size = reduce_bucket_size - self._communication_dtype = communication_dtype - - # initialize communication stream for - # communication-computation overlapping - if self._overlap_communication: - self._comm_stream = get_accelerator().Stream() - - # reduction hook is only used if overlapping communication - # or stage 2 is used - # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_communication or self._partition_grads: - # we iterate over the working params - # on each param, we register a hook to its AccumulateGrad object - param_group = self._working_param_group - for param in param_group: - if param.requires_grad: - - def _grad_handler(grad, param): - # if run with no_sync context, would not sync grad when backward - if self.require_grad_sync: - self._add_to_bucket(param) - return grad - - param.register_hook(partial(_grad_handler, param=param)) - # temporary variables self.__saved_master_params = None self.__saved_working_params = None @@ -332,8 +424,6 @@ def post_backward(self): if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() - def pre_backward_by_grad(self, tensor, grad): assert not ( self._partition_grads and not self.require_grad_sync @@ -348,16 +438,11 @@ def pre_step(self) -> None: # and should not be updated grad_index = 0 if self._partition_grads else self._local_rank real_master_params, real_working_params = [], [] - for working_param, master_param in zip(self._working_param_group, self._master_param_group_of_current_rank): - assert ( - master_param is self._param_store.working_to_master_param[id(working_param)] - ), f"sanity check @botbw: wrong refactor" + for working_param, master_param in zip(self.working_param_group, self.master_param_group): # if a working param requires grad and has no grad # it is not 'really' working, e.g. the droped layer # else the splited grad should be attached to the splited param - grads = self._grad_store.get_partitioned_gradients_by_param_id( - LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, id(working_param) - ) + grads = self.get_working_param_grads(working_param) if len(grads) > 0: real_master_params.append(master_param) real_working_params.append(working_param) @@ -374,101 +459,34 @@ def pre_step(self) -> None: # @botbw: to me, it seems like the original author only wants to keep the "real_xxx_params" when do the optimizer # computation, and add "non real_xxx_params" back after since we might still need them for checkpoint # not sure if it's necessary since None grads don't really bring lots of overhead - self.__saved_working_params = self._working_param_group - self.__saved_master_params = self._master_param_group_of_current_rank - self._working_param_group = real_working_params - self._master_param_group_of_current_rank = self.param_group["params"] = real_master_params + self.__saved_working_params = self.working_param_group + self.__saved_master_params = self.master_param_group + self.working_param_group = real_working_params + self.master_param_group = self.param_group["params"] = real_master_params def post_step(self): - release_param_grad(self._master_param_group_of_current_rank) + release_param_grad(self.master_param_group) # update working partition updated by the current rank device = get_accelerator().get_current_device() - for working_param, master_param in zip(self._working_param_group, self._master_param_group_of_current_rank): + for working_param, master_param in zip(self.working_param_group, self.master_param_group): all_splited_param = [ torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.process_group) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - # restore - self._working_param_group = self.__saved_working_params - self._master_param_group_of_current_rank = self.__saved_master_params + # restore saved values + self.working_param_group = self.__saved_working_params + self.master_param_group = self.__saved_master_params self.__saved_master_params = self.__saved_working_params = None - - self.param_group["params"] = self._master_param_group_of_current_rank + self.param_group["params"] = self.master_param_group ###################################################################### - def _reduce_grad(self): - # if not overlapping communication (no reduction hook is attached) when zero1 - # we need to manually reduce these gradients - if not self._partition_grads and not self._overlap_communication: - self._sync_grad() - else: - self._run_reduction() - - def _sync_grad(self): - param_group = self._working_param_group - for param in param_group: - if param.requires_grad and param.grad is not None: - self._add_to_bucket(param) - - self._run_reduction() - - def _run_reduction(self): - if self._bucket_store.num_elements_in_bucket() <= 0: - return - - self._bucket_store.build_grad_in_bucket() - - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size - - # ready to add other tensors to bucket - self._bucket_store.reset_num_elements_in_bucket() - - if self._overlap_communication: - stream = self._comm_stream - # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) - # waiting for ops in the default stream finishing - stream.wait_stream(get_accelerator().current_stream()) - else: - stream = get_accelerator().current_stream() - - with get_accelerator().stream(stream): - group_id = self._bucket_store.current_group_id - assert group_id == LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID, "after refactoring, group_id should be 0" - - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) - - if not self._partition_grads: - dist.all_reduce(flat_grads, group=self.process_group) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) - else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.process_group) - - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) - - self._bucket_store.reset() - class MoeZeroStrategy(LowLevelOptStrategy): def __init__(self, param_group, *args, **kwargs): + for param in param_group["params"]: + assert is_moe_tensor(param), f"Mixture-of-Experts parameters are required for MoeZeroStrategy {type(param)}" super().__init__(*args, param_group=param_group, **kwargs) - for param in param_group: - assert is_moe_tensor(param), "Mixture-of-Experts parameters are required for MoeZeroStrategy"