Skip to content

Commit

Permalink
Pass inference model shard configs for module init
Browse files Browse the repository at this point in the history
Signed-off-by: char-1ee <xingjianli59@gmail.com>
  • Loading branch information
char-1ee committed Jun 7, 2024
1 parent eec77e5 commit 5f398fc
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 136 deletions.
Binary file added colossalai/_C/.nfs0000000013155a3b0000021b
Binary file not shown.
21 changes: 11 additions & 10 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.generation import GenerationConfig

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2

GibiByte = 1024**3

Expand Down Expand Up @@ -312,13 +313,14 @@ def to_generation_config(self, model_config) -> GenerationConfig:
meta_config[type] = getattr(model_config, type)

return GenerationConfig.from_dict(meta_config)

def to_model_inference_config(self) -> "ModelInferenceConfig":
model_inference_config = ModelInferenceConfig(

def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_cuda_graph=self.use_cuda_graph,
use_flash_attn=use_flash_attn,
)
return model_inference_config

Expand Down Expand Up @@ -374,21 +376,20 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
inference_config = cls(**inference_config_args)
return inference_config


@dataclass
class ModelInferenceConfig():
class ModelShardInferenceConfig:
"""
Configurations used when initializing/sharding model for inference.
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
"""

dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
use_cuda_graph: bool = False

23 changes: 13 additions & 10 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
Expand Down Expand Up @@ -72,8 +72,9 @@ def __init__(

self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()

self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)

self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
Expand All @@ -98,9 +99,7 @@ def __init__(

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec

# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
# We can add a SpecDecConfig class to store these configs.

self.drafter_model = None
self.drafter = None
self.use_glide = False
Expand All @@ -109,9 +108,10 @@ def __init__(
self._verify_args()

def init_model(
self,
model_or_path: Union[nn.Module, str],
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Expand All @@ -120,6 +120,7 @@ def init_model(
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""

if isinstance(model_or_path, str):
Expand All @@ -133,7 +134,7 @@ def init_model(
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
raise ValueError(f"Model {arch} is not supported.")

except Exception as e:
self.logger.error(
Expand Down Expand Up @@ -176,6 +177,7 @@ def init_model(
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
Expand Down Expand Up @@ -296,6 +298,7 @@ def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
Expand All @@ -321,6 +324,7 @@ def _shardformer(
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
Expand Down Expand Up @@ -357,8 +361,7 @@ def enable_spec_dec(
engine.clear_spec_dec()
```
"""
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")


if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
Expand Down
104 changes: 82 additions & 22 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import InputMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.logging import get_dist_logger
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
flash_decoding_attention,
)

logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention


@dataclass
Expand All @@ -33,7 +26,6 @@ class AttentionMetaData:
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
use_cuda_kernel: bool = False


class AttentionBackend(ABC):
Expand All @@ -46,7 +38,16 @@ def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError


class CudaAttentionBackend(AttentionBackend):
class FlashAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
token_nums = kwargs.get("token_nums", -1)

Expand All @@ -69,7 +70,55 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
inference_ops.flash_decoding_attention(
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor


class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
Expand All @@ -88,6 +137,10 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):


class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
Expand All @@ -102,7 +155,7 @@ def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
use_new_kcache_layout=False,
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
Expand All @@ -126,17 +179,24 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):


def get_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. Only when:
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionBackend()
else:
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()

if model_shard_infer_config.use_cuda_kernel:
if model_shard_infer_config.use_flash_attn:
return FlashAttentionBackend()
return CudaAttentionBackend()

return TritonAttentionBackend()
Loading

0 comments on commit 5f398fc

Please sign in to comment.