Skip to content

Commit

Permalink
[feature] add row residue logic in column parallel factor;
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Apr 14, 2024
1 parent 02ea83e commit fb14125
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,43 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam):
return torch.mul(r_factor, c_factor)

def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group)
exp_avg_sq_row.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H, W/tp]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W/tp]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape

else:
update = update_reshape
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group)
exp_avg_sq_row.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update

def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
Expand Down Expand Up @@ -333,9 +356,14 @@ def step(self, closure=None):
if factored:
if param_is_dtensor:
if shard_spec.sharding_sequence[0] == "R": # Col Parallel
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
Expand All @@ -344,16 +372,16 @@ def step(self, closure=None):
# Row indivisible shape situation
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/dp/Tp]
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/Tp]
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]

state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
else:
if self.use_zero:
# param grad [H // dp]
Expand Down

0 comments on commit fb14125

Please sign in to comment.