Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 3, 2024
1 parent d34bec9 commit 1ed7f7f
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 82 deletions.
9 changes: 5 additions & 4 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
meta_config[type] = getattr(model_config, type)

return GenerationConfig.from_dict(meta_config)

def to_model_inference_config(self) -> "ModelInferenceConfig":
model_inference_config = ModelInferenceConfig(
dtype=self.dtype,
Expand Down Expand Up @@ -352,21 +352,22 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
inference_config = cls(**inference_config_args)
return inference_config


@dataclass
class ModelInferenceConfig():
class ModelInferenceConfig:
"""
Configurations used when initializing/sharding model for inference.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
"""

dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
use_cuda_graph: bool = False

14 changes: 8 additions & 6 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec

# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
# We can add a SpecDecConfig class to store these configs.
self.drafter_model = None
Expand All @@ -109,8 +109,8 @@ def __init__(
self._verify_args()

def init_model(
self,
model_or_path: Union[nn.Module, str],
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
):
"""
Expand All @@ -133,7 +133,7 @@ def init_model(
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
raise ValueError(f"Model {arch} is not supported.")

except Exception as e:
self.logger.error(
Expand Down Expand Up @@ -357,8 +357,10 @@ def enable_spec_dec(
engine.clear_spec_dec()
```
"""
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")

self.logger.warning(
f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead."
)

if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
Expand Down
15 changes: 5 additions & 10 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import InputMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.logging import get_dist_logger
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
flash_decoding_attention,
)
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
from colossalai.logging import get_dist_logger

logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
Expand Down Expand Up @@ -125,9 +122,7 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
)


def get_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
) -> AttentionBackend:
def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. Only when:
1. using CUDA kernel (use_cuda_kernel=True)
Expand Down
15 changes: 5 additions & 10 deletions colossalai/inference/modeling/backends/pre_attention_backend.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from abc import ABC, abstractmethod

import torch

from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from colossalai.logging import get_dist_logger
from colossalai.kernel.triton import (
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
rotary_embedding,
)

logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
Expand Down Expand Up @@ -94,7 +91,7 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
Expand All @@ -118,9 +115,7 @@ def decode(self, attn_metadata: AttentionMetaData, **kwargs):
)


def get_pre_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
) -> PreAttentionBackend:
def get_pre_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
"""
Expand Down
39 changes: 17 additions & 22 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor


inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)

Expand Down Expand Up @@ -233,7 +224,7 @@ def forward(
)

block_size = k_cache.size(-2)

attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
Expand All @@ -252,10 +243,14 @@ def forward(
use_alibi_attn=self.use_alibi_attn,
use_cuda_kernel=use_cuda_kernel,
)

attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)


attention_backend = get_attention_backend(
use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype
)
pre_attention_backend = get_pre_attention_backend(
use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype
)

if is_prompts: # prefilling stage
pre_attention_backend.prefill(
attn_metadata,
Expand All @@ -266,19 +261,19 @@ def forward(
attn_output = attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1

pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
q_len=q_len,
)

Expand Down
39 changes: 17 additions & 22 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,11 @@

from colossalai.inference.config import InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
Expand Down Expand Up @@ -527,7 +519,7 @@ def forward(
)

block_size = k_cache.size(-2)

attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
Expand All @@ -546,10 +538,14 @@ def forward(
use_alibi_attn=False,
use_cuda_kernel=use_cuda_kernel,
)

attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)


attention_backend = get_attention_backend(
use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype
)
pre_attention_backend = get_pre_attention_backend(
use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype
)

if is_prompts: # prefilling stage
pre_attention_backend.prefill(
attn_metadata,
Expand All @@ -560,22 +556,22 @@ def forward(
attn_output = attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1

pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
num_key_value_groups=self.num_key_value_groups,
q_len=q_len,
)
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -633,4 +629,3 @@ def _load_from_state_dict(

def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"

17 changes: 9 additions & 8 deletions colossalai/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
Utils for model inference
"""
import math
import os
import re
import math
from pathlib import Path
from typing import Optional, Tuple

import torch
from torch import nn

from colossalai.testing import free_port
from colossalai.logging import get_dist_logger
from colossalai.testing import free_port

logger = get_dist_logger(__name__)

Expand Down Expand Up @@ -122,11 +122,11 @@ def find_available_ports(num: int):
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of attention heads.
device (torch.device): The device to use.
Returns:
torch.Tensor: The Alibi slopes.
"""
Expand All @@ -142,20 +142,21 @@ def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes


def can_use_flash_attn2(dtype: torch.dtype) -> bool:
"""
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
return False

try:
from flash_attn import __version__

logger.info(f"flash_attn2 version {__version__}.")
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
return False

0 comments on commit 1ed7f7f

Please sign in to comment.