Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Feb 23, 2024
1 parent e2aa82e commit e1402fd
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 69 deletions.
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
print("self.shard_config", self.shard_config)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
Expand All @@ -1058,6 +1059,7 @@ def __init__(
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=torch.bfloat16,
)

self.max_norm = max_norm
Expand Down Expand Up @@ -1099,6 +1101,7 @@ def configure(
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
print("use_ddp", use_ddp)
model = HybridParallelModule(
model,
precision=self.precision,
Expand Down
22 changes: 12 additions & 10 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

try:
import fused_weight_gradient_mlp_cuda

_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
Expand Down Expand Up @@ -77,10 +78,11 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
bias = bias.view(bias.shape)
# weight = weight.view(weight.shape)
# bias = bias.view(bias.shape)

total_input = input
# print("grad_output.shape", grad_output.shape, "weight.shape", weight.shape)
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
Expand All @@ -93,9 +95,10 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# _ = torch.empty(1, device=grad_output.device) + 1

grad_weight = total_input.t().matmul(grad_output)
# print("use_biasuse_biasuse_biasuse_biasuse_bias",use_bias)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce:
Expand All @@ -115,7 +118,6 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce

if bias is not None:
output = F.linear(input_, weight, bias)
else:
Expand All @@ -129,8 +131,8 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
# if use_bias:
# bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight)
Expand All @@ -145,7 +147,7 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
# _ = torch.empty(1, device=grad_output.device) + 1

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
Expand Down Expand Up @@ -205,8 +207,8 @@ def backward(ctx, grad_output):
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
# if use_bias:
# bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)
Expand Down Expand Up @@ -431,7 +433,7 @@ def backward(ctx, grad_output):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
Expand Down
7 changes: 3 additions & 4 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from colossalai.lazy import LazyInitContext

from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils

__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
Expand All @@ -29,7 +28,7 @@ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):

def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
# output = hook_paramter_in_backward(output, self.weight, self.bias)
return output

class FusedRMSNormWithHook(ApexFusedRMSNorm):
Expand All @@ -38,7 +37,7 @@ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):

def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
# output = hook_paramter_in_backward(output, self.weight)
return output

except ImportError:
Expand Down Expand Up @@ -79,7 +78,7 @@ def __init__(self, hidden_size, eps=0.00001):

def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
# output = hook_paramter_in_backward(output, self.weight, self.bias)
return output


Expand Down
17 changes: 10 additions & 7 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication
)
input_parallel = input_
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)

if self.gather_output:
# All-gather across the partitions.
Expand All @@ -331,7 +330,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
output = output_parallel

if self.skip_bias_add:
return output, self.bias
# return output, self.bias
return output
else:
return output

Expand Down Expand Up @@ -528,7 +528,8 @@ def forward(self, input_: Tensor) -> Tensor:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = torch.matmul(input_, self.weight)
# output_parallel = torch.matmul(input_, self.weight)
output_parallel = matmul_with_async_comm(input_, self.weight, None, self.process_group, False)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
Expand All @@ -539,7 +540,8 @@ def forward(self, input_: Tensor) -> Tensor:
output = output + self.bias
return output
else:
return output, self.bias
# return output, self.bias
return output


# ====================================
Expand Down Expand Up @@ -734,6 +736,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
output = output_parallel

if self.skip_bias_add:
return output, self.bias
# return output, self.bias
return output
else:
return output
13 changes: 7 additions & 6 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,13 @@ def gpt2_lmhead_model_forward(
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
loss = loss_fct(shift_logits, shift_labels)
# if shard_config.enable_tensor_parallelism:
# loss = cross_entropy_1d(
# shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
# )
# else:
# loss = loss_fct(shift_logits, shift_labels)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down
52 changes: 25 additions & 27 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

import colossalai.shardformer.layer as col_nn

from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -58,10 +53,10 @@ def module_policy(self):
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
# SubModuleReplacementDescription(
# suffix="drop",
# target_module=col_nn.DropoutForParallelInput,
# ),
]
)

Expand All @@ -87,27 +82,30 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
"skip_bias_add": True,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel": use_sequence_parallel, "skip_bias_add": True},
),
# SubModuleReplacementDescription(
# suffix="attn.attn_dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
# SubModuleReplacementDescription(
# suffix="attn.resid_dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.dropout",
# target_module=col_nn.DropoutForParallelInput,
# ),
],
)

Expand Down Expand Up @@ -271,10 +269,10 @@ def module_policy(self):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False}
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
# method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
module_policy.update(addon_module)
Expand Down
Loading

0 comments on commit e1402fd

Please sign in to comment.