From 375e1b944197835eacf1a8db4d6a85ca8e98d0c3 Mon Sep 17 00:00:00 2001 From: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:42:19 +0800 Subject: [PATCH] [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../booster/plugin/low_level_zero_plugin.py | 5 +- colossalai/nn/optimizer/__init__.py | 4 + colossalai/nn/optimizer/adafactor.py | 201 +++++ .../nn/optimizer/distributed_adafactor.py | 435 +++++++++++ .../en/features/distributed_adafactor.md | 156 ++++ .../zh-Hans/features/distributed_adafactor.md | 155 ++++ tests/test_optimizer/_utils.py | 118 +++ tests/test_optimizer/test_dist_adafactor.py | 697 ++++++++++++++++++ tests/test_shardformer/test_model/_utils.py | 67 +- 9 files changed, 1835 insertions(+), 3 deletions(-) create mode 100644 colossalai/nn/optimizer/adafactor.py create mode 100644 colossalai/nn/optimizer/distributed_adafactor.py create mode 100644 docs/source/en/features/distributed_adafactor.md create mode 100644 docs/source/zh-Hans/features/distributed_adafactor.md create mode 100644 tests/test_optimizer/test_dist_adafactor.py diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 8fc390414484..19faf80b0e81 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -6,6 +6,7 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch +import torch.distributed import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -328,8 +329,8 @@ def configure( model.update_master_params = MethodType(optimizer.update_master_params, model) # Setup optimizers that require global states if isinstance(optimizer.optim, DistributedOptim): - tp_group = self.tp_group - dp_group = self.dp_group + tp_group = None + dp_group = torch.distributed.distributed_c10d._get_default_group() shard_to_param = optimizer.get_master_to_working_map() is_zero = True optimizer.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero) diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 2ead709bff13..155051f04516 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,5 +1,7 @@ +from .adafactor import Adafactor from .came import CAME from .cpu_adam import CPUAdam +from .distributed_adafactor import DistributedAdaFactor from .distributed_came import DistributedCAME from .distributed_lamb import DistributedLamb from .fused_adam import FusedAdam @@ -20,4 +22,6 @@ "DistributedLamb", "CAME", "DistributedCAME", + "Adafactor", + "DistributedAdaFactor", ] diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py new file mode 100644 index 000000000000..22a6c8f4d3ce --- /dev/null +++ b/colossalai/nn/optimizer/adafactor.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch.optim import Optimizer + +__all__ = ["Adafactor"] + + +# Adafactor +class Adafactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + lr = None + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + """ + # grad shape is same as weigh / bias + """ + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + """ + p is weight + state + {'step', + 'exp_avg_sq_row', + 'exp_avg_sq_col', + 'RMS' + } + + p is bias + state + {'step', + 'exp_avg_sq', + 'RMS' + } + """ + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"] + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + else: + state["exp_avg_sq"] = state["exp_avg_sq"] + + state["step"] += 1 + # state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + # RMS + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p.add_(p, alpha=(-group["weight_decay"] * lr)) + p.add_(-update) + + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py new file mode 100644 index 000000000000..d0794f450d8a --- /dev/null +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -0,0 +1,435 @@ +import math +from typing import Dict + +import torch +import torch.distributed as dist + +from colossalai.interface.optimizer import DistributedOptim +from colossalai.shardformer.layer._operation import _gather, _split +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor + +# DistributedAdaFactor (with Tensor parallel and Zero stage 2) +__all__ = ["DistributedAdaFactor"] + + +class DistributedAdaFactor(DistributedOptim): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + lr = None + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + self.tensor_parallel_size = 1 + self.tensor_parallel_group = None + self.data_parallel_size = 1 + self.data_parallel_group = None + self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + self.use_zero = True + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} + self.use_first_moment_dict = {} # {id(p): True/False} + self.shard_spec_dict = {} # {id(p): ShardSpec} + super().__init__(params, defaults) + + def setup_distributed( + self, + tensor_parallel_group: dist.ProcessGroup = None, + data_parallel_group: dist.ProcessGroup = None, + shard_to_param: Dict = {}, + use_zero: bool = True, + ) -> None: + """ + Inject features to the Optimizer + + Args: + tensor_parallel_group: The devices group for tensor parallel; + data_parallel_group: The devices group for data parallel; + sharding_spec_dict: ShardingSpecs of Each params; + param_shape: Paramater Shape of Each params; + use_zero: Whether or not to use zero; + + """ + self.tensor_parallel_group = tensor_parallel_group # "Expected row process group" + self.data_parallel_group = data_parallel_group + if self.tensor_parallel_group is not None: + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) + if self.data_parallel_group is not None: + self.data_parallel_size = dist.get_world_size(self.data_parallel_group) + self.use_zero = use_zero + + self.shard_to_param = shard_to_param if shard_to_param is not None else {} + # grad is None, cause we dont setup now + for group in self.param_groups: + for p in group["params"]: + self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options( + group, self.grad_shape_dict[id(p)] + ) + if self.param_is_dtensor_dict[id(p)]: + self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) + else: + self.shard_spec_dict[id(p)] = None + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + """ + Determines whether the current param is factored + Args: + param_group : param group + param_shape : Original Shape of param + + """ + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): + tensor_sum = tensor.pow(2).sum() + num_of_element = tensor.numel() + + if param_is_dtensor: + # reduce tensor_sum from tp_group + dist.all_reduce(tensor_sum, group=tp_group) + num_of_element = num_of_element * tp_size + if use_zero: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + rms = (tensor_sum / num_of_element).sqrt() + return rms + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + # row_meam = sq_row_meam + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): + if grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H, W/tp] + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W/tp] + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): + if grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + update = update_reshape + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _base_factor(self, update, grad, state, grad_shape, beta2t): + if self.use_zero: + # only zero + if grad_shape[0] % self.data_parallel_size != 0: + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change + # col mean need reduce and div + # gather update[flatten] along dp group then reshape to [H, W] + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W] + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + # no residual row + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = update_reshape.view(-1) + else: + # base factor; no tp, no dp + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + return update + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization steps + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = self.grad_shape_dict[id(p)] + param_is_dtensor = self.param_is_dtensor_dict[id(p)] + if param_is_dtensor: + grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] + + shard_spec = self.shard_spec_dict[id(p)] + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + if factored: + if param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel + if grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel + # Row indivisible shape situation + if grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] + + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + if self.use_zero: + if grad_shape[0] % self.data_parallel_size != 0: + # save all exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H // dp] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + # exp_avg_sq_col alaways [W] + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"] + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + else: + state["exp_avg_sq"] = state["exp_avg_sq"] + + state["step"] += 1 + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + + if factored: + if param_is_dtensor: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + if shard_spec.sharding_sequence[0] == "R": + update = self._col_parallel_factor(update, grad, state, grad_shape, beta2t) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == "R": + update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) + else: + update = self._base_factor(update, grad, state, grad_shape, beta2t) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + # # (Line No.8) RMS + rms = self._rms( + update, + param_is_dtensor, + self.use_zero, + self.tensor_parallel_size, + self.data_parallel_size, + self.tensor_parallel_group, + self.data_parallel_group, + ) + update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) + + update.mul_(lr) + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p.add_(p, alpha=(-group["weight_decay"] * lr)) + + p.add_(-update) + + return loss diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md new file mode 100644 index 000000000000..8d3691177ad6 --- /dev/null +++ b/docs/source/en/features/distributed_adafactor.md @@ -0,0 +1,156 @@ +# Distributed Adafactor + +Author: + +**Related Paper** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) + +## Introduction + +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. + +## API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} + +## Hands-On Practice +We now demonstrate how to start Distributed Adafactor with booster API. +### step 1. Import libraries + +```python +import torch +from torch import nn +import torch.distributed as dist +from transformers import LlamaModel, LlamaConfig + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossal_llama2.dataset.loader import load_tokenized_dataset +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +``` + +### step 2. Initialize Distributed Environment and Parallism Group +We then need to initialize distributed environment. For demo purpose, we uses `colossalai.launch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. We use `ProcessGroupMesh` to create tensor parallelism group and data parallelism group. + +```python +# Distributed Enviroment +config = {} +colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") +``` + +### step 3. Initialize Module and Optimizer +Build our model. We created an MLP using two Linear Layer. + +```python +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration) +dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") +dataloader = plugin.prepare_dataloader(dataset, batch_size=8) +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### step 4.Init Booster + +```python +plugin = LowLevelZeroPlugin() +booster = Booster(plugin=plugin) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +``` +### step 5.Train Your Model +```python +for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() +``` + +## Supporting Information +Model/Feature Compatibility Matrix: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Hybrid Parallel
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
+ + diff --git a/docs/source/zh-Hans/features/distributed_adafactor.md b/docs/source/zh-Hans/features/distributed_adafactor.md new file mode 100644 index 000000000000..19610a85c8c1 --- /dev/null +++ b/docs/source/zh-Hans/features/distributed_adafactor.md @@ -0,0 +1,155 @@ +# 分布式 Adafactor + +作者: + +**相关论文** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) + +## 简介 + +分布式 Adafactor 是一种支持混合优化的优化器,包括 1D 张量并行和 ZerO。它通过合理的任务并行化充分利用了计算资源,提高了训练效率和速度,并减少了存储压力。它应用广泛,目前支持一系列基于 Transformer 的模型,详见 [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo). + +## API接口 + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} + +## 实例演示 +现在我们演示如何使用 Booster API 启动分布式 Adafactor。 +### 步骤 1. 导入相关库 + +```python +import torch +from torch import nn +import torch.distributed as dist +from transformers import LlamaModel, LlamaConfig + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossal_llama2.dataset.loader import load_tokenized_dataset +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +``` + +### 步骤 2. 初始化分布式环境和参数 +然后,我们需要初始化分布式环境。为了演示的目的,我们使用了 `colossalai.launch`。您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) 获得其他的初始化方法。这里, 我们使用 "ProcessGroupMesh"来创建张量并行组和数据并行组。 + +```python +# Distributed Enviroment +config = {} +colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") +``` + +### 步骤 3.初始化模块和优化器 +Build our model. We created an MLP using two Linear Layer. + +```python +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration) +dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") +dataloader = plugin.prepare_dataloader(dataset, batch_size=8) +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### 步骤 4.初始化Booster + +```python +plugin = LowLevelZeroPlugin() +booster = Booster(plugin=plugin) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +``` +### 步骤 5.训练模型 +```python +for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() +``` + +## 支持信息 +模型/功能兼容性矩阵: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Hybrid Parallel
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
+ + diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 6ce9b0364abe..75b57db134ec 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -1,7 +1,9 @@ import torch +import torch.distributed from torch.testing import assert_close import colossalai +from colossalai.shardformer.layer._operation import _gather from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, spawn @@ -135,3 +137,119 @@ def _run_bert_test(rank, world_size, port, optim_class, sharded_optim_class): def check_optim_on_bert(optim_class, sharded_optim_class): spawn(_run_bert_test, 4, optim_class, sharded_optim_class) + + +def check_dist_optim_state(org_optimizer, sharded_optimizer): + torch.set_default_dtype(torch.bfloat16) + for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.param_groups): + for p, tp in zip(group["params"], tp_group["params"]): + p_state = org_optimizer.state[p] + tp_state = sharded_optimizer.state[tp] + # TODO "exp_avg_sq_col", "exp_avg_sq_row", "exp_avg_sq" + for key in ["exp_avg_sq_row"]: + if key in tp_state.keys() and type(tp_state[key]) is torch.Tensor: + tp_is_dtensor = sharded_optimizer.param_is_dtensor_dict[id(tp)] + shard_spec = sharded_optimizer.shard_spec_dict[id(tp)] + use_zero = sharded_optimizer.use_zero + tp_optim_state = tp_state[key] + p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape + dp_size, tp_size = ( + sharded_optimizer.data_parallel_size, + sharded_optimizer.tensor_parallel_size, + ) + # we start init model with first tensor parallel then zero; + # So, we gather model with first zero then tensor parallel + + if tp_is_dtensor: + # col parallel + if shard_spec.sharding_sequence[0] == "R": + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + tp_optim_state = _gather( + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, + ) + tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass + else: + pass + # gather from tp group + # sq_row don need gather alone tp group + if key == "exp_avg_sq_row": + pass + # sq_col need gather alone dp group + if key == "exp_avg_sq_col": + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) + tp_optim_state.shape + + # row parallel + if shard_spec.sharding_sequence[-1] == "R": + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + if p_state[key].shape[0] // tp_size % dp_size != 0: + pass + else: + tp_optim_state = _gather( + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, + ) + tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass + else: + pass + # gather from tp group + # sq_row need gather alone tp group + if key == "exp_avg_sq_row": + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) + tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass + else: + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + # row residule; no gather + if p_state[key].shape[0] % dp_size != 0: + pass + else: + tp_optim_state = _gather( + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, + ) + tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + tp_optim_state = tp_optim_state.div_(dp_size) + # need a div; + else: + pass + # Sovled a New issus: different dtype; + # So far, only happen in H100 env; + # Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision; + # Or assert_close just update to check dtype; + if p_state[key].dtype != tp_optim_state.dtype: + tp_optim_state = tp_optim_state.type(p_state[key].dtype) + assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) + + +def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol): + for (org_name, org_param), (sharded_name, sharded_param) in zip( + org_model.named_parameters(), sharded_model.named_parameters() + ): + if org_name in weight_layer_for_check: + # print(f"org_name {org_name} shape {org_param.shape} {org_param}\n sharded_name {sharded_name} shape {sharded_param.shape} {sharded_param}\n") + assert_close(org_param, sharded_param, atol=atol, rtol=rtol) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py new file mode 100644 index 000000000000..237851a90f6c --- /dev/null +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -0,0 +1,697 @@ +import copy + +import pytest +import torch +import torch.distributed as dist +from torch import nn +from torch.testing import assert_close + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.nn.optimizer.adafactor import Adafactor +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor import ( + distribute_tensor, + get_device_mesh, + get_layout, + get_sharding_spec, + is_distributed_tensor, + shard_colwise, + shard_rowwise, +) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.tensor.d_tensor.sharding_spec import DimSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.zero import LowLevelZeroOptimizer +from tests.kit.model_zoo import model_zoo +from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + build_model_from_low_level_zero_plugin, + check_weight, + run_forward_backward_with_hybrid_plugin, + run_forward_backward_with_low_level_zero_plugin, + unwrap_model, +) + +HEIGHT = 4 +WIDTH = 4 +_TP_SPEC = DimSpec([0]) + + +def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float32: + rtol = 5e-04 + atol = 5e-04 + elif dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) + assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + + +# setup param groups; (For zero test optim) +def setup_param_groups_zero(model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +# setup param groups; (For base optim) +def setup_param_groups(model: nn.Module) -> list: + optimizer_grouped_parameters = [p for n, p in model.named_parameters()] + return optimizer_grouped_parameters + + +# setup flatten param groups, sharding spec and shape; (For dist optim) +def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: + flatten_optimizer_grouped_parameters = [] + sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} + param_shape = {} # {id(flatten param): get_sharding_spec(p)} + for n, p in model.named_parameters(): + # flatten_p = copy.deepcopy(p).flatten() + flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) + flatten_optimizer_grouped_parameters.append(flatten_p) + if is_distributed_tensor(p): + sharding_spec[id(flatten_p)] = get_sharding_spec(p) + param_shape[id(flatten_p)] = get_layout(p).global_shape + else: + sharding_spec[id(flatten_p)] = None + param_shape[id(flatten_p)] = p.shape + return flatten_optimizer_grouped_parameters, sharding_spec, param_shape + + +def set_dist_grad( + dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup +) -> None: + """ + Set split grads for Tensor Parallel or ZeRO DP. + We do not need a separate treatment for ZeRO, + as the wrapper takes care of reduce-scattering grads. + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): + if torch_p.grad is None: + torch_p.grad = torch.zeros_like(torch_p) + + is_distributed = hasattr(p, "dist_layout") + if is_distributed: + sharding = p.dist_layout.sharding_spec.sharding_sequence + split_dim = sharding.index(_TP_SPEC) + shape = torch_p.split(world_size, dim=split_dim)[rank].shape + + indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) + # Generate grads only for the correctly split chunk + torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) + + else: + shape = torch_p.shape + torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) + + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +def set_master_param_to_shard_param(master_param_list) -> dict: + master_param_to_shard_param = {id(p): p for p in master_param_list} + return master_param_to_shard_param + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(HEIGHT, WIDTH) + self.linear2 = nn.Linear(WIDTH, HEIGHT) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module( + linear1, process_group=tp_group, gather_output=False, overlap=True + ) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) +def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + use_zero = True if zero_size > 1 else False + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Base Case + # ============================== + H, W = HEIGHT, WIDTH + model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight + weight, bias = model_col.weight, model_col.bias + + # ============================== + # Col Parallel + # ============================== + weight_col_shard = shard_colwise(weight.clone(), tp_group) + weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape + weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec + weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) + bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + + # ============================== + # Row Parallel + # ============================== + weight_row_shard = shard_rowwise(weight.clone(), tp_group) + weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape + weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec + weight_row_shard_flatten = nn.Parameter( + weight_row_shard.clone().flatten().requires_grad_(True) + ) # flatten input(not dtensor) to optimizer + bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + + # base_param_group = setup_param_groups([weight, bias]) + # cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) + # rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) + + # ============================== + # Init Optimizer + # ============================== + + # base + optimizer_base = Adafactor([weight, bias]) + cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) + + shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten]) + cp_dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_cp, + use_zero=use_zero, + ) + + shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten]) + rp_dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_rp, + use_zero=use_zero, + ) + + N_STEPS = 1 + for _ in range(N_STEPS): + # base step + optimizer_base.zero_grad() + weight.grad = torch.rand_like(weight) + bias.grad = torch.rand_like(bias) + optimizer_base.step() + + # col parallel step + cp_dist_optim.zero_grad() + weight_col_shard_flatten.grad = ( + distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec) + .clone() + .flatten() + ) + bias_col_flatten.grad = bias.grad.clone().flatten() + cp_dist_optim.step() + + # row parallel step + rp_dist_optim.zero_grad() + weight_row_shard_flatten.grad = ( + distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec) + .clone() + .flatten() + ) + bias_row_flatten.grad = bias.grad.clone().flatten() + rp_dist_optim.step() + + # gather result + weight_col_gather = _gather( + input_=weight_col_shard_flatten.data.view(-1, H // tp_size), + dim=-1, + process_group=tp_group, + ) # gather + weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view( + -1, W + ) # gather + + # verify + correctness_verify(weight.data, weight_col_gather.data, dtype) + correctness_verify(weight.data, weight_row_gather.data, dtype) + + print(f"Base Test Pass") + + +@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4)]) # (2, 2), (4, 1), (1, 4) +def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: + # No TP bias + pass + correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Zero Test Pass") + + +@parameterize("dtype", [torch.float16]) +@parameterize("tp_zero_size", [(1, 4)]) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + # tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + tp_model = copy.deepcopy(base_model).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = LowLevelZeroPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: + # No TP bias + pass + correctness_verify(p.data, tp_p.data, dtype) + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Booster Test Pass") + + +@parameterize( + "test_config", + [ + { + "stage": 1, + "precision": "bf16", + }, + { + "stage": 2, + "precision": "bf16", + }, + ], +) +def exam_bert_test_on_lowlevelzero_plugin(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + model_list = [ + "transformers_bert", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", + "transformers_bert_for_masked_lm", + "transformers_bert_for_sequence_classification", + "transformers_bert_for_token_classification", + "transformers_bert_for_next_sentence", + "transformers_bert_for_mcq", + "transformers_bert_for_question_answering", + ] + clear_layout_converter() + torch.set_default_dtype(torch.bfloat16) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + # LowLevelZero not need warp + # bert = unwrap_model(org_model, "BertModel", "bert") + # sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = [ + "bert.encoder.layer.0.output.dense.weight", + "bert.encoder.layer.0.output.dense.weight", + ] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + + check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol) + check_optim_states(org_optimizer, sharded_optimizer.optim) + + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Pass") + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 0, + "precision": "bf16", + }, + ], +) +def exam_bert_test_on_hybrid_plugin(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel + test_config["initial_scale"] = 2**16 # avoid overflow + model_list = [ + "transformers_bert", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", + "transformers_bert_for_masked_lm", + "transformers_bert_for_sequence_classification", + "transformers_bert_for_token_classification", + "transformers_bert_for_next_sentence", + "transformers_bert_for_mcq", + "transformers_bert_for_question_answering", + ] + clear_layout_converter() + torch.set_default_dtype(torch.bfloat16) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + # check optim states + check_dist_optim_state(org_optimizer, sharded_optimizer.optim) + + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Pass") + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_bert_test_on_lowlevelzero_plugin() + exam_bert_test_on_hybrid_plugin() + exam_dist_adafactor_base() + exam_dist_adafactor_zero() + exam_dist_adafactor_booster() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_adafactor(): + spawn(run_dist, nprocs=4) + + +if __name__ == "__main__": + test_dist_adafactor() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 4719fa0b0546..4c46e98f174e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -12,8 +12,9 @@ from torch.optim import Adam, Optimizer from torch.testing import assert_close +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager @@ -137,6 +138,32 @@ def build_model_from_hybrid_plugin( return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster +def build_model_from_low_level_zero_plugin( + model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam +): + use_lazy_init = False + if "use_lazy_init" in test_config: + use_lazy_init = test_config.pop("use_lazy_init") + + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + org_model = model_fn() + sharded_model = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + org_model = org_model.cuda() + org_optimizer = optim_class(org_model.parameters(), lr=1e-3) + sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + plugin = LowLevelZeroPlugin(**test_config, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + + def run_forward_backward_with_hybrid_plugin( org_model: Module, sharded_model: Module, @@ -195,6 +222,44 @@ def _criterion(outputs, inputs): return org_loss, org_output, sharded_loss, sharded_output +def run_forward_backward_with_low_level_zero_plugin( + org_model: Module, + sharded_model: Module, + sharded_optimizer: Optimizer, + data_gen_fn: Callable, + output_transform_fn: Callable, + criterion: Callable, + booster: Booster, +): + get_accelerator().get_current_device() + org_model.cuda() + sharded_model.cuda() + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + data = data_gen_fn() + + # data = { + # k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + # } + data = {k: v.cuda() for k, v in data.items()} + + sharded_model.train() + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_optimizer.backward(sharded_loss) + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) + org_loss.backward() + + return org_loss, org_output, sharded_loss, sharded_output + + def check_output_hidden_state( org_output: Tensor, sharded_output: Tensor,