From 576a2f7b10711bcdb43b86da6a5afaa98f4ad867 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 10 Nov 2023 10:15:16 +0800 Subject: [PATCH] [gemini] gemini support tensor parallelism. (#4942) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li Co-authored-by: CjhHa1 * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li Co-authored-by: CjhHa1 * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions Co-authored-by: Baizhou Zhang Co-authored-by: Zhongkai Zhao Co-authored-by: digger yu Co-authored-by: Cuiqing Li Co-authored-by: cuiqing.li Co-authored-by: CjhHa1 Co-authored-by: Xu Kai Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> --- colossalai/booster/plugin/gemini_plugin.py | 81 ++++++++++++++++++- colossalai/cluster/process_group_mesh.py | 1 + colossalai/shardformer/layer/_operation.py | 58 +++++++++++-- colossalai/shardformer/layer/embedding.py | 5 +- colossalai/shardformer/layer/normalization.py | 72 ++++++++++++----- colossalai/shardformer/modeling/bloom.py | 3 +- colossalai/shardformer/policies/t5.py | 8 -- colossalai/tensor/d_tensor/__init__.py | 4 + colossalai/tensor/d_tensor/api.py | 59 ++++++++++++++ colossalai/zero/gemini/gemini_ddp.py | 53 ++++++++++-- colossalai/zero/gemini/gemini_optimizer.py | 76 +++++++++++++++-- .../test_plugin/test_gemini_plugin.py | 19 ++++- .../test_gemini_checkpoint_io.py | 18 +++-- 13 files changed, 390 insertions(+), 67 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d1a9bc2623a3..9c7dc6836c1e 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -5,6 +5,7 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -19,8 +20,9 @@ save_state_dict, save_state_dict_shards, ) -from colossalai.cluster import DistCoordinator +from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -32,7 +34,25 @@ SUPPORTED_PRECISION = ["fp16", "bf16"] PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} +DP_AXIS = 0 +TP_AXIS = 1 +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A mapping from integer param_id to param32 shape. + + if optim is None: + return {} + param_info = {"id2shape": {}} + start_index = 0 + for group in optim.param_groups: + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + param_info["id2shape"][param_id] = original_shape + + start_index += len(group["params"]) + + return param_info class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -284,6 +304,16 @@ class GeminiPlugin(DPPluginBase): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. + enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. + tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ @@ -317,6 +347,14 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, + enable_tensor_parallelism: bool = False, + tp_size: int = 1, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_sequence_parallelism: bool = False, + enable_jit_fused: bool = False, + enable_sequence_overlap: bool = False, verbose: bool = False, ) -> None: super().__init__() @@ -355,8 +393,32 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) + self.enable_tensor_parallelism = enable_tensor_parallelism + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 + self.dp_size = dist.get_world_size() // self.tp_size + assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap, + ) + def support_no_sync(self) -> bool: return False @@ -380,6 +442,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + optimizer_params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -391,11 +454,21 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) + if self.enable_tensor_parallelism: + shardformer = ShardFormer(self.shard_config) + model, _ = shardformer.optimize(model) + + model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + optimizer, + model, + **self.zero_optim_config, + **self.optim_kwargs, + tp_group=self.tp_group, + optimizer_params_info=optimizer_params_info, + verbose=self.verbose, ) return model, optimizer, criterion, dataloader, lr_scheduler @@ -407,4 +480,4 @@ def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index eb4532194a26..7a3bde44869c 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -225,3 +225,4 @@ def get_group_along_axis( # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self._ranks_to_group[ranks_in_group] + \ No newline at end of file diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 5ec48096183b..0d8c3d453ce1 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce @@ -62,13 +62,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): if bias is not None: output = output + bias + return output @staticmethod def backward(ctx, grad_output): - input, weight = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight = weight.view(weight.shape) + bias = bias.view(bias.shape) + total_input = input grad_input = grad_output.matmul(weight.T) grad_output = grad_output.contiguous() @@ -100,7 +105,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce @@ -109,13 +114,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): output = F.linear(input_, weight, bias) else: output = F.linear(input_, weight) + return output @staticmethod def backward(ctx, grad_output): - input, weight = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + total_input = input grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() @@ -152,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -170,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -289,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -306,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight = weight.view(weight.shape) + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -454,6 +473,29 @@ def forward(ctx, input_, dim, process_group): @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +class HookParameter(torch.autograd.Function): + """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" + @staticmethod + def forward(ctx, input, weight, bias): + ctx.save_for_backward(weight, bias) + output = input + return output + + @staticmethod + def backward(ctx, grad_output): + weight, bias = ctx.saved_tensors + if weight is not None: + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + return grad_output, None, None + + +def hook_paramter_in_backward(input, weight=None, bias=None): + return HookParameter.apply(input, weight, bias) + def _reduce(input_, process_group): diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 62163cb009aa..d081b204093b 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -309,7 +309,8 @@ def forward(self, input_: Tensor) -> Tensor: ) # Mask the output embedding. - output_parallel[input_mask, :] = 0.0 + embedding_output = output_parallel.clone() + embedding_output[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(embedding_output, self.process_group) return output diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 413d07e8742b..42efe9a44308 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,15 +1,29 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import warnings from abc import ABC, abstractmethod - import torch.nn as nn - from colossalai.lazy import LazyInitContext +from ._operation import hook_paramter_in_backward from .utils import SeqParallelUtils __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + EnableFastLayerNorm = True +except ImportError: + EnableFastLayerNorm = False + +try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm +except ImportError: + warnings.warn( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" + ) + FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, 1536, @@ -37,6 +51,34 @@ 65536, ] +if EnableFastLayerNorm: + class FastLayerNormWithHook(FastLayerNorm): + def __init__(self, hidden_size, eps=0.00001): + super().__init__(hidden_size, eps) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + +class FusedLayerNormWithHook(ApexFusedLayerNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + +class FusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight) + return output + class BaseLayerNorm(ABC): @abstractmethod @@ -161,16 +203,6 @@ def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, * Raises: AssertionError: If the provided module is not an instance of nn.LayerNorm. """ - # check if apex is installed - - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." - - try: - pass - except ImportError: - raise ImportError( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" - ) LazyInitContext.materialize(module) # get the attributes of the module @@ -184,18 +216,17 @@ def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, * use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE if use_fast_ln: - try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm - except ImportError: + if EnableFastLayerNorm: + ApexFusedLayerNorm = FastLayerNormWithHook + else: # fall back to the normal fused layernorm is not built - from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + ApexFusedLayerNorm = FusedLayerNormWithHook else: - from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + ApexFusedLayerNorm = FusedLayerNormWithHook layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) - layernorm.weight = module.weight layernorm.bias = module.bias @@ -213,13 +244,12 @@ class FusedRMSNorm(BaseLayerNorm): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ - def __init__(self) -> None: raise NotImplementedError( "FusedRMSNorm is not implemented as a physical class. " "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." ) - + @staticmethod def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" @@ -252,7 +282,7 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg eps = module.eps elementwise_affine = module.elementwise_affine - rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm.weight = module.weight diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1bf87e80a461..cd8a023306dc 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -719,7 +719,7 @@ def forward( ): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = query_layer.size() + batch_size, tgt_len, _, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -755,6 +755,7 @@ def forward( attention_numerical_mask = torch.masked_fill( attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min ) + attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype) context_layer = me_attention( query_layer, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index fc5021600acc..4d906e3f4c04 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -183,14 +183,6 @@ def module_policy(self): policy=policy, target_key=T5LayerFF, ) - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=norm_cls, - ), - policy=policy, - target_key=T5LayerFF, - ) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls), policy=policy, diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index fad5101d380c..6f8097735d57 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -2,7 +2,9 @@ compute_global_numel, customized_distributed_tensor_to_param, distribute_tensor, + init_as_dtensor, distribute_tensor_with_customization, + init_tensor_as_customization_distributed, get_device_mesh, get_global_shape, get_layout, @@ -23,6 +25,7 @@ __all__ = [ "is_distributed_tensor", "distribute_tensor", + "init_as_dtensor", "to_global", "is_sharded", "shard_rowwise", @@ -36,6 +39,7 @@ "get_layout", "is_customized_distributed_tensor", "distribute_tensor_with_customization", + "init_tensor_as_customization_distributed", "to_global_for_customized_distributed_tensor", "customized_distributed_tensor_to_param", "Layout", diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 178bac428ea9..74a785f2dcd4 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -128,6 +128,17 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp return sharded_tensor +def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor: + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + + # shard tensor + tensor.dist_layout = dist_layout + + # hack some tensor methods + _hijack_detach_and_clone(tensor) + + return tensor def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: """ @@ -420,6 +431,54 @@ def gather_fn(tensor): return sharded_tensor +def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), "The shard_fn must be callable." + assert callable(gather_fn), "The gather_fn must be callable." + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + + + # set the shard_fn and gather_fn as attributes of the distributed tensor + tensor.shard_fn = shard_fn + tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(tensor) + + return tensor + + def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: """ Gather the given tensor to the global tensor. diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 565f50c90dd1..ade0a4909902 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -17,6 +17,7 @@ from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.checkpoint_io.utils import gather_distributed_param from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -24,6 +25,18 @@ from .memory_tracer import MemStats, OrderedParamGenerator from .utils import get_temp_total_chunk_on_cuda +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + init_tensor_as_customization_distributed, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + get_global_shape, + init_as_dtensor +) + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -318,9 +331,7 @@ def backward(self, loss: torch.Tensor): self._post_backward() def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - torch.autograd.backward(tensor, grad) - self._post_backward() + raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") def grad_handle(self, p, grad): setattr(p, "_gemini_reduced", True) @@ -431,7 +442,18 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).to(tensor.device) + if is_distributed_tensor(tensor): + global_shape = get_global_shape(tensor) + device_mesh = get_device_mesh(tensor) + shard_spec = get_sharding_spec(tensor) + record_tensor = init_as_dtensor(record_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape = global_shape) + elif is_customized_distributed_tensor(tensor): + init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) + record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -606,10 +628,16 @@ def _load_from_state_dict( local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(param_name, dest_tensor, copy_func): + def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): state_key = prefix + param_name if state_key in state_dict: input_param = state_dict[state_key] + + if source_device_mesh is not None and source_sharding_spec is not None: + input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) + elif shard_fn is not None and gather_fn is not None: + input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] @@ -653,9 +681,19 @@ def load_parameter(chunk_slice, data): temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): + + source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None + if is_distributed_tensor(tensor): + # shard the input param + source_device_mesh = get_device_mesh(tensor) + source_sharding_spec = get_sharding_spec(tensor) + elif is_customized_distributed_tensor(tensor): + shard_fn = tensor.shard_fn + gather_fn = tensor.gather_fn + parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_parameter, parameter_slice)) + load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -724,7 +762,8 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi if self.master_weights: # create a fp32 parameter - fp32_p = p.data.float() + fp32_p = p.clone() + fp32_p.data = fp32_p.data.float() self.chunk_manager.register_tensor( tensor=fp32_p, group_type="fp32_param", diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 0d0298e067f3..e20d846f1071 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -9,6 +9,7 @@ from packaging.version import Version from torch.nn import Parameter from torch.optim import Optimizer +from torch.distributed import ProcessGroup from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder @@ -19,6 +20,18 @@ from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP +from colossalai.checkpoint_io.utils import gather_distributed_param +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + init_tensor_as_customization_distributed, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + get_global_shape, + init_as_dtensor +) __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] @@ -93,6 +106,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, + tp_group: ProcessGroup = None, + optimizer_params_info=None, verbose: bool = False, **defaults: Any, ): @@ -109,6 +124,10 @@ def __init__( self.chunk16_set: Set[Chunk] = set() self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm + self.tp_group = tp_group + self.optimizer_params_info = optimizer_params_info + self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 + self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() @@ -406,8 +425,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: param = self.id_to_real_params[param_id] fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) - process_group = chunk.torch_pg - rank = dist.get_rank(process_group) + dp_group = chunk.torch_pg + rank = dist.get_rank(dp_group) master_rank = 0 collected_states = {} @@ -415,9 +434,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: local_state_names = None if fake_param is not None: local_state_names = list(self.optim.state[fake_param].keys()) - gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] + gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] dist.barrier() - dist.all_gather_object(gathered_state_names, local_state_names) + dist.all_gather_object(gathered_state_names, local_state_names, dp_group) state_names = None for names in gathered_state_names: if names is not None: @@ -436,6 +455,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # Every rank is collector when only_rank_0 is False. is_collector = (rank == master_rank) or (not only_rank_0) + # get tensor parallelism information + is_dtensor = is_distributed_tensor(param) + is_customized_distributed = is_customized_distributed_tensor(param) + shard_spec = get_sharding_spec(param) if is_dtensor else None + device_mesh = get_device_mesh(param) if is_dtensor else None + global_shape = self.optimizer_params_info["id2shape"][param_id] + # If the chunk is kept gathered, # the parameteres are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. @@ -451,7 +477,18 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: ).cpu() else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() - collected_states[state_name] = torch.reshape(state_tensor, param.shape) + if is_dtensor: + state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) + state_tensor = init_as_dtensor(state_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape = global_shape) + elif is_customized_distributed: + state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) + init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() + + collected_states[state_name] = state_tensor.reshape(global_shape) return collected_states # Check whether the param with given id is managed by current process. @@ -473,7 +510,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: _, shard_offset, shard_size = self.get_offsets(param_id) # Collectors gather state shards through all_gathering. - gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] + gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] dist.barrier() dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) @@ -494,6 +531,16 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: for state_name, state_tensor in collected_states.items(): if state_tensor.numel() == param.numel(): collected_states[state_name] = torch.reshape(state_tensor, param.shape) + if is_dtensor: + state_tensor = state_tensor.to(param.device) + state_tensor = init_as_dtensor(state_tensor, + sharding_spec=shard_spec, + device_mesh=device_mesh, + global_shape=global_shape) + elif is_customized_distributed: + state_tensor = state_tensor.to(param.device) + init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() return collected_states @@ -658,6 +705,14 @@ def cast(param, state_range, value, key=None): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) + + if is_dtensor: + value = torch.reshape(value, global_shape) + value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) + elif is_customized_distributed: + value = torch.reshape(value, global_shape) + value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn) + ret_val.copy_(value.flatten()[state_start:state_end]) return ret_val @@ -668,6 +723,15 @@ def cast(param, state_range, value, key=None): # Copy states assigned to param (and cast tensors to appropriate types). updated_states = dict() + + # get tensor parallelism information + real_param = self.id_to_real_params[param_id] + is_dtensor = is_distributed_tensor(real_param) + is_customized_distributed = is_customized_distributed_tensor(real_param) + shard_spec = get_sharding_spec(real_param) if is_dtensor else None + device_mesh = get_device_mesh(real_param) if is_dtensor else None + global_shape = self.optimizer_params_info["id2shape"][param_id] + for k, v in saved_states.items(): updated_states[k] = cast(fake_param, state_range, v, k) del v # clean loaded states diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 00ff6cb37d2a..97ec0233f766 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -10,18 +10,21 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) + enable_all_optimization = True if enable_tensor_parallelism else False + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -46,6 +49,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ booster.backward(loss, optimizer) optimizer.step() + except NotImplementedError: + print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") except Exception as e: # raise e return repr(e) @@ -57,7 +62,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): +@parameterize("enable_tensor_parallelism", [True, False]) +def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True): """check gemini plugin over model zoo Args: @@ -116,7 +122,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool "torchvision_efficientnet_v2_s", ]: continue - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + + # TODO debug blip2 when using tp, something wrong with shift_logits's shape + if "transformers_blip2" in name: + enable_tensor_parallelism = False + + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) torch.cuda.empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index f876040384b3..821ce9fbbbd9 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -37,17 +37,20 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("tp_size", [2]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() + enable_all_optimization = True if enable_tensor_parallelism else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config) + plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -63,13 +66,16 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) -@parameterize("shard", [False, True]) +@parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("tp_size", [2]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) + enable_all_optimization = True if enable_tensor_parallelism else False + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) model = model_fn() @@ -148,7 +154,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size)