Skip to content

Commit

Permalink
test distributed came passed
Browse files Browse the repository at this point in the history
  • Loading branch information
chongqichuizi875 committed Apr 7, 2024
1 parent 82b04cb commit f95d875
Show file tree
Hide file tree
Showing 9 changed files with 705 additions and 303 deletions.
8 changes: 8 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptimizer
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand Down Expand Up @@ -1166,6 +1167,13 @@ def configure(
**self.zero_config,
**self.amp_config,
)
# Setup optimizers that require global states
if isinstance(optimizer.optim, DistributedOptimizer):
tp_group = self.__dict__.get("tp_group", None)
dp_group = self.__dict__.get("dp_group", None)
master_to_working_map = optimizer.get_master_to_working_map()
zero_flag = self.zero_stage > 0
optimizer.optim.setup_distributed(master_to_working_map, tp_group, dp_group, zero_flag=zero_flag)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
7 changes: 7 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptimizer
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -325,6 +326,12 @@ def configure(
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
if isinstance(optimizer.optim, DistributedOptimizer):
tp_group = self.__dict__.get("tp_group", None)
dp_group = self.__dict__.get("dp_group", None)
master_to_working_map = optimizer.get_master_to_working_map()
optimizer.optim.setup_distributed(master_to_working_map, tp_group, dp_group, zero_flag=True)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
24 changes: 24 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.optim import Optimizer


Expand Down Expand Up @@ -133,3 +134,26 @@ def unwrap(self):
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim


class DistributedOptimizer(Optimizer):
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
)
super(DistributedOptimizer, self).__init__(params, defaults)

def setup_distributed(self, master_to_working_map: dict, tp_group: ProcessGroup, zero_group: ProcessGroup):
pass
140 changes: 19 additions & 121 deletions colossalai/nn/optimizer/came.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import torch
import torch.distributed as dist
from torch.optim import Optimizer
import torch.optim

from colossalai.tensor.d_tensor import api


class CAME(Optimizer):
class CAME(torch.optim.Optimizer):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
Expand All @@ -30,25 +27,10 @@ def __init__(
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
tp_process_group=None,
zero_process_group=None,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])

self.tensor_parallel_group = tp_process_group
self.zero_parallel_group = zero_process_group
self.tensor_parallel_world_size = (
dist.get_world_size(group=self.tensor_parallel_group) if tp_process_group else 1
)
self.zero_parallel_world_size = dist.get_world_size(group=self.zero_parallel_group) if zero_process_group else 1
combined_parallel_ranks = []
for tp_rank in range(self.tensor_parallel_world_size):
for zero_rank in range(self.zero_parallel_world_size):
combined_parallel_ranks.append(tp_rank * self.zero_parallel_world_size + zero_rank)
combined_parallel_ranks = [dist.get_rank()] if len(combined_parallel_ranks) == 1 else combined_parallel_ranks
self.combined_parallel_group = dist.new_group(combined_parallel_ranks)

defaults = dict(
lr=lr,
eps=eps,
Expand All @@ -58,18 +40,6 @@ def __init__(
)
super(CAME, self).__init__(params, defaults)

self.clip_method = dict()
self.ori_shape = dict()
for group in self.param_groups:
for p in group["params"]:
self.ori_shape[id(p)] = p.data.shape
try:
api.get_device_mesh(p)
sharding_spec = api.get_sharding_spec(p)
self.clip_method[id(p)] = "col" if 0 in sharding_spec.dim_partition_dict.keys() else "row"
except:
self.clip_method[id(p)] = None

@property
def supports_memory_efficient_fp16(self):
return True
Expand All @@ -83,50 +53,13 @@ def _get_options(self, param_shape):
return factored

def _rms(self, tensor):
# return tensor.norm(2) / (tensor.numel() ** 0.5)
# 计算当前设备上张量的平方和
local_sum_sq = tensor.pow(2).sum()

# 在所有设备上汇总平方和
global_sum_sq = local_sum_sq.clone()
dist.all_reduce(global_sum_sq, op=dist.ReduceOp.SUM, group=self.combined_parallel_group)

# 在所有设备上汇总元素总数
local_numel = torch.tensor(tensor.numel(), device=tensor.device)
global_numel = local_numel.clone()
dist.all_reduce(global_numel, op=dist.ReduceOp.SUM, group=self.combined_parallel_group)

# 计算 RMS
rms = (global_sum_sq / global_numel).sqrt()
return rms

def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, clip_method):
exp_avg_sq_row_mean = exp_avg_sq_row.mean(dim=-1, keepdim=True)
if clip_method == "col":
group = self.combined_parallel_group
else:
group = self.zero_parallel_group
dist.all_reduce(exp_avg_sq_row_mean, op=dist.ReduceOp.SUM, group=group)
exp_avg_sq_row_mean /= dist.get_world_size(group=group)

r_factor = (exp_avg_sq_row / exp_avg_sq_row_mean).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return tensor.norm(2) / (tensor.numel() ** 0.5)

def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

def _unflatten_grad_tensor_by_param(self, param):
ori_shape = self.ori_shape[id(param)]
if not (len(ori_shape) >= 2 and len(param.grad.data.shape) == 1):
return param.grad.data
remaining_dims = ori_shape[1:]
return param.grad.data.reshape(-1, *remaining_dims)

def _flatten_update_tensor_by_param(self, param, tensor):
ori_shape = self.ori_shape[id(param)]
if not (len(ori_shape) >= 2 and len(param.grad.data.shape) == 1):
return tensor
return torch.flatten(tensor)

def step(self, closure=None):
"""Performs a single optimization step.
Args:
Expand All @@ -141,14 +74,13 @@ def step(self, closure=None):
for p in group["params"]:
if p.grad is None:
continue
grad = self._unflatten_grad_tensor_by_param(p)
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("CAME does not support sparse gradients.")

state = self.state[p]
# zero下grad_shape是原始grad经过flatten后再切割(只有一维)
grad_shape = grad.shape

factored = self._get_options(grad_shape)
Expand Down Expand Up @@ -176,79 +108,45 @@ def step(self, closure=None):
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

# 局部平均
sq_mean_row = update.mean(dim=-1)
sq_mean_col = update.mean(dim=-2)
if self.tensor_parallel_world_size > 1:
# 全局同步
if self.clip_method[id(p)] == "row":
dist.all_reduce(sq_mean_row, op=dist.ReduceOp.SUM, group=self.tensor_parallel_group)
sq_mean_row /= dist.get_world_size(group=self.tensor_parallel_group)
elif self.clip_method[id(p)] == "col":
dist.all_reduce(sq_mean_col, op=dist.ReduceOp.SUM, group=self.tensor_parallel_group)
sq_mean_col /= dist.get_world_size(group=self.tensor_parallel_group)
else:
pass
if self.zero_parallel_world_size > 1:
dist.all_reduce(sq_mean_col, op=dist.ReduceOp.SUM, group=self.zero_parallel_group)
sq_mean_col /= dist.get_world_size(group=self.zero_parallel_group)

# 得到的exp_avg是完整exp_avg的切割
exp_avg_sq_row.mul_(group["betas"][1]).add_(sq_mean_row, alpha=1.0 - group["betas"][1])
exp_avg_sq_col.mul_(group["betas"][1]).add_(sq_mean_col, alpha=1.0 - group["betas"][1])
exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])
exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])

# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, clip_method=self.clip_method[id(p)])
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
# bias执行这个部分
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
update = exp_avg_sq.rsqrt().mul_(grad)
# if dist.get_rank() == 0:
# print("came: ", torch.sum(grad), grad)

# update也为完整update的切割
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))

exp_avg = state["exp_avg"]
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])

# Confidence-guided strategy
# Calculation of instability
res = (update - exp_avg) ** 2 + group["eps"][0]
res = (update - exp_avg) ** 2 + group["eps"][1]

if factored:
exp_avg_res_row = state["exp_avg_res_row"]
exp_avg_res_col = state["exp_avg_res_col"]

res_mean_row = res.mean(dim=-1)
res_mean_col = res.mean(dim=-2)
if self.tensor_parallel_world_size > 1:
if self.clip_method[id(p)] == "row":
dist.all_reduce(res_mean_row, op=dist.ReduceOp.SUM, group=self.tensor_parallel_group)
res_mean_row /= dist.get_world_size(group=self.tensor_parallel_group)
elif self.clip_method[id(p)] == "col":
dist.all_reduce(res_mean_col, op=dist.ReduceOp.SUM, group=self.tensor_parallel_group)
res_mean_col /= dist.get_world_size(group=self.tensor_parallel_group)
else:
pass
if self.zero_parallel_world_size > 1:
dist.all_reduce(res_mean_col, op=dist.ReduceOp.SUM, group=self.zero_parallel_group)
res_mean_col /= dist.get_world_size(group=self.zero_parallel_group)

exp_avg_res_row.mul_(group["betas"][2]).add_(res_mean_row, alpha=1.0 - group["betas"][2])
exp_avg_res_col.mul_(group["betas"][2]).add_(res_mean_col, alpha=1.0 - group["betas"][2])
exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])
exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])

# Approximation of exponential moving average of instability
res_approx = self._approx_sq_grad(
exp_avg_res_row, exp_avg_res_col, clip_method=self.clip_method[id(p)]
)
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
update = res_approx.mul_(exp_avg)
else:
update = exp_avg.clone()

if group["weight_decay"] != 0:
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])

update.mul_(group["lr"])
update = self._flatten_update_tensor_by_param(p, update)
p.data.add_(-update)

return loss
Loading

0 comments on commit f95d875

Please sign in to comment.