Skip to content

Commit

Permalink
Fix tests and naming
Browse files Browse the repository at this point in the history
Signed-off-by: char-1ee <xingjianli59@gmail.com>
  • Loading branch information
char-1ee committed Jun 3, 2024
1 parent 45c747b commit d34bec9
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 250 deletions.
8 changes: 2 additions & 6 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
Expand Down Expand Up @@ -72,9 +72,8 @@ def __init__(

self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_inference_config = inference_config.to_model_inference_config()

self.init_model(model_or_path, model_policy, self.model_inference_config)
self.init_model(model_or_path, model_policy)

self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
Expand Down Expand Up @@ -113,7 +112,6 @@ def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_inference_config: ModelInferenceConfig = None,
):
"""
Shard model or/and Load weight
Expand Down Expand Up @@ -178,7 +176,6 @@ def init_model(
self.model = self._shardformer(
model,
model_policy,
model_inference_config,
None,
tp_group=tp_group,
)
Expand Down Expand Up @@ -299,7 +296,6 @@ def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_inference_config: ModelInferenceConfig,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
Expand Down
98 changes: 47 additions & 51 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from colossalai.logging import get_dist_logger
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
context_attention_unpadded,
flash_decoding_attention,
)

logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()



@dataclass
class AttentionMetaData:
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
Expand All @@ -32,7 +33,8 @@ class AttentionMetaData:
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False

use_cuda_kernel: bool = False


class AttentionBackend(ABC):
@abstractmethod
Expand All @@ -42,46 +44,30 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError


class CudaAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec:
token_nums = kwargs.get('token_nums', -1)

attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_k=attn_metadata.kv_seq_len,
max_seqlen_v=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True,
)
token_nums = kwargs.get("token_nums", -1)

attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.kv_seq_len,
max_seqlen_k=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
alibi_slopes=attn_metadata.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
return attn_output



def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
inference_ops.flash_decoding_attention(
output_tensor,
Expand All @@ -99,8 +85,8 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
attn_metadata.sm_scale,
)
return output_tensor


class TritonAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
Expand All @@ -113,13 +99,14 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=False,
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
Expand All @@ -131,16 +118,25 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=attn_metadata.alibi_slopes,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get('num_key_value_groups', 0),
q_len=kwargs.get('q_len', 1),
kv_group_num=kwargs.get("num_key_value_groups", 1),
q_len=kwargs.get("q_len", 1),
)


def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
use_flash_attn = can_use_flash_attn2(dtype)


def get_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. Only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionBackend()
else:
return TritonAttentionBackend()

Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,52 @@

logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()


class AttentionContext(ABC):

class PreAttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError

@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaAttentionContext(AttentionContext):


class CudaPreAttentionBackend(PreAttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec:
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get('high_precision', False),
)
inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
else:
rotary_embedding(
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
)

kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if attn_metadata.use_alibi_attn:
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.high_precision,
kwargs.get("high_precision", None),
)
else:
inference_ops.decode_kv_cache_memcpy(
Expand All @@ -77,58 +69,63 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonAttentionContext(AttentionContext):


class TritonPreAttentionBackend(PreAttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get("cos", None),
kwargs.get("sin", None),
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else:
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get("cos", None),
kwargs.get("sin", None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get('q_len', 1)
n=kwargs.get("q_len", 1),
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get('q_len', 1)
n=kwargs.get("q_len", 1),
)


def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext:


def get_pre_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionContext()
return CudaPreAttentionBackend()
else:
return TritonAttentionContext()
return TritonPreAttentionBackend()
Loading

0 comments on commit d34bec9

Please sign in to comment.