Skip to content

Commit

Permalink
Merge pull request #5771 from char-1ee/refactor/modeling
Browse files Browse the repository at this point in the history
[Inference] Refactor modeling attention layer by abstracting attention backends
  • Loading branch information
char-1ee authored Jun 10, 2024
2 parents 73e88a5 + b303976 commit 77a219a
Show file tree
Hide file tree
Showing 15 changed files with 531 additions and 301 deletions.
2 changes: 1 addition & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
- POST '/chat':
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
#### chat-template
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
### Usage
#### Args for customizing your server
The configuration for api server contains both serving interface and engine backend.
Expand Down
33 changes: 32 additions & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.generation import GenerationConfig

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2

GibiByte = 1024**3

Expand Down Expand Up @@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
Expand Down Expand Up @@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
ignore_eos: bool = False

# speculative decoding configs
use_spec_dec: bool = False
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False

Expand Down Expand Up @@ -311,6 +314,16 @@ def to_generation_config(self, model_config) -> GenerationConfig:

return GenerationConfig.from_dict(meta_config)

def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
)
return model_inference_config

def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
Expand Down Expand Up @@ -362,3 +375,21 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config


@dataclass
class ModelShardInferenceConfig:
"""
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
"""

dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
38 changes: 26 additions & 12 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
Expand Down Expand Up @@ -72,8 +72,9 @@ def __init__(

self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()

self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)

self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
Expand All @@ -97,21 +98,29 @@ def __init__(
self.capture_model(self.k_cache, self.v_cache)

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.use_spec_dec = self.inference_config.use_spec_dec

self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens

self._verify_args()

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""

if isinstance(model_or_path, str):
Expand All @@ -124,6 +133,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")

except Exception as e:
Expand Down Expand Up @@ -167,6 +177,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
Expand All @@ -187,7 +198,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)

free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
Expand Down Expand Up @@ -287,6 +298,7 @@ def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
Expand All @@ -312,6 +324,7 @@ def _shardformer(
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
Expand Down Expand Up @@ -348,6 +361,7 @@ def enable_spec_dec(
engine.clear_spec_dec()
```
"""

if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
Expand Down Expand Up @@ -517,19 +531,19 @@ def generate(
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""

gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
Expand Down
Empty file.
168 changes: 168 additions & 0 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention


@dataclass
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
k_cache: torch.Tensor
v_cache: torch.Tensor
block_tables: torch.Tensor
block_size: int
kv_seq_len: int = None
sequence_lengths: torch.Tensor = None
cu_seqlens: torch.Tensor = None
sm_scale: int = None
alibi_slopes: torch.Tensor = None
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False


class AttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError

@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError


class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.kv_seq_len,
max_seqlen_k=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
alibi_slopes=attn_metadata.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k-cache layout
)
return attn_output

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor


class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
kv_seq_len=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
max_seq_len_in_batch=attn_metadata.kv_seq_len,
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=attn_metadata.alibi_slopes,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get("num_key_value_groups", 1),
q_len=kwargs.get("q_len", 1),
)


def get_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()

if model_shard_infer_config.use_cuda_kernel:
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)

return TritonAttentionBackend()
Loading

0 comments on commit 77a219a

Please sign in to comment.