diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 678dc7f706b4..6ccd1412ce2c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1046,8 +1046,8 @@ def __init__( ( self.dp_axis, self.pp_axis, - self.sp_axis, self.tp_axis, + self.sp_axis, ) = ( 0, 1, @@ -1056,7 +1056,7 @@ def __init__( ) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: - self.pp_axis, self.dp_axis, self.sp_axis, self.tp_axis = 0, 1, 2, 3 + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index da1643f795db..0cb62a7bf114 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -6,6 +6,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, @@ -577,10 +578,10 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if sp_mode is not None: assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert (sp_size is not None) and ( @@ -625,15 +626,27 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + # if past_key_value is not None: + # kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) past_key_value = (key_states, value_states) if use_cache else None diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index dd36605a81bf..f4ac79ba2713 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -80,8 +80,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) sp_partial_derived = sp_mode in ["split_gather", "ring"] - self.shard_config.enable_flash_attention - if sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a2f7d0d10c02..2c4f91fb8d58 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -160,25 +160,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "initial_scale": 1, }, { - "tp_size": 4, + "tp_size": 1, "pp_size": 1, - "num_microbatches": 1, + "sp_size": 2, + "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, { - "tp_size": 1, + "tp_size": 4, "pp_size": 1, - "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, "use_lazy_init": True, - "zero_stage": 2, "precision": "fp16", "initial_scale": 1, }, @@ -227,7 +227,11 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() @@ -277,7 +281,11 @@ def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index()