Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix

fix

fix

fix
  • Loading branch information
flybird11111 committed Feb 26, 2024
1 parent e1402fd commit acccb4b
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 97 deletions.
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}


def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
Expand Down Expand Up @@ -1033,7 +1035,6 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
print("self.shard_config", self.shard_config)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
Expand All @@ -1059,7 +1060,7 @@ def __init__(
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=torch.bfloat16,
forced_dtype=PRECISION_TORCH_TYPE[precision],
)

self.max_norm = max_norm
Expand Down Expand Up @@ -1101,7 +1102,6 @@ def configure(
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
print("use_ddp", use_ddp)
model = HybridParallelModule(
model,
precision=self.precision,
Expand Down
23 changes: 11 additions & 12 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
# weight = weight.view(weight.shape)
# bias = bias.view(bias.shape)
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)

total_input = input
# print("grad_output.shape", grad_output.shape, "weight.shape", weight.shape)
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
Expand All @@ -93,12 +93,11 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
# _ = torch.empty(1, device=grad_output.device) + 1
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = total_input.t().matmul(grad_output)
# print("use_biasuse_biasuse_biasuse_biasuse_bias",use_bias)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce:
Expand Down Expand Up @@ -131,8 +130,8 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
# if use_bias:
# bias.view(bias.shape)
if use_bias:
bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight)
Expand All @@ -145,9 +144,9 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
# _ = torch.empty(1, device=grad_output.device) + 1
_ = torch.empty(1, device=grad_output.device) + 1

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -207,8 +206,8 @@ def backward(ctx, grad_output):
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
# if use_bias:
# bias = bias.view(bias.shape)
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)
Expand Down
7 changes: 4 additions & 3 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from colossalai.lazy import LazyInitContext

from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils

__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
Expand All @@ -28,7 +29,7 @@ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):

def forward(self, input):
output = super().forward(input)
# output = hook_paramter_in_backward(output, self.weight, self.bias)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output

class FusedRMSNormWithHook(ApexFusedRMSNorm):
Expand All @@ -37,7 +38,7 @@ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):

def forward(self, input):
output = super().forward(input)
# output = hook_paramter_in_backward(output, self.weight)
output = hook_paramter_in_backward(output, self.weight)
return output

except ImportError:
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(self, hidden_size, eps=0.00001):

def forward(self, input):
output = super().forward(input)
# output = hook_paramter_in_backward(output, self.weight, self.bias)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output


Expand Down
17 changes: 7 additions & 10 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication
)

if self.gather_output:
# All-gather across the partitions.
Expand All @@ -330,8 +331,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
output = output_parallel

if self.skip_bias_add:
# return output, self.bias
return output
return output, self.bias
else:
return output

Expand Down Expand Up @@ -528,8 +528,7 @@ def forward(self, input_: Tensor) -> Tensor:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
# output_parallel = torch.matmul(input_, self.weight)
output_parallel = matmul_with_async_comm(input_, self.weight, None, self.process_group, False)
output_parallel = torch.matmul(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
Expand All @@ -540,8 +539,7 @@ def forward(self, input_: Tensor) -> Tensor:
output = output + self.bias
return output
else:
# return output, self.bias
return output
return output, self.bias


# ====================================
Expand Down Expand Up @@ -736,7 +734,6 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
output = output_parallel

if self.skip_bias_add:
# return output, self.bias
return output
return output, self.bias
else:
return output
21 changes: 9 additions & 12 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,12 @@ def gpt2_lmhead_model_forward(
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# if shard_config.enable_tensor_parallelism:
# loss = cross_entropy_1d(
# shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
# )
# else:
# loss = loss_fct(shift_logits, shift_labels)
loss = loss_fct(shift_logits, shift_labels)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down Expand Up @@ -727,7 +726,7 @@ def gpt2_for_sequence_classification_forward(
)


def get_gpt2_flash_attention_forward(shard_config: ShardConfig):
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
Expand Down Expand Up @@ -778,12 +777,10 @@ def forward(
else:
present = None

flash_attention_mask = None
if not self.is_cross_attention:
attn_mask_type = AttnMaskType.causal
else:
attn_mask_type = None
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
flash_attention_mask = None
if attention_mask != None:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
Expand Down
54 changes: 27 additions & 27 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import colossalai.shardformer.layer as col_nn

from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -53,10 +58,10 @@ def module_policy(self):
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
# SubModuleReplacementDescription(
# suffix="drop",
# target_module=col_nn.DropoutForParallelInput,
# ),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
]
)

Expand All @@ -82,30 +87,25 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
"skip_bias_add": True,
},
kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel": use_sequence_parallel, "skip_bias_add": True},
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
# SubModuleReplacementDescription(
# suffix="attn.attn_dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
# SubModuleReplacementDescription(
# suffix="attn.resid_dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
],
)

Expand Down Expand Up @@ -145,7 +145,7 @@ def module_policy(self):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_gpt2_flash_attention_forward(self.shard_config),
"forward": get_gpt2_flash_attention_forward(),
},
policy=policy,
target_key=GPT2Attention,
Expand Down Expand Up @@ -269,10 +269,10 @@ def module_policy(self):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False}
)
],
# method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
module_policy.update(addon_module)
Expand Down
32 changes: 3 additions & 29 deletions examples/language/gpt/hybridparallelism/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,7 @@ def empty_init():
pp_style=args.pp_style,
zero_stage=args.zero,
num_model_chunks=args.num_model_chunks,
# enable_all_optimization=True,
# enable_flash_attention=True,
# enable_jit_fused=True,
enable_fused_normalization=True,
# enable_sequence_parallelism=True,
enable_all_optimization=True,
num_microbatches=args.mbs,
cpu_offload=args.cpu_offload,
precision="bf16",
Expand Down Expand Up @@ -176,18 +172,14 @@ def empty_init():
else nullcontext()
)

# with init_ctx:
# model = GPT2LMHeadModel(config)
model = GPT2LMHeadModel(config)
with init_ctx:
model = GPT2LMHeadModel(config)

if args.grad_checkpoint:
model.gradient_checkpointing_enable()

model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
# print("args.ignore_steps", args.ignore_steps)
# print("args.batch_size", args.batch_size)
# print("max_length", args.max_length)
performance_evaluator = PerformanceEvaluator(
model_numel,
model.config.n_layer,
Expand Down Expand Up @@ -226,24 +218,6 @@ def empty_init():
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)

# for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
# performance_evaluator.on_step_start(step)

# with torch.profiler.profile(
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=5),
# on_trace_ready=torch.profiler.tensorboard_trace_handler("/home/jiangmingyan/workspace/trace/shardformer/GPT2-12-bf16"),
# with_stack=True,
# record_shapes=True
# ) as prof:
# for _ in range(0 + 2 + 5):
# outputs = model(**batch)
# loss = outputs[0]
# booster.backward(loss, optimizer)
# optimizer.step()
# optimizer.zero_grad()
# prof.step()
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

performance_evaluator.on_fit_end()
Expand Down
1 change: 0 additions & 1 deletion examples/language/performance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def on_step_end(self, input_ids: Tensor, **kwargs) -> None:

def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
# avg_duration = self.timer.duration
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.coordinator.world_size // self.dp_world_size
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
Expand Down

0 comments on commit acccb4b

Please sign in to comment.