Skip to content

Commit

Permalink
fix sp group reversed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jun 13, 2024
1 parent 932753f commit c92951d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 20 deletions.
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,8 @@ def __init__(
(
self.dp_axis,
self.pp_axis,
self.sp_axis,
self.tp_axis,
self.sp_axis,
) = (
0,
1,
Expand All @@ -1056,7 +1056,7 @@ def __init__(
)
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.sp_axis, self.tp_axis = 0, 1, 2, 3
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
Expand Down
25 changes: 19 additions & 6 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand Down Expand Up @@ -577,10 +578,10 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
Expand Down Expand Up @@ -625,15 +626,27 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
# if past_key_value is not None:
# kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# if past_key_value is not None:
# # reuse k, v, self_attention
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

past_key_value = (key_states, value_states) if use_cache else None

Expand Down
2 changes: 0 additions & 2 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
sp_partial_derived = sp_mode in ["split_gather", "ring"]

self.shard_config.enable_flash_attention

if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
Expand Down
28 changes: 18 additions & 10 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,25 +160,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"initial_scale": 1,
},
{
"tp_size": 4,
"tp_size": 1,
"pp_size": 1,
"num_microbatches": 1,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"tp_size": 4,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
Expand Down Expand Up @@ -227,7 +227,11 @@ def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e

clear_layout_converter()
Randomizer.reset_index()
Expand Down Expand Up @@ -277,7 +281,11 @@ def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e

clear_layout_converter()
Randomizer.reset_index()
Expand Down

0 comments on commit c92951d

Please sign in to comment.