From fb141253b0f06fca8373d68514742447fcb932d1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Sun, 14 Apr 2024 21:06:09 +0800 Subject: [PATCH] [feature] add row residue logic in column parallel factor; --- .../nn/optimizer/distributed_adafactor.py | 72 +++++++++++++------ 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 956bf07604e3..9331da396a65 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -157,20 +157,43 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): return torch.mul(r_factor, c_factor) def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): - update_reshape = update.view(-1, grad_shape[1]) - grad_reshape = grad.view(-1, grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) - exp_avg_sq_row.div_(self.tensor_parallel_size) - update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) + if grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H, W/tp] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W/tp] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + else: - update = update_reshape + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape return update def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): @@ -333,9 +356,14 @@ def step(self, closure=None): if factored: if param_is_dtensor: if shard_spec.sharding_sequence[0] == "R": # Col Parallel - state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] + if grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] @@ -344,16 +372,16 @@ def step(self, closure=None): # Row indivisible shape situation if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] else: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] else: if self.use_zero: # param grad [H // dp]