Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[example]add gpt2 benchmark example script. #5295

Merged
merged 40 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6a2aaae
benchmark gpt2
flybird11111 Jan 12, 2024
01ee32a
fix
flybird11111 Jan 22, 2024
dde982f
Merge branch 'hpcaitech:main' into benchmark-gpt
flybird11111 Jan 22, 2024
9ce5280
[doc] fix typo in Colossal-LLaMA-2/README.md (#5247)
digger-yu Jan 10, 2024
929c32e
[workflow] fixed build CI (#5240)
FrankLeeeee Jan 10, 2024
03c6112
[ci] fixed booster test (#5251)
FrankLeeeee Jan 11, 2024
482f1ea
[ci] fixed ddp test (#5254)
FrankLeeeee Jan 11, 2024
907ee2a
fix typo in applications/ColossalEval/README.md (#5250)
digger-yu Jan 11, 2024
54aca87
[ci] fix shardformer tests. (#5255)
flybird11111 Jan 11, 2024
1b53824
[doc] fix doc typo (#5256)
binmakeswell Jan 11, 2024
94bd340
[hotfix]: add pp sanity check and fix mbs arg (#5268)
CWHer Jan 15, 2024
7f282f7
[workflow] fixed incomplete bash command (#5272)
FrankLeeeee Jan 16, 2024
6e158b7
[workflow] fixed oom tests (#5275)
FrankLeeeee Jan 16, 2024
ce924a1
[ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)
flybird11111 Jan 17, 2024
e7ca755
[shardformer] hybridparallelplugin support gradients accumulation. (#…
flybird11111 Jan 17, 2024
ef53503
[hotfix] Fix ShardFormer test execution path when using sequence para…
KKZ20 Jan 17, 2024
85860e7
fix auto loading gpt2 tokenizer (#5279)
MichelleMa8 Jan 18, 2024
b40cc06
[doc] add llama2-13B disyplay (#5285)
Desperado-Jia Jan 19, 2024
9422351
fix llama pretrain (#5287)
flybird11111 Jan 19, 2024
977888b
fix
flybird11111 Jan 22, 2024
c5279a6
Merge branch 'benchmark-gpt' of github.com:flybird11111/ColossalAI in…
flybird11111 Jan 22, 2024
f556e1d
fix
flybird11111 Jan 22, 2024
2da389d
fix
flybird11111 Jan 22, 2024
09267dc
fix
flybird11111 Jan 22, 2024
46f4c87
fix
flybird11111 Jan 22, 2024
d165dee
fix
flybird11111 Jan 22, 2024
d2593b8
fix
flybird11111 Jan 28, 2024
30ffe10
benchmark gpt2
flybird11111 Jan 12, 2024
1149884
fix
flybird11111 Jan 22, 2024
e5a33da
[workflow] fixed build CI (#5240)
FrankLeeeee Jan 10, 2024
c15223b
[ci] fixed booster test (#5251)
FrankLeeeee Jan 11, 2024
c642f32
fix
flybird11111 Jan 29, 2024
cc2fac8
fix
flybird11111 Jan 29, 2024
9249a5c
fix
flybird11111 Jan 29, 2024
e2aa82e
Merge branch 'benchmark-gpt' of github.com:flybird11111/ColossalAI in…
flybird11111 Jan 29, 2024
e1402fd
fix
flybird11111 Feb 23, 2024
acccb4b
fix
flybird11111 Feb 26, 2024
85f0cea
fix
flybird11111 Feb 26, 2024
e7e382a
Merge branch 'main' into benchmark-gpt
flybird11111 Feb 27, 2024
8728c3c
Update shardformer.py
flybird11111 Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,4 @@ jobs:
uses: actions/upload-artifact@v3
with:
name: report
path: report/
path: report/
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ jobs:
SERVER_URL: ${{github.server_url }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}


def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
Expand Down Expand Up @@ -1059,6 +1061,7 @@ def __init__(
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
)

self.max_norm = max_norm
Expand Down
27 changes: 12 additions & 15 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

try:
import fused_weight_gradient_mlp_cuda

_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
Expand Down Expand Up @@ -78,7 +79,8 @@ def backward(ctx, grad_output):

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
bias = bias.view(bias.shape)
if bias is not None:
bias = bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight.T)
Expand All @@ -91,9 +93,8 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
ver217 marked this conversation as resolved.
Show resolved Hide resolved
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand All @@ -115,7 +116,6 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce

if bias is not None:
output = F.linear(input_, weight, bias)
else:
Expand Down Expand Up @@ -143,9 +143,8 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -228,9 +227,8 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -394,9 +392,8 @@ def backward(ctx, grad_output):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down Expand Up @@ -431,7 +428,7 @@ def backward(ctx, grad_output):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
Expand Down
93 changes: 92 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d


class GPT2PipelineForwards:
"""
Expand Down Expand Up @@ -326,7 +328,15 @@ def gpt2_lmhead_model_forward(
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -1006,3 +1016,84 @@ def custom_forward(*inputs):
)

return forward


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import GPT2LMHeadModel

def forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]

lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)

return forward
82 changes: 60 additions & 22 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import colossalai.shardformer.layer as col_nn

from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -87,9 +92,7 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
},
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down Expand Up @@ -167,15 +170,35 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.ln_f)
else:
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers

def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
Expand All @@ -189,13 +212,27 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
}
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)


Expand Down Expand Up @@ -232,9 +269,10 @@ def module_policy(self):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False}
)
]
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
module_policy.update(addon_module)
Expand All @@ -249,7 +287,7 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down
4 changes: 4 additions & 0 deletions colossalai/shardformer/shard/shardformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, List, Tuple

import torch.nn as nn
Expand All @@ -9,6 +10,9 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder

# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
ver217 marked this conversation as resolved.
Show resolved Hide resolved


class ShardFormer:
"""
Expand Down
Empty file added examples/__init__.py
Empty file.
Empty file added examples/language/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def __getitem__(self, idx):
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
}
Loading
Loading