Skip to content

Commit

Permalink
[gemini] gemini support tensor parallelism. (#4942)
Browse files Browse the repository at this point in the history
* [colossalai]fix typo

* [inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license

* Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
huggingface/transformers#25598

* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test

* [format] applied code formatting on changed files in pull request 4908 (#4918)

Co-authored-by: github-actions <github-actions@github.com>

* [gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case

* [hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit

* [test] add no master test for low level zero plugin (#4934)

* [format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>

* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)

* [Refactor] Integrated some lightllm kernels into token-attention  (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test

* [inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>

* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

* fix

fix

fix

* update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

* support fused layernorm

support fused layernorm

support fused layernorm

* update fusedlayernorm

update fusedlayernorm

update fusedlayernorm

* add sequence parallel to gemini

add sequence parallel to gemini

* fix

* fix comments

fix comments

fix comments

* fix

* fix t5

* clear cache

* fix

* activate ci

* activate ci

* fix

* fix

* fix

* fix

* revert

* modify tp gather method

modify tp gather method

modify tp gather method

modify tp gather method

* fix test

---------

Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
  • Loading branch information
21 people authored Nov 10, 2023
1 parent a448938 commit 576a2f7
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 67 deletions.
81 changes: 77 additions & 4 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
Expand All @@ -19,8 +20,9 @@
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
Expand All @@ -32,7 +34,25 @@
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}

DP_AXIS = 0
TP_AXIS = 1

def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["id2shape"][param_id] = original_shape

start_index += len(group["params"])

return param_info
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -284,6 +304,16 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""

Expand Down Expand Up @@ -317,6 +347,14 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -355,8 +393,32 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose

self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)

def support_no_sync(self) -> bool:
return False

Expand All @@ -380,6 +442,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
Expand All @@ -391,11 +454,21 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if self.enable_tensor_parallelism:
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)

model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
optimizer,
model,
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
verbose=self.verbose,
)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand All @@ -407,4 +480,4 @@ def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
raise NotImplementedError
1 change: 1 addition & 0 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ def get_group_along_axis(
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]

58 changes: 50 additions & 8 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
Expand All @@ -62,13 +62,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):

if bias is not None:
output = output + bias

return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
bias = bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
Expand Down Expand Up @@ -100,7 +105,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
Expand All @@ -109,13 +114,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
Expand Down Expand Up @@ -152,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -170,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -289,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -306,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -454,6 +473,29 @@ def forward(ctx, input_, dim, process_group):
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output

@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
if weight is not None:
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None


def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)



def _reduce(input_, process_group):
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def forward(self, input_: Tensor) -> Tensor:
)

# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(output_parallel, self.process_group)
output = reduce_forward(embedding_output, self.process_group)
return output
Loading

0 comments on commit 576a2f7

Please sign in to comment.