Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 17, 2024
1 parent 87746ec commit 01d0f95
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 137 deletions.
8 changes: 4 additions & 4 deletions colossalai/nn/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .adafactor import Adafactor
from .came import CAME
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
from .distributed_came import DistributedCAME
from .distributed_lamb import DistributedLamb
from .fused_adam import FusedAdam
Expand All @@ -8,8 +10,6 @@
from .hybrid_adam import HybridAdam
from .lamb import Lamb
from .lars import Lars
from .adafactor import Adafactor
from .distributed_adafactor import DistributedAdaFactor

__all__ = [
"FusedLAMB",
Expand All @@ -22,6 +22,6 @@
"DistributedLamb",
"CAME",
"DistributedCAME",
"Adafactor",
"DistributedAdaFactor"
"Adafactor",
"DistributedAdaFactor",
]
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down
120 changes: 60 additions & 60 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim

from colossalai.interface.optimizer import DistributedOptim
from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor

Expand Down Expand Up @@ -49,14 +49,14 @@ def __init__(
self.data_parallel_group = None
self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor}
self.use_zero = True
self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}

self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}
self.use_first_moment_dict = {} # {id(p): True/False}
self.shard_spec_dict = {} # {id(p): ShardSpec}
super().__init__(params, defaults)

def setup_distributed(
self,
tensor_parallel_group: dist.ProcessGroup = None,
Expand All @@ -82,19 +82,21 @@ def setup_distributed(
if self.data_parallel_group is not None:
self.data_parallel_size = dist.get_world_size(self.data_parallel_group)
self.use_zero = use_zero

self.shard_to_param = shard_to_param if shard_to_param is not None else {}
# grad is None, cause we dont setup now
for group in self.param_groups:
for p in group["params"]:
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p)))
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)])
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(
group, self.grad_shape_dict[id(p)]
)
if self.param_is_dtensor_dict[id(p)]:
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p)))
else:
self.shard_spec_dict[id(p)] = None

@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
Expand Down Expand Up @@ -123,7 +125,7 @@ def _get_options(param_group, param_shape):
def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group):
tensor_sum = tensor.pow(2).sum()
num_of_element = tensor.numel()

if param_is_dtensor:
# reduce tensor_sum from tp_group
dist.all_reduce(tensor_sum, group=tp_group)
Expand Down Expand Up @@ -151,25 +153,21 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam):
r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H, W/tp]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W/tp]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update_reshape.mul_(grad_reshape)
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
Expand All @@ -181,25 +179,21 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
exp_avg_sq_row.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)

if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update

def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -225,9 +219,7 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group)
exp_avg_sq_col.div_(self.tensor_parallel_size)
# gather row
exp_avg_sq_row_gather = _gather(
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
Expand All @@ -236,24 +228,20 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
else:
update = update_reshape
return update

def _base_factor(self, update, grad, state, grad_shape, beta2t):
if self.use_zero:
# only zero
if grad_shape[0] % self.data_parallel_size != 0:
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# col mean need reduce and div
# gather update[flatten] along dp group then reshape to [H, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -268,8 +256,8 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
else:
# no residual row
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
Expand All @@ -279,7 +267,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
exp_avg_sq_col.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = update_reshape.view(-1)
update = update_reshape.view(-1)
else:
# base factor; no tp, no dp
exp_avg_sq_row = state["exp_avg_sq_row"]
Expand All @@ -292,7 +280,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
return update

@torch.no_grad()
def step(self, closure=None):
"""
Expand Down Expand Up @@ -327,7 +315,7 @@ def step(self, closure=None):
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = self.grad_shape_dict[id(p)]
param_is_dtensor = self.param_is_dtensor_dict[id(p)]
Expand All @@ -347,11 +335,11 @@ def step(self, closure=None):
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
) # [H]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
Expand All @@ -361,23 +349,27 @@ def step(self, closure=None):
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
) # [H/tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]
) # [H/dp/tp]

state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
else:
if self.use_zero:
if grad_shape[0] % self.data_parallel_size != 0:
# save all exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H // dp]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype)
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
Expand All @@ -399,7 +391,7 @@ def step(self, closure=None):
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]

if factored:
if param_is_dtensor:
# ==============================
Expand All @@ -415,16 +407,24 @@ def step(self, closure=None):
elif shard_spec.sharding_sequence[-1] == "R":
update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t)
else:
update = self._base_factor(update, grad, state, grad_shape, beta2t)
update = self._base_factor(update, grad, state, grad_shape, beta2t)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

# # (Line No.8) RMS
rms = self._rms(update, param_is_dtensor, self.use_zero,self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group)
rms = self._rms(
update,
param_is_dtensor,
self.use_zero,
self.tensor_parallel_size,
self.data_parallel_size,
self.tensor_parallel_group,
self.data_parallel_group,
)
update.div_((rms / group["clip_threshold"]).clamp_(min=1.0))

update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
Expand All @@ -433,7 +433,7 @@ def step(self, closure=None):

if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))

p.add_(-update)

return loss
Loading

0 comments on commit 01d0f95

Please sign in to comment.