Skip to content

Commit

Permalink
remove comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jun 17, 2024
1 parent 20059ae commit a2e07c9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 17 deletions.
1 change: 0 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,6 @@ def __init__(
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)

elif self.sequence_parallelism_mode in ["all_to_all"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
Expand Down
13 changes: 6 additions & 7 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,10 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
Expand Down Expand Up @@ -519,8 +520,6 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
# if past_key_value is not None:
# kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
Expand All @@ -538,8 +537,6 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -619,8 +616,10 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
Expand Down
16 changes: 8 additions & 8 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
Expand Down

0 comments on commit a2e07c9

Please sign in to comment.