Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix

fix

fix

fix

fix

fix

fix
  • Loading branch information
flybird11111 committed Apr 10, 2024
1 parent b570f1a commit f08e084
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 203 deletions.
15 changes: 9 additions & 6 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ def get_param_info(model: nn.Module, optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

if optim is None:
return {}
param_info = {"id2shape": {}, "name2shape": {}}
for m_name, m_var in model.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + "." + p_name if m_name else p_name
original_shape = p_var.shape if isinstance(p_var, torch.Tensor) else None
param_info["name2shape"][param_name] = original_shape

if optim is None:
return param_info

start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["id2shape"][param_id] = original_shape

start_index += len(group["params"])
for name, param in model.named_parameters():
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["name2shape"][name] = original_shape
print("original_shape", original_shape)

return param_info

Expand Down
26 changes: 6 additions & 20 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ def save_sharded_model(

# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
# if self.dp_rank != 0:
# return
if self.dp_rank != 0:
return

# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0 and self.dp_rank == 0
control_saving = self.tp_rank == 0

if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
Expand Down Expand Up @@ -288,7 +288,7 @@ def save_sharded_model(
use_safetensors=use_safetensors,
use_pp_format=True,
)
dist.barrier(self.pp_group)

if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
Expand All @@ -298,6 +298,8 @@ def save_sharded_model(
else:
return

dist.barrier(self.pp_group)

# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
Expand Down Expand Up @@ -682,14 +684,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
print(
"barrier state dicts",
(
torch.distributed.get_rank(self.dp_group),
torch.distributed.get_rank(self.pp_group),
torch.distributed.get_rank(self.tp_group),
),
)
dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)

Expand All @@ -698,14 +692,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
print(
"before save_state_dict",
(
torch.distributed.get_rank(self.dp_group),
torch.distributed.get_rank(self.pp_group),
torch.distributed.get_rank(self.tp_group),
),
)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
Expand Down
12 changes: 6 additions & 6 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
"""
partition_dim = None
for dim, length in enumerate(original_shape):
if length != current_shape[dim]:
if length > current_shape[dim]:
partition_dim = dim
break
# if partition_dim is not None:
# assert (
# original_shape[partition_dim] == tp_size * current_shape[partition_dim]
# ), f"The parameter isn't evenly distributed among tensor parallel group: \
# shape before sharding {original_shape}, shape after sharding {current_shape}"
if partition_dim is not None:
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"

return partition_dim

Expand Down
23 changes: 7 additions & 16 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
)

from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import PaddingParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset

_EXTRA_STATE_KEY_SUFFIX = "_extra_state"

__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]


class Embedding1D(PaddingParallelModule):
class Embedding1D(ParallelModule):
r"""Embedding for 1D parallelism.
Args:
Expand Down Expand Up @@ -73,9 +71,12 @@ def __init__(
*args,
**kwargs,
):
super().__init__()

self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group

self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
Expand All @@ -88,12 +89,10 @@ def __init__(
# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)

super(Embedding1D, self).__init__(num_embeddings, num_embeddings, embedding_dim, weight)

self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
Expand Down Expand Up @@ -322,11 +321,6 @@ def __init__(
if weight is None:
self.reset_parameters(weight_initializer)

print(
f"embedding self.weight{self.num_embeddings} {self.old_num_embeddings}{dist.get_rank(self.process_group)}, bias{self.bias}",
self.weight.shape,
)

@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
Expand All @@ -346,8 +340,6 @@ def from_native_module(
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]

make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 128)

# create the parallel module
vocab_embedding_1d = VocabParallelEmbedding1D(
num_embeddings=num_embeddings,
Expand All @@ -356,7 +348,6 @@ def from_native_module(
device=device,
process_group=process_group,
weight=module.weight,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
*args,
**kwargs,
)
Expand Down
4 changes: 0 additions & 4 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def backward(ctx, grad_output):
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)

update = 1.0 - mask.view(-1).float()
print("masked_target_1d", masked_target_1d.dtype)
print("grad_logits_2d", grad_logits_2d.dtype)
print("update", update.dtype)
grad_logits_2d = grad_logits_2d.float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update

grad_logits.mul_(grad_output.unsqueeze(dim=-1))
Expand Down
70 changes: 0 additions & 70 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,76 +30,6 @@
logger = logging.get_logger(__name__)


def _get_attention_mask(
self: GPT2Model,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor],
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
batch_size, seq_len = hidden_states.shape[:2]
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if shard_config.enable_flash_attention:
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
dtype=hidden_states.dtype,
dtype2=encoder_hidden_states.dtype,
q_padding_mask=attention_mask,
kv_padding_mask=encoder_attention_mask,
)
else:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
if shard_config.enable_flash_attention:
encoder_attention_mask = {"attention_mask": None}
else:
encoder_attention_mask = None
# GPT2Attention mask.
past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if shard_config.enable_flash_attention:
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
return attention_mask, encoder_attention_mask


logger = logging.get_logger(__name__)


def _get_attention_mask(
self: GPT2Model,
shard_config: ShardConfig,
Expand Down
9 changes: 9 additions & 0 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []

def tie_weight_check(self):
input_embedding = self.model.get_input_embeddings()
output_embedding = self.model.get_output_embeddings()
return (
input_embedding is not None
and output_embedding is not None
and id(input_embedding.weight) == id(output_embedding.weight)
)
17 changes: 15 additions & 2 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def module_policy(self):

policy = {}

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
col_nn.VocabParallelEmbedding1D
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
col_nn.PaddingEmbedding
embedding_cls = col_nn.PaddingEmbedding

if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
Expand Down Expand Up @@ -160,6 +161,18 @@ def module_policy(self):
target_key=BertModel,
)

if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
)
],
policy=policy,
target_key=BertEmbeddings,
)

# optimization configuration
# Handle bert layer
self.append_or_create_submodule_replacement(
Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
Expand Down Expand Up @@ -315,6 +316,10 @@ def module_policy(self):
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
Expand Down
15 changes: 5 additions & 10 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_llama_model_forward_for_flash_attn,
get_llama_seq_parallel_attention_forward,
get_llama_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down Expand Up @@ -184,16 +185,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=LlamaModel,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=PaddingEmbedding,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=LlamaModel,
)

# optimization configuration
self.append_or_create_submodule_replacement(
Expand Down Expand Up @@ -355,6 +346,10 @@ def module_policy(self):
],
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
make_vocab_size_divisible_by: int = 64
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
Loading

0 comments on commit f08e084

Please sign in to comment.