Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optim] add distributed came #5526

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptimizer
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand Down Expand Up @@ -1166,6 +1167,15 @@ def configure(
**self.zero_config,
**self.amp_config,
)
# Setup optimizers that require global states
if isinstance(optimizer.optim, DistributedOptimizer):
self.tp_group = self.__dict__.get("tp_group", None)
self.dp_group = self.__dict__.get("dp_group", None)
chongqichuizi875 marked this conversation as resolved.
Show resolved Hide resolved
master_to_working_map = optimizer.get_master_to_working_map()
zero_flag = self.zero_stage > 0
optimizer.optim.setup_distributed(
master_to_working_map, self.tp_group, self.dp_group, zero_flag=zero_flag
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
7 changes: 7 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptimizer
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -325,6 +326,12 @@ def configure(
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
if isinstance(optimizer.optim, DistributedOptimizer):
tp_group = self.__dict__.get("tp_group", None)
dp_group = self.__dict__.get("dp_group", None)
master_to_working_map = optimizer.get_master_to_working_map()
optimizer.optim.setup_distributed(master_to_working_map, tp_group, dp_group, zero_flag=True)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
8 changes: 8 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.optim import Optimizer


Expand Down Expand Up @@ -133,3 +134,10 @@ def unwrap(self):
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim


class DistributedOptimizer(Optimizer):
def setup_distributed(
self, master_to_working_map: dict, tp_group: ProcessGroup, dp_group: ProcessGroup, zero_group: ProcessGroup
):
pass
156 changes: 156 additions & 0 deletions colossalai/nn/optimizer/came.py
ver217 marked this conversation as resolved.
Show resolved Hide resolved
ver217 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import torch.optim


class CAME(torch.optim.Optimizer):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and instability respectively (default: (1e-30, 1e-16))
clip_threshold (float): threshold of root-mean-square of
final gradient update (default: 1.0)
betas (tuple[float, float, float]): coefficient used for computing running averages of
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""

def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])

defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
)
super(CAME, self).__init__(params, defaults)

@property
def supports_memory_efficient_fp16(self):
return True

@property
def supports_flat_params(self):
return False

def _get_options(self, param_shape):
factored = len(param_shape) >= 2
return factored

def _rms(self, tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)

def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
# if grad.dtype in {torch.float16, torch.bfloat16}:
# grad = grad.float()
if grad.is_sparse:
raise RuntimeError("CAME does not support sparse gradients.")

state = self.state[p]
grad_shape = grad.shape

factored = self._get_options(grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0

state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
).type_as(grad)

state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
state["exp_avg_res_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
)
else:
state["exp_avg_sq"] = torch.zeros_like(p)

state["RMS"] = 0

state["step"] += 1
state["RMS"] = self._rms(p.data)

update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])
exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])

# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
update = exp_avg_sq.rsqrt().mul_(grad)
# if dist.get_rank() == 0:
# print("came: ", torch.sum(grad), grad)
chongqichuizi875 marked this conversation as resolved.
Show resolved Hide resolved

update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))

exp_avg = state["exp_avg"]
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])

# Confidence-guided strategy
# Calculation of instability
res = (update - exp_avg) ** 2 + group["eps"][1]

if factored:
exp_avg_res_row = state["exp_avg_res_row"]
exp_avg_res_col = state["exp_avg_res_col"]

exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])
exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])

# Approximation of exponential moving average of instability
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
update = res_approx.mul_(exp_avg)
else:
update = exp_avg.clone()

if group["weight_decay"] != 0:
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
update.mul_(group["lr"])
p.data.add_(-update)

return loss
Loading
Loading