diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py new file mode 100644 index 000000000000..53d4f1a69625 --- /dev/null +++ b/colossalai/nn/optimizer/adafactor.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import torch +from torch.optim import Optimizer + + +__all__ = ["Adafactor"] +# Adafactor +class Adafactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + """ + # grad shape is same as weigh / bias + """ + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + """ + p is weight + state + {'step', + 'exp_avg_sq_row', + 'exp_avg_sq_col', + 'RMS' + } + + p is bias + state + {'step', + 'exp_avg_sq', + 'RMS' + } + """ + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + # state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + # RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + p_data_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py new file mode 100644 index 000000000000..63af9bc1b352 --- /dev/null +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -0,0 +1,259 @@ +import math +import os +from typing import Dict + +import torch +from torch.optim import Optimizer +import torch.distributed as dist + +from colossalai.shardformer.layer._operation import _gather +from colossalai.tensor.d_tensor import ( + is_distributed_tensor, + get_layout, + get_sharding_spec +) + + +# DistributedAdaFactor (with Tensor parallel and Zero stage 2) +__all__ = ["DistributedAdaFactor"] +class DistributedAdaFactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + self.tensor_parallel_size = 1 + self.tensor_parallel_group = None + self.data_parallel_size = 1 + self.data_parallel_group = 1 + self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + + + def setup_distributed(self, + tensor_parallel_group: dist.ProcessGroup = None, + data_parallel_group: dist.ProcessGroup = None, + shard_to_param: Dict = None + )-> None: + """ + Inject features to the Optimizer + Args: + tensor_parallel_group (dist.ProcessGroup): The devices group for tensor parallel + data_parallel_group (dist.ProcessGroup): The devices group for data parallel + sharding_spec_dict (Dict{id(param):ShardingSpec}): ShardingSpecs of Each params + param_shape (Dict{id(param):shape}): Paramater Shape of Each params + + """ + self.tensor_parallel_group = tensor_parallel_group # "Expected row process group" + self.data_parallel_group = data_parallel_group + if self.tensor_parallel_group is not None: + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) + if self.data_parallel_group is not None: + self.data_parallel_size = dist.get_world_size(self.data_parallel_group) + self.shard_to_param = shard_to_param + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + """ + Determines whether the current param is factored + Args: + param_group : param group + param_shape : Original Shape of param + + """ + factored = (len(param_shape) >= 2) + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + # row_meam = sq_row_meam + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization steps + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + state = self.state[p] + grad_shape = grad.shape # 1 dim shape + param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) + if param_is_dtensor: + grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + else: + grad_shape = grad.shape # tp and zero shape (1 dim) + factored, use_first_moment = self._get_options(group, grad_shape) + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if shard_spec.sharding_sequence[0] == 'R': # Col Parallel + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size).to(grad) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).to(grad) # [W/TP] + + if shard_spec.sharding_sequence[-1] == 'R': # Row Parallel + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.tensor_parallel_size).to(grad) # [H/dp/Tp] + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1]).to(grad) # [W/TP] + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + + if factored: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if shard_spec.sharding_sequence[0] == 'R': + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = update_reshape.flatten() + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == 'R': + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + update = update_reshape.flatten() + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + # (Line No.8) RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update).flatten() + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss + diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py new file mode 100644 index 000000000000..e477611f4882 --- /dev/null +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -0,0 +1,527 @@ +import pytest +import os +import sys +import copy +import torch +import torch.distributed as dist +from torch import nn +from torch.testing import assert_close + +import colossalai +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor import layout_converter +from colossalai.tensor.d_tensor.sharding_spec import DimSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.device.device_mesh import DeviceMesh +from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.tensor.d_tensor import ( + is_distributed_tensor, + distribute_tensor, + sharded_tensor_to_param, + shard_rowwise, + shard_colwise, + get_layout, + get_sharding_spec +) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.zero import LowLevelZeroOptimizer +from colossalai.nn.optimizer.adafactor import Adafactor +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin + +HEIGHT = 4096 +WIDTH = 4096 +_TP_SPEC = DimSpec([0]) + +def init_dist(): + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + dist.init_process_group(world_size=world_size, rank=rank, + init_method="env://", backend="nccl") + torch.cuda.set_device(local_rank) + +def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float32: + rtol = 1e-05 + atol = 1e-05 + elif dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol, equal_nan=True)) + # assert_close(tensor1, tensor2, rtol=rtol, atol=atol, equal_nan=True) + +# setup param groups; (For zero test optim) +def setup_param_groups_zero(model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + +# setup param groups; (For base optim) +def setup_param_groups(model: nn.Module) -> list: + optimizer_grouped_parameters = [p for n, p in model.named_parameters()] + return optimizer_grouped_parameters + +# setup flatten param groups, sharding spec and shape; (For dist optim) +def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: + flatten_optimizer_grouped_parameters = [] + sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} + param_shape = {} # {id(flatten param): get_sharding_spec(p)} + for n, p in model.named_parameters(): + # flatten_p = copy.deepcopy(p).flatten() + flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) + flatten_optimizer_grouped_parameters.append(flatten_p) + if is_distributed_tensor(p): + sharding_spec[id(flatten_p)] = get_sharding_spec(p) + param_shape[id(flatten_p)] = get_layout(p).global_shape + else: + sharding_spec[id(flatten_p)] = None + param_shape[id(flatten_p)] = p.shape + # print(f"sharding_spec {sharding_spec}") + # print(f"param_shape {param_shape}") + return flatten_optimizer_grouped_parameters, sharding_spec, param_shape + +def set_dist_grad( + dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup +) -> None: + """ + Set split grads for Tensor Parallel or ZeRO DP. + We do not need a separate treatment for ZeRO, + as the wrapper takes care of reduce-scattering grads. + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): + if torch_p.grad is None: + torch_p.grad = torch.zeros_like(torch_p) + + is_distributed = hasattr(p, "dist_layout") + if is_distributed: + sharding = p.dist_layout.sharding_spec.sharding_sequence + split_dim = sharding.index(_TP_SPEC) + shape = torch_p.split(world_size, dim=split_dim)[rank].shape + + indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) + # Generate grads only for the correctly split chunk + torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) + + else: + shape = torch_p.shape + torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) + + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(HEIGHT, WIDTH) + self.linear2 = nn.Linear(WIDTH, HEIGHT) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module(linear1, process_group=tp_group, gather_output=False, overlap=True) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_base(dtype: torch.dtype): + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Base Case + # ============================== + H, W = 4096, 4096 + model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight + weight, bias = model_col.weight, model_col.bias + device_mesh = DeviceMesh(torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True) + tp_group = device_mesh.get_process_group(axis=1) + # ============================== + # Col Parallel + # ============================== + weight_col_shard = shard_colwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape + weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec + weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) + bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + col_params_shape = {id(weight_col_shard_flatten): weight_col_shard_layout.global_shape, id(bias_col_flatten): bias.shape} + col_sharding_spec_dict = {id(weight_col_shard_flatten): weight_col_shard_shard_spec, id(bias_col_flatten): None} + + # ============================== + # Row Parallel + # ============================== + weight_row_shard = shard_rowwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape + weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec + weight_row_shard_flatten = nn.Parameter(weight_row_shard.clone().flatten().requires_grad_(True)) # flatten input(not dtensor) to optimizer + bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + row_params_shape = {id(weight_row_shard_flatten): weight_row_shard_layout.global_shape, id(bias_row_flatten): bias.shape} + row_sharding_spec_dict = {id(weight_row_shard_flatten): weight_row_shard_shard_spec, id(bias_row_flatten): None} + + # ============================== + # Init Optimizer + # ============================== + + # base + optimizer_base = Adafactor([weight, bias]) + + # col parallel + optimizer_cp = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + optimizer_cp.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=None, sharding_spec_dict=col_sharding_spec_dict, param_shape = col_params_shape) + # row parallel + optimizer_rp = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) + optimizer_rp.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group =None, sharding_spec_dict=row_sharding_spec_dict, param_shape = row_params_shape) + + N_STEPS = 1 + for _ in range(N_STEPS): + # base step + optimizer_base.zero_grad() + weight.grad = torch.rand_like(weight) + bias.grad = torch.rand_like(bias) + optimizer_base.step() + + # col parallel step + optimizer_cp.zero_grad() + weight_col_shard_flatten.grad = distribute_tensor(weight.grad, device_mesh, weight_col_shard_shard_spec).clone().flatten() + bias_col_flatten.grad = bias.grad.clone().flatten() + optimizer_cp.step() + + # row parallel step + optimizer_rp.zero_grad() + weight_row_shard_flatten.grad = distribute_tensor(weight.grad, device_mesh, weight_row_shard_shard_spec).clone().flatten() + bias_row_flatten.grad = bias.grad.clone().flatten() + optimizer_rp.step() + + # gather result + weight_col_gather = _gather(input_=weight_col_shard_flatten.data.view(-1, H // tensor_parallel_size),dim=-1, process_group=device_mesh.get_process_group(axis=1)) # gather + weight_row_gather = _gather(input_=weight_row_shard_flatten.data,dim=-1, process_group=device_mesh.get_process_group(axis=1)).view(-1, W) # gather + + + # verify + col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) + row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) + + print(f"col corrness {col_correct} row correct {row_correct}") + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + device_mesh = DeviceMesh(torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True) + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), device_mesh.get_process_group(axis=1)).to(local_rank) + tp_group = device_mesh.get_process_group(axis=1) + + base_param_group = setup_param_groups(base_model) + tp_param_group, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + dist_optim.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=None, sharding_spec_dict=tp_shard_spec, param_shape=tp_param_shape) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + loss_tp = tp_model(x).sum() + loss_tp.backward() + + loss = base_model(x).sum() + loss.backward() + + base_optim.zero_grad() + dist_optim.zero_grad() + + base_optim.step() + dist_optim.step() + + for p, tp_p in zip(base_param_group, tp_param_group): + + if tp_shard_spec[id(tp_p)] is not None: + if len(tp_shard_spec[id(tp_p)].sharding_sequence) >= 2: + # print(f"device {local_rank} \n tp_p shard spec {tp_shard_spec[id(tp_p)]}\n len {len(tp_shard_spec[id(tp_p)].sharding_sequence)}") + # if tp_p tp_shard_spec is col tp --> view to (-1, H // tensor_parallel_size) then gather + if tp_shard_spec[id(tp_p)].sharding_sequence[0] == 'R': + tp_p = _gather(input_=tp_p.data.view(-1, HEIGHT // tensor_parallel_size),dim=-1, process_group=device_mesh.get_process_group(axis=1)) # gather + # if tp_p tp_shard_spec is row tp --> gather then view to (-1, H // tensor_parallel_size) + else: + tp_p = _gather(input_=tp_p.data,dim=-1, process_group=device_mesh.get_process_group(axis=1)).view(-1, WIDTH) # gather + else: + # bias parallel + tp_p = _gather(input_=tp_p.data,dim=-1, process_group=device_mesh.get_process_group(axis=1)) + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + else: + # compare p and tp no need + pass + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + correctness = correctness_verify(p.data, tp_p.data, dtype) + # print(f"correct {correctness}") + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(2, 2)])# (2, 2), (4, 1),(1, 4) +def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=dp_group, shard_to_param=shard_to_param) + else: + dist_optim.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=dp_group, shard_to_param=shard_to_param) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == 'R': + tp_p = _gather(input_=tp_p,dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == 'R': + tp_p = _gather(input_=tp_p,dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p,dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(2, 2)])# (2, 2), (4, 1),(1, 4) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=dp_group, shard_to_param=shard_to_param) + else: + dist_optim.setup_distributed(tensor_parallel_group=tp_group, data_parallel_group=dp_group, shard_to_param=shard_to_param) + + + # ============================== + # Booster Init + # ============================== + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == 'R': + tp_p = _gather(input_=tp_p,dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == 'R': + tp_p = _gather(input_=tp_p,dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p,dim=-1, process_group=tp_group) # gather + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # init_dist() + # exam_dist_adafactor_base() + # exam_dist_adafactor_fwd_bwd() + exam_dist_adafactor_zero() + exam_dist_adafactor_booster() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_adafactor(): + spawn(run_dist, nprocs=4) + +if __name__ == "__main__": + test_dist_adafactor() \ No newline at end of file