diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 593f7df7a0ac..9521b5197d6d 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -18,7 +18,7 @@ from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig +from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens @@ -72,9 +72,8 @@ def __init__( self.verbose = verbose self.logger = get_dist_logger(__name__) - self.model_inference_config = inference_config.to_model_inference_config() - self.init_model(model_or_path, model_policy, self.model_inference_config) + self.init_model(model_or_path, model_policy) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() @@ -113,7 +112,6 @@ def init_model( self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None, - model_inference_config: ModelInferenceConfig = None, ): """ Shard model or/and Load weight @@ -178,7 +176,6 @@ def init_model( self.model = self._shardformer( model, model_policy, - model_inference_config, None, tp_group=tp_group, ) @@ -299,7 +296,6 @@ def _shardformer( self, model: nn.Module, model_policy: Policy, - model_inference_config: ModelInferenceConfig, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py index 4d82161313e7..ecdd9d4c4e5e 100644 --- a/colossalai/inference/modeling/backends/attention_backend.py +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -8,15 +8,16 @@ from colossalai.logging import get_dist_logger from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( - context_attention_unpadded, + context_attention_unpadded, flash_decoding_attention, ) logger = get_dist_logger(__name__) inference_ops = InferenceOpsLoader().load() - + + @dataclass -class AttentionMetaData: +class AttentionMetaData: query_states: torch.Tensor key_states: torch.Tensor value_states: torch.Tensor @@ -32,7 +33,8 @@ class AttentionMetaData: output_tensor: torch.Tensor = None use_spec_dec: bool = False use_alibi_attn: bool = False - + use_cuda_kernel: bool = False + class AttentionBackend(ABC): @abstractmethod @@ -42,46 +44,30 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs): @abstractmethod def decode(self, attn_metadatas: AttentionMetaData, **kwargs): raise NotImplementedError - + class CudaAttentionBackend(AttentionBackend): def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - if not attn_metadata.use_spec_dec: - token_nums = kwargs.get('token_nums', -1) - - attn_output = flash_attn_varlen_func( - attn_metadata.query_states, - attn_metadata.key_states, - attn_metadata.value_states, - cu_seqlens_q=attn_metadata.cu_seqlens, - cu_seqlens_k=attn_metadata.cu_seqlens, - max_seqlen_k=attn_metadata.kv_seq_len, - max_seqlen_v=attn_metadata.kv_seq_len, - dropout_p=0.0, - softmax_scale=attn_metadata.sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) - else: - attn_output = context_attention_unpadded( - q=attn_metadata.query_states, - k=attn_metadata.key_states, - v=attn_metadata.value_states, - k_cache=attn_metadata.k_cache, - v_cache=attn_metadata.v_cache, - context_lengths=attn_metadata.sequence_lengths, - block_tables=attn_metadata.block_tables, - block_size=attn_metadata.block_size, - output=attn_metadata.output_tensor, - max_seq_len=attn_metadata.kv_seq_len, - sm_scale=attn_metadata.sm_scale, - use_new_kcache_layout=True, - ) + token_nums = kwargs.get("token_nums", -1) + + attn_output = flash_attn_varlen_func( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + cu_seqlens_q=attn_metadata.cu_seqlens, + cu_seqlens_k=attn_metadata.cu_seqlens, + max_seqlen_q=attn_metadata.kv_seq_len, + max_seqlen_k=attn_metadata.kv_seq_len, + dropout_p=0.0, + softmax_scale=attn_metadata.sm_scale, + causal=True, + alibi_slopes=attn_metadata.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) return attn_output - - + def decode(self, attn_metadata: AttentionMetaData, **kwargs): - fd_inter_tensor = kwargs.get('fd_inter_tensor', None) + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) output_tensor = attn_metadata.output_tensor inference_ops.flash_decoding_attention( output_tensor, @@ -99,8 +85,8 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs): attn_metadata.sm_scale, ) return output_tensor - - + + class TritonAttentionBackend(AttentionBackend): def prefill(self, attn_metadata: AttentionMetaData, **kwargs): return context_attention_unpadded( @@ -113,13 +99,14 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs): block_tables=attn_metadata.block_tables, block_size=attn_metadata.block_size, output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, max_seq_len=attn_metadata.kv_seq_len, sm_scale=attn_metadata.sm_scale, - use_new_kcache_layout=False, + use_new_kcache_layout=attn_metadata.use_cuda_kernel, ) - + def decode(self, attn_metadata: AttentionMetaData, **kwargs): - fd_inter_tensor = kwargs.get('fd_inter_tensor', None) + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) return flash_decoding_attention( q=attn_metadata.query_states, k_cache=attn_metadata.k_cache, @@ -131,16 +118,25 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs): output=attn_metadata.output_tensor, mid_output=fd_inter_tensor.mid_output, mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=attn_metadata.alibi_slopes, sm_scale=attn_metadata.sm_scale, - kv_group_num=kwargs.get('num_key_value_groups', 0), - q_len=kwargs.get('q_len', 1), + kv_group_num=kwargs.get("num_key_value_groups", 1), + q_len=kwargs.get("q_len", 1), ) - - -def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend: - use_flash_attn = can_use_flash_attn2(dtype) + + +def get_attention_backend( + use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype +) -> AttentionBackend: + """ + Get the attention backend based on the inference configurations. Only when: + 1. using CUDA kernel (use_cuda_kernel=True) + 2. can use flash attention (flash-attn installed and dtype is fp16 or bf16) + 3. not using speculative decoding (currently cuda kernel not support speculative decoding) + will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend. + """ + use_flash_attn = can_use_flash_attn2(dtype) if use_cuda_kernel and use_flash_attn and not use_spec_dec: return CudaAttentionBackend() else: return TritonAttentionBackend() - \ No newline at end of file diff --git a/colossalai/inference/modeling/backends/attention_context.py b/colossalai/inference/modeling/backends/pre_attention_backend.py similarity index 63% rename from colossalai/inference/modeling/backends/attention_context.py rename to colossalai/inference/modeling/backends/pre_attention_backend.py index 909660121d4f..73cf325920a0 100644 --- a/colossalai/inference/modeling/backends/attention_context.py +++ b/colossalai/inference/modeling/backends/pre_attention_backend.py @@ -13,60 +13,52 @@ logger = get_dist_logger(__name__) inference_ops = InferenceOpsLoader().load() - -class AttentionContext(ABC): + +class PreAttentionBackend(ABC): @abstractmethod def prefill(self, attn_metadata: AttentionMetaData, **kwargs): raise NotImplementedError - + @abstractmethod def decode(self, attn_metadata: AttentionMetaData, **kwargs): raise NotImplementedError - - -class CudaAttentionContext(AttentionContext): + + +class CudaPreAttentionBackend(PreAttentionBackend): def prefill(self, attn_metadata: AttentionMetaData, **kwargs): - if not attn_metadata.use_spec_dec: - if not attn_metadata.use_alibi_attn: - inference_ops.rotary_embedding( - attn_metadata.query_states, - attn_metadata.key_states, - kwargs.get('cos', None), - kwargs.get('sin', None), - kwargs.get('high_precision', False), - ) - inference_ops.context_kv_cache_memcpy( - attn_metadata.key_states, - attn_metadata.value_states, - attn_metadata.k_cache, - attn_metadata.v_cache, - attn_metadata.sequence_lengths, - attn_metadata.cu_seqlens, - attn_metadata.block_tables, - attn_metadata.kv_seq_len, - ) - else: - rotary_embedding( + if not attn_metadata.use_alibi_attn: + inference_ops.rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, - kwargs.get('cos', None), - kwargs.get('sin', None), - ) - + kwargs.get("cos", None), + kwargs.get("sin", None), + kwargs.get("high_precision", False), + ) + inference_ops.context_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.cu_seqlens, + attn_metadata.block_tables, + attn_metadata.kv_seq_len, + ) + def decode(self, attn_metadata: AttentionMetaData, **kwargs): - if attn_metadata.use_alibi_attn: + if not attn_metadata.use_alibi_attn: inference_ops.rotary_embedding_and_cache_copy( attn_metadata.query_states, attn_metadata.key_states, attn_metadata.value_states, - kwargs.get('cos', None), - kwargs.get('sin', None), + kwargs.get("cos", None), + kwargs.get("sin", None), attn_metadata.k_cache, attn_metadata.v_cache, attn_metadata.sequence_lengths, attn_metadata.block_tables, - attn_metadata.high_precision, + kwargs.get("high_precision", None), ) else: inference_ops.decode_kv_cache_memcpy( @@ -77,58 +69,63 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs): attn_metadata.sequence_lengths, attn_metadata.block_tables, ) - - -class TritonAttentionContext(AttentionContext): + + +class TritonPreAttentionBackend(PreAttentionBackend): def prefill(self, attn_metadata: AttentionMetaData, **kwargs): if not attn_metadata.use_alibi_attn: rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, - kwargs.get('cos', None), - kwargs.get('sin', None), + kwargs.get("cos", None), + kwargs.get("sin", None), ) - + def decode(self, attn_metadata: AttentionMetaData, **kwargs): if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn: decoding_fused_rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, attn_metadata.value_states, - kwargs.get('cos', None), - kwargs.get('sin', None), + kwargs.get("cos", None), + kwargs.get("sin", None), attn_metadata.k_cache, attn_metadata.v_cache, attn_metadata.block_tables, attn_metadata.sequence_lengths, ) - else: + else: # else if using speculative decoding if not attn_metadata.use_alibi_attn: rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, - kwargs.get('cos', None), - kwargs.get('sin', None), + kwargs.get("cos", None), + kwargs.get("sin", None), ) copy_k_to_blocked_cache( attn_metadata.key_states, attn_metadata.k_cache, kv_lengths=attn_metadata.sequence_lengths, block_tables=attn_metadata.block_tables, - n=kwargs.get('q_len', 1) + n=kwargs.get("q_len", 1), ) copy_k_to_blocked_cache( attn_metadata.value_states, attn_metadata.v_cache, kv_lengths=attn_metadata.sequence_lengths, block_tables=attn_metadata.block_tables, - n=kwargs.get('q_len', 1) + n=kwargs.get("q_len", 1), ) - - -def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext: + + +def get_pre_attention_backend( + use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype +) -> PreAttentionBackend: + """ + Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization. + """ use_flash_attn = can_use_flash_attn2(dtype) if use_cuda_kernel and use_flash_attn and not use_spec_dec: - return CudaAttentionContext() + return CudaPreAttentionBackend() else: - return TritonAttentionContext() \ No newline at end of file + return TritonPreAttentionBackend() diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index 920de0d8a86f..d722c80ea0c7 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -10,6 +10,8 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.utils import get_alibi_slopes from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( context_attention_unpadded, @@ -23,28 +25,8 @@ from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") inference_ops = InferenceOpsLoader().load() - logger = get_dist_logger(__name__) @@ -251,122 +233,54 @@ def forward( ) block_size = k_cache.size(-2) - - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - if not self.use_alibi_attn: - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ) - attn_output = attn_output.view(token_nums, -1) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - alibi_slopes=self.alibi_slopes, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: + + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=self.alibi_slopes, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=self.use_alibi_attn, + use_cuda_kernel=use_cuda_kernel, + ) + + attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) + pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) + + if is_prompts: # prefilling stage + pre_attention_backend.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - - if use_cuda_kernel: - if not self.use_alibi_attn: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - else: - inference_ops.decode_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - self.alibi_slopes, - sm_scale, - ) - attn_output = output_tensor - else: - if not is_verifier and not self.use_alibi_attn: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - alibi_slopes=self.alibi_slopes, - sm_scale=sm_scale, - q_len=q_len, - ) + + pre_attention_backend.decode( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + q_len=q_len, + ) + attn_output = attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 459c6e040e23..e9346017a976 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -19,7 +19,7 @@ from colossalai.inference.config import InputMetaData from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData -from colossalai.inference.modeling.backends.attention_context import get_attention_context +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend from colossalai.inference.utils import can_use_flash_attn2 from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( @@ -121,7 +121,7 @@ def llama_model_forward( cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) elif use_cuda_kernel: - if inputmetadata.dtype != torch.float32 and can_use_flash_attn2(): + if can_use_flash_attn2(inputmetadata.dtype): cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) @@ -544,13 +544,14 @@ def forward( output_tensor=output_tensor, use_spec_dec=is_verifier, use_alibi_attn=False, + use_cuda_kernel=use_cuda_kernel, ) attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) - attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) + pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) if is_prompts: # prefilling stage - attention_context.prefill( + pre_attention_backend.prefill( attn_metadata, cos=cos_sin[0], sin=cos_sin[1], @@ -563,7 +564,7 @@ def forward( else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - attention_context.decode( + pre_attention_backend.decode( attn_metadata, cos=cos_sin[0], sin=cos_sin[1],