diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 57d677ef0059..22a6c8f4d3ce 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -36,7 +36,7 @@ def __init__( relative_step=True, warmup_init=False, ): - lr=None + lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index ccda1f3e65b1..3f455c1c5b41 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,9 +3,9 @@ import torch import torch.distributed as dist + # from torch.optim import Optimizer from colossalai.interface.optimizer import DistributedOptim - from colossalai.shardformer.layer._operation import _gather, _split from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor @@ -50,14 +50,13 @@ def __init__( self.data_parallel_group = None self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} self.use_zero = True - - self.param_is_dtensor_dict = {} # {id(p): True/False} - self.grad_shape_dict = {} # {id(p): master param shape} - self.factored_dict = {} # {id(p): True/False} + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} self.use_first_moment_dict = {} # {id(p): True/False} self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) - def setup_distributed( self, @@ -84,24 +83,26 @@ def setup_distributed( if self.data_parallel_group is not None: self.data_parallel_size = dist.get_world_size(self.data_parallel_group) self.use_zero = use_zero - + self.shard_to_param = shard_to_param if shard_to_param is not None else {} # grad is None, cause we dont setup now for group in self.param_groups: for p in group["params"]: self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) if self.param_is_dtensor_dict[id(p)]: - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape else: # no tp; could be zero or not zero # self.grad_shape_dict[id(p)] = p.shape - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape - self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options( + group, self.grad_shape_dict[id(p)] + ) if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) else: self.shard_spec_dict[id(p)] = None - + @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] @@ -130,7 +131,7 @@ def _get_options(param_group, param_shape): def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): tensor_sum = tensor.pow(2).sum() num_of_element = tensor.numel() - + if param_is_dtensor: # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) @@ -162,30 +163,26 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) - + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H, W/tp] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W/tp] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H] exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) + update_reshape.mul_(grad_reshape) if self.use_zero: update = update_reshape.view(-1) else: update = update_reshape - + else: update_reshape = update.view(-1, grad_shape[1]) grad_reshape = grad.view(-1, grad_shape[1]) @@ -202,19 +199,15 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): else: update = update_reshape return update - + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H/tp, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H/tp, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -240,9 +233,7 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) exp_avg_sq_col.div_(self.tensor_parallel_size) # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group) sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) update_reshape.mul_(grad_reshape) @@ -251,24 +242,20 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): else: update = update_reshape return update - + def _base_factor(self, update, grad, state, grad_shape, beta2t): if self.use_zero: # only zero if grad_shape[0] % self.data_parallel_size != 0: - # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) - # row mean no change + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change # col mean need reduce and div # gather update[flatten] along dp group then reshape to [H, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -283,8 +270,8 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): else: # no residual row # view update to origin[tp] shape - update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] - grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -294,7 +281,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_col.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - update = update_reshape.view(-1) + update = update_reshape.view(-1) else: # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] @@ -307,9 +294,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) return update - - - + @torch.no_grad() def step(self, closure=None): """ @@ -344,7 +329,7 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") - + state = self.state[p] grad_shape = self.grad_shape_dict[id(p)] param_is_dtensor = self.param_is_dtensor_dict[id(p)] @@ -364,11 +349,11 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H] + ) # [H] else: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] @@ -378,23 +363,27 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/tp] + ) # [H/tp] else: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/tp] - + ) # [H/dp/tp] + state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] else: if self.use_zero: if grad_shape[0] % self.data_parallel_size != 0: # save all exp_avg_sq_row [H] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H // dp] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H] state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) @@ -432,16 +421,23 @@ def step(self, closure=None): elif shard_spec.sharding_sequence[-1] == "R": update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) else: - update = self._base_factor(update, grad, state, grad_shape, beta2t) + update = self._base_factor(update, grad, state, grad_shape, beta2t) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - + # # (Line No.8) RMS - rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) + rms = self._rms( + update, + param_is_dtensor, + self.tensor_parallel_size, + self.data_parallel_size, + self.tensor_parallel_group, + self.data_parallel_group, + ) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) - + update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] @@ -450,8 +446,7 @@ def step(self, closure=None): if group["weight_decay"] != 0: p.add_(p, alpha=(-group["weight_decay"] * lr)) - + p.add_(-update) - return loss diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 5a8d8ebade85..6553c0b03f7d 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -1,20 +1,20 @@ # Distributed Adafactor -Author: +Author: **Related Paper** - [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) ## Introduction -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. ### API Reference {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice -We now demonstrate how to start Distributed Adafactor with booster API. +We now demonstrate how to start Distributed Adafactor with booster API. ### step 1. Import libraries ```python @@ -65,9 +65,9 @@ dist_optim = DistributedAdaFactor(model.parameters()) ```python plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) -model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) ``` -### step 5.Train Your Model +### step 5.Train Your Model ```python for epoch in range(max_epochs): for input_ids, attention_mask in dataloader: @@ -106,7 +106,7 @@ Model/Feature Compatibility Matrix: