diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 57d677ef0059..22a6c8f4d3ce 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -36,7 +36,7 @@ def __init__( relative_step=True, warmup_init=False, ): - lr=None + lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index ccda1f3e65b1..3f455c1c5b41 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,9 +3,9 @@ import torch import torch.distributed as dist + # from torch.optim import Optimizer from colossalai.interface.optimizer import DistributedOptim - from colossalai.shardformer.layer._operation import _gather, _split from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor @@ -50,14 +50,13 @@ def __init__( self.data_parallel_group = None self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} self.use_zero = True - - self.param_is_dtensor_dict = {} # {id(p): True/False} - self.grad_shape_dict = {} # {id(p): master param shape} - self.factored_dict = {} # {id(p): True/False} + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} self.use_first_moment_dict = {} # {id(p): True/False} self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) - def setup_distributed( self, @@ -84,24 +83,26 @@ def setup_distributed( if self.data_parallel_group is not None: self.data_parallel_size = dist.get_world_size(self.data_parallel_group) self.use_zero = use_zero - + self.shard_to_param = shard_to_param if shard_to_param is not None else {} # grad is None, cause we dont setup now for group in self.param_groups: for p in group["params"]: self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) if self.param_is_dtensor_dict[id(p)]: - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape else: # no tp; could be zero or not zero # self.grad_shape_dict[id(p)] = p.shape - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape - self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options( + group, self.grad_shape_dict[id(p)] + ) if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) else: self.shard_spec_dict[id(p)] = None - + @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] @@ -130,7 +131,7 @@ def _get_options(param_group, param_shape): def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): tensor_sum = tensor.pow(2).sum() num_of_element = tensor.numel() - + if param_is_dtensor: # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) @@ -162,30 +163,26 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) - + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H, W/tp] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W/tp] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H] exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) + update_reshape.mul_(grad_reshape) if self.use_zero: update = update_reshape.view(-1) else: update = update_reshape - + else: update_reshape = update.view(-1, grad_shape[1]) grad_reshape = grad.view(-1, grad_shape[1]) @@ -202,19 +199,15 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): else: update = update_reshape return update - + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H/tp, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H/tp, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -240,9 +233,7 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) exp_avg_sq_col.div_(self.tensor_parallel_size) # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group) sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) update_reshape.mul_(grad_reshape) @@ -251,24 +242,20 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): else: update = update_reshape return update - + def _base_factor(self, update, grad, state, grad_shape, beta2t): if self.use_zero: # only zero if grad_shape[0] % self.data_parallel_size != 0: - # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) - # row mean no change + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change # col mean need reduce and div # gather update[flatten] along dp group then reshape to [H, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -283,8 +270,8 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): else: # no residual row # view update to origin[tp] shape - update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] - grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -294,7 +281,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_col.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - update = update_reshape.view(-1) + update = update_reshape.view(-1) else: # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] @@ -307,9 +294,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) return update - - - + @torch.no_grad() def step(self, closure=None): """ @@ -344,7 +329,7 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") - + state = self.state[p] grad_shape = self.grad_shape_dict[id(p)] param_is_dtensor = self.param_is_dtensor_dict[id(p)] @@ -364,11 +349,11 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H] + ) # [H] else: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] @@ -378,23 +363,27 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/tp] + ) # [H/tp] else: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/tp] - + ) # [H/dp/tp] + state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] else: if self.use_zero: if grad_shape[0] % self.data_parallel_size != 0: # save all exp_avg_sq_row [H] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H // dp] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H] state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) @@ -432,16 +421,23 @@ def step(self, closure=None): elif shard_spec.sharding_sequence[-1] == "R": update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) else: - update = self._base_factor(update, grad, state, grad_shape, beta2t) + update = self._base_factor(update, grad, state, grad_shape, beta2t) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - + # # (Line No.8) RMS - rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) + rms = self._rms( + update, + param_is_dtensor, + self.tensor_parallel_size, + self.data_parallel_size, + self.tensor_parallel_group, + self.data_parallel_group, + ) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) - + update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] @@ -450,8 +446,7 @@ def step(self, closure=None): if group["weight_decay"] != 0: p.add_(p, alpha=(-group["weight_decay"] * lr)) - + p.add_(-update) - return loss diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 5a8d8ebade85..6553c0b03f7d 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -1,20 +1,20 @@ # Distributed Adafactor -Author: +Author: **Related Paper** - [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) ## Introduction -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. ### API Reference {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice -We now demonstrate how to start Distributed Adafactor with booster API. +We now demonstrate how to start Distributed Adafactor with booster API. ### step 1. Import libraries ```python @@ -65,9 +65,9 @@ dist_optim = DistributedAdaFactor(model.parameters()) ```python plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) -model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) ``` -### step 5.Train Your Model +### step 5.Train Your Model ```python for epoch in range(max_epochs): for input_ids, attention_mask in dataloader: @@ -106,7 +106,7 @@ Model/Feature Compatibility Matrix: ✔️ ✔️ - + diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 49ab3ee4562c..84f47556273d 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -2,6 +2,7 @@ from torch.testing import assert_close import colossalai +from colossalai.shardformer.layer._operation import _gather from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, spawn @@ -12,7 +13,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) -from colossalai.shardformer.layer._operation import _gather + def check_optim_states(org_optim, sharded_optim): for group in org_optim.param_groups: @@ -148,7 +149,7 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): use_zero = sharded_optimizer.use_zero tp_optim_state = tp_state[key] p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape - # we start init model as first tensor parallel then zero; + # we start init model as first tensor parallel then zero; # we gather model as first zero then tensor parallel if use_zero: # gather from dp group @@ -159,33 +160,35 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): tp_state_shape = tp_optim_state.shape else: pass - + # check tp if tp_is_dtensor: # gather from tp group if p_state_shape != tp_state_shape: tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group - ) + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) tp_state_shape = tp_optim_state.shape else: pass else: pass - + else: # check tp if tp_is_dtensor: # gather from tp group if p_state_shape != tp_state_shape: tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group - ) + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) tp_state_shape = tp_optim_state.shape else: pass else: pass - print(f"{key} \np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\n") - + print( + f"{key} \np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\n" + ) + assert_close(p_state[key], tp_optim_state, atol=5e-3, rtol=1.6e-2) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index ab6cde370bc3..d0d5c79d963d 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -1,5 +1,4 @@ import copy -import os import pytest import torch @@ -9,19 +8,19 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.logging import disable_existing_loggers +from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import ProcessGroupMesh -from colossalai.device.device_mesh import DeviceMesh +from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer.adafactor import Adafactor from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor import ( distribute_tensor, + get_device_mesh, get_layout, get_sharding_spec, - get_device_mesh, is_distributed_tensor, shard_colwise, shard_rowwise, @@ -32,21 +31,18 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test, check_dist_optim_state -from colossalai.shardformer.layer._operation import _gather from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_weight, run_forward_backward_with_hybrid_plugin, unwrap_model, ) -from colossalai.shardformer.layer.utils import Randomizer - HEIGHT = 4 WIDTH = 4 _TP_SPEC = DimSpec([0]) + def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): rtol = None atol = None @@ -63,6 +59,7 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + # setup param groups; (For zero test optim) def setup_param_groups_zero(model: nn.Module) -> list: no_decay = ["bias", "LayerNorm.weight"] @@ -78,11 +75,13 @@ def setup_param_groups_zero(model: nn.Module) -> list: ] return optimizer_grouped_parameters + # setup param groups; (For base optim) def setup_param_groups(model: nn.Module) -> list: optimizer_grouped_parameters = [p for n, p in model.named_parameters()] return optimizer_grouped_parameters + # setup flatten param groups, sharding spec and shape; (For dist optim) def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: flatten_optimizer_grouped_parameters = [] @@ -136,10 +135,12 @@ def set_dist_grad( p.grad = p.data p.data = orig_p + def set_master_param_to_shard_param(master_param_list) -> dict: - master_param_to_shard_param ={id(p):p for p in master_param_list} + master_param_to_shard_param = {id(p): p for p in master_param_list} return master_param_to_shard_param - + + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -151,6 +152,7 @@ def forward(self, x): x = self.linear2(x) return x + class TPModel(nn.Module): def __init__(self, linear1, linear2, tp_group=None): super().__init__() @@ -164,8 +166,9 @@ def forward(self, x): x = self.linear2(x) return x + @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) +@parameterize("tp_zero_size", [(4, 1)]) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -183,7 +186,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): H, W = HEIGHT, WIDTH model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias - + # ============================== # Col Parallel # ============================== @@ -203,7 +206,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): weight_row_shard.clone().flatten().requires_grad_(True) ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - + # base_param_group = setup_param_groups([weight, bias]) # cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) # rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) @@ -216,7 +219,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): optimizer_base = Adafactor([weight, bias]) cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - + shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten]) cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, @@ -224,7 +227,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): shard_to_param=shard_to_param_cp, use_zero=use_zero, ) - + shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten]) rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, @@ -233,7 +236,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) - N_STEPS = 1 for _ in range(N_STEPS): # base step @@ -245,7 +247,9 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # col parallel step cp_dist_optim.zero_grad() weight_col_shard_flatten.grad = ( - distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec) + .clone() + .flatten() ) bias_col_flatten.grad = bias.grad.clone().flatten() cp_dist_optim.step() @@ -253,7 +257,9 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # row parallel step rp_dist_optim.zero_grad() weight_row_shard_flatten.grad = ( - distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec) + .clone() + .flatten() ) bias_row_flatten.grad = bias.grad.clone().flatten() rp_dist_optim.step() @@ -264,25 +270,24 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dim=-1, process_group=tp_group, ) # gather - weight_row_gather = _gather( - input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group - ).view( + weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view( -1, W ) # gather # verify - col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) - row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) + correctness_verify(weight.data, weight_col_gather.data, dtype) + correctness_verify(weight.data, weight_row_gather.data, dtype) print(f"Base Test Pass") + @parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() - + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) @@ -381,19 +386,20 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): else: # No TP bias pass - correctness = correctness_verify(p.data, tp_p.data, dtype) + correctness_verify(p.data, tp_p.data, dtype) clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() print(f"Zero Test Pass") - + + @parameterize("dtype", [torch.float16]) @parameterize("tp_zero_size", [(1, 4)]) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() - + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) @@ -454,14 +460,14 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int shard_to_param=shard_to_param, use_zero=use_zero, ) - + # ============================== # Booster Init # ============================== plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) criterion = lambda x: x.mean() - + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) # ============================== @@ -502,12 +508,12 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int else: # No TP bias pass - correctness = correctness_verify(p.data, tp_p.data, dtype) + correctness_verify(p.data, tp_p.data, dtype) Randomizer.reset_index() - torch.cuda.empty_cache() - print(f"Booster Test Pass") - - + torch.cuda.empty_cache() + print(f"Booster Test Pass") + + @parameterize( "test_config", [ @@ -562,25 +568,29 @@ def exam_bert_test(test_config): clear_layout_converter() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name in model_list: + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group bert = unwrap_model(org_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - + org_optimizer.step() sharded_optimizer.step() - + # check weights if test_config["precision"] == "bf16": atol, rtol = 5e-4, 5e-4 @@ -589,12 +599,13 @@ def exam_bert_test(test_config): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check optim states - # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) + # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) Randomizer.reset_index() torch.cuda.empty_cache() - print(f"Bert Model Zoo Test Pass") - + print(f"Bert Model Zoo Test Pass") + + def run_dist(rank, world_size, port): disable_existing_loggers() config = {} @@ -604,6 +615,7 @@ def run_dist(rank, world_size, port): exam_dist_adafactor_zero() exam_dist_adafactor_booster() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor():