Skip to content

Commit

Permalink
[feature] Add transformer-bert to testcase;solve a bug related to ind…
Browse files Browse the repository at this point in the history
…ivisible shape (induction in use_zero and tp is row parallel);
  • Loading branch information
duanjunwen committed Apr 9, 2024
1 parent d5f72fe commit 020ed54
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 145 deletions.
1 change: 1 addition & 0 deletions colossalai/nn/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down
111 changes: 71 additions & 40 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import torch
import torch.distributed as dist
from torch.optim import Optimizer
# from torch.optim import Optimizer
from colossalai.interface.optimizer import DistributedOptim

from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor

# DistributedAdaFactor (with Tensor parallel and Zero stage 2)
__all__ = ["DistributedAdaFactor"]


class DistributedAdaFactor(Optimizer):
class DistributedAdaFactor(DistributedOptim):
def __init__(
self,
params,
Expand All @@ -26,6 +27,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand All @@ -42,7 +44,6 @@ def __init__(
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)
self.tensor_parallel_size = 1
self.tensor_parallel_group = None
self.data_parallel_size = 1
Expand All @@ -53,13 +54,14 @@ def __init__(
self.factored = None # bool
self.use_first_moment = None # bool
self.use_zero = True
self.is_dist = {}
super().__init__(params, defaults)


def setup_distributed(
self,
tensor_parallel_group: dist.ProcessGroup = None,
data_parallel_group: dist.ProcessGroup = None,
shard_to_param: Dict = None,
shard_to_param: Dict = {},
use_zero: bool = True,
) -> None:
"""
Expand All @@ -82,6 +84,7 @@ def setup_distributed(
self.use_zero = use_zero

self.shard_to_param = shard_to_param if shard_to_param is not None else {}


@staticmethod
def _get_lr(param_group, param_state):
Expand Down Expand Up @@ -161,7 +164,9 @@ def step(self, closure=None):
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
self.grad_shape = grad.shape # 1 dim shape


# print(f"self.shard_to_param {self.shard_to_param}")

param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p)))

if param_is_dtensor:
Expand All @@ -184,9 +189,16 @@ def step(self, closure=None):
) # [W/TP]

if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel
state["exp_avg_sq_row"] = torch.zeros(
# Row Residual situation
if self.grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
self.grad_shape[0], device=p.device, dtype=p.dtype
) # [H/dp/Tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/Tp]
) # [H/dp/Tp]

state["exp_avg_sq_col"] = torch.zeros(
self.grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
Expand All @@ -202,10 +214,6 @@ def step(self, closure=None):
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)

p_data_fp32 = p.float()
# if p.dtype in {torch.float16, torch.bfloat16}:
# p_data_fp32 = p_data_fp32.float()

state["step"] += 1
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
Expand All @@ -228,7 +236,6 @@ def step(self, closure=None):
exp_avg_sq_row.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
# update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1])
if self.use_zero:
update = update_reshape.view(-1)
else:
Expand All @@ -238,27 +245,54 @@ def step(self, closure=None):
# Row Parallel ---> sq_col need Do (row) Reduce
# ==============================
elif self.shard_spec.sharding_sequence[-1] == "R":
update_reshape = update.view(-1, self.grad_shape[1])
grad_reshape = grad.view(-1, self.grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group)
exp_avg_sq_col.div_(self.tensor_parallel_size)
# gather row
exp_avg_sq_row_gather = _gather(
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
# update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1])
if self.use_zero:
update = update_reshape.view(-1)
# Row Residual situation
if self.grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
# view update to origin[tp] shape
update_reshape = update.view(-1, self.grad_shape[1])

# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad_reshape = grad.view(-1, self.grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group)
exp_avg_sq_col.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group)
else:
update = update_reshape
else:
update = update_reshape
update_reshape = update.view(-1, self.grad_shape[1])
grad_reshape = grad.view(-1, self.grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group)
exp_avg_sq_col.div_(self.tensor_parallel_size)
# gather row
exp_avg_sq_row_gather = _gather(
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
Expand All @@ -273,12 +307,9 @@ def step(self, closure=None):
update = exp_avg

if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)

p.add_(p, alpha=(-group["weight_decay"] * lr))

p.add_(-update)

if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)

return loss
Loading

0 comments on commit 020ed54

Please sign in to comment.