forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inference] add silu linear fusion for smoothquant llama mlp (hpcaite…
…ch#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests
- Loading branch information
Showing
5 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#include <torch/extension.h> | ||
|
||
#include "linear.h" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, | ||
"Linear SiLU (INT8)"); | ||
} |
162 changes: 162 additions & 0 deletions
162
colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu | ||
|
||
#include "linear.h" | ||
#include <cutlass/core_io.h> | ||
#include <cutlass/cutlass.h> | ||
#include <cutlass/half.h> | ||
|
||
#include <cutlass/gemm/device/gemm.h> | ||
#include <cutlass/numeric_types.h> | ||
#include <cutlass/util/host_tensor.h> | ||
#include <cutlass/epilogue/thread/linear_combination_silu.h> | ||
#include <cstdint> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <cuda_fp16.h> | ||
#include <iostream> | ||
#include <torch/torch.h> | ||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 | ||
torch::Tensor weight, // INT8 | ||
torch::Tensor bias, // FP32 | ||
float alpha, // FP32 | ||
float beta // FP32 | ||
) { | ||
auto M = input.size(0); | ||
auto N = weight.size(0); | ||
auto K = input.size(1); | ||
|
||
using ElementOutput = float; | ||
using ElementAccumulator = int32_t; | ||
using ElementComputeEpilogue = float; | ||
using ElementInputA = int8_t; // <- data type of elements in input matrix A | ||
using ElementInputB = int8_t; // <- data type of elements in input matrix B | ||
|
||
// The code section below describes matrix layout of input and output | ||
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major | ||
// for Matrix C | ||
using LayoutInputA = cutlass::layout::RowMajor; | ||
using LayoutInputB = cutlass::layout::ColumnMajor; | ||
using LayoutOutput = cutlass::layout::RowMajor; | ||
|
||
#if CUDA_ARCH >= 800 | ||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< | ||
ElementOutput, // <- data type of output matrix | ||
128 / cutlass::sizeof_bits< | ||
ElementOutput>::value, // <- this is the number of elements per | ||
// vectorized memory access. For half | ||
// precision, it's 8 elements. This | ||
// becomes the vector width of math | ||
// instructions in epilogue too | ||
ElementAccumulator, // <- data type of accumulator | ||
ElementComputeEpilogue // <- data type for alpha in linear combination | ||
// function | ||
>; | ||
using Gemm = cutlass::gemm::device::Gemm< | ||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, | ||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, | ||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, | ||
cutlass::gemm::GemmShape<256, 128, 64>, | ||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, | ||
EpilogueOp, | ||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; | ||
#elif CUDA_ARCH >= 750 | ||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< | ||
ElementOutput, // <- data type of output matrix | ||
128 / cutlass::sizeof_bits< | ||
ElementOutput>::value, // <- this is the number of elements per | ||
// vectorized memory access. For half | ||
// precision, it's 8 elements. This | ||
// becomes the vector width of math | ||
// instructions in epilogue too | ||
ElementAccumulator, // <- data type of accumulator | ||
ElementComputeEpilogue // <- data type for alpha in linear combination | ||
// function | ||
>; | ||
|
||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< | ||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | ||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; | ||
using Gemm = cutlass::gemm::device::Gemm< | ||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, | ||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, | ||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | ||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, | ||
DefaultGemmCfg::InstructionShape, | ||
EpilogueOp>; | ||
#elif CUDA_ARCH >= 700 | ||
#define USE_TORCH_SILU | ||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< | ||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70, | ||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; | ||
using Gemm = cutlass::gemm::device::Gemm< | ||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, | ||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, | ||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70, | ||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, | ||
DefaultGemmCfg::InstructionShape, | ||
cutlass::epilogue::thread::LinearCombination< | ||
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; | ||
#else | ||
#error "Unsupported cuda arch" | ||
#endif | ||
|
||
auto input_size = cutlass::MatrixCoord(M, K); | ||
auto weight_size = cutlass::MatrixCoord(K, N); | ||
auto output_size = cutlass::MatrixCoord(M, N); | ||
|
||
auto device = input.device(); | ||
// use the broadcasted bias as the output | ||
auto out = bias.to(device).view({1, -1}).repeat({M, 1}); | ||
|
||
// constexpr int kSparse = Gemm::kSparse; | ||
// How many elements of A are covered per ElementE | ||
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; | ||
// The size of individual meta data | ||
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; | ||
cutlass::gemm::GemmCoord problem_size(M, N, K); | ||
|
||
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref( | ||
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size)); | ||
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref( | ||
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size)); | ||
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref( | ||
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size)); | ||
|
||
typename Gemm::Arguments arguments{ | ||
problem_size, // <- problem size of matrix multiplication | ||
input_ref, // <- reference to matrix A on device | ||
weight_ref, // <- reference to matrix B on device | ||
out_ref, // <- reference to matrix C on device | ||
out_ref, // <- reference to matrix D on device | ||
{alpha, beta}, 1}; | ||
Gemm gemm_op; | ||
|
||
// Using the arguments, query for extra workspace required for matrix | ||
// multiplication computation | ||
size_t workspace_size = Gemm::get_workspace_size(arguments); | ||
|
||
// Allocate workspace memory | ||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||
|
||
// Check the problem size is supported or not | ||
cutlass::Status status = gemm_op.can_implement(arguments); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot implement"); | ||
} | ||
|
||
// Initialize CUTLASS kernel with arguments and workspace pointer | ||
status = gemm_op.initialize(arguments, workspace.get()); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot initialize"); | ||
} | ||
|
||
status = gemm_op(); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot run"); | ||
} | ||
#ifdef USE_TORCH_SILU | ||
#undef USE_TORCH_SILU | ||
out = torch::silu(out); | ||
#endif | ||
return out; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <torch/torch.h> | ||
#include <torch/types.h> | ||
|
||
#include <cstdint> | ||
#include <iostream> | ||
|
||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 | ||
torch::Tensor weight, // INT8 | ||
torch::Tensor bias, // FP32 | ||
float alpha, // FP32 | ||
float beta // FP32 | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
|
||
from .builder import Builder | ||
from .utils import append_nvcc_threads | ||
|
||
|
||
class SmoothquantBuilder(Builder): | ||
NAME = "cu_smoothquant" | ||
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" | ||
|
||
def __init__(self): | ||
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) | ||
|
||
def include_dirs(self): | ||
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] | ||
return ret | ||
|
||
def sources_files(self): | ||
ret = [ | ||
self.csrc_abs_path(fname) | ||
for fname in [ | ||
"smoothquant/binding.cpp", | ||
"smoothquant/linear.cu", | ||
] | ||
] | ||
return ret | ||
|
||
def cxx_flags(self): | ||
return ["-O3"] + self.version_dependent_macros | ||
|
||
def nvcc_flags(self): | ||
compute_capability = torch.cuda.get_device_capability() | ||
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 | ||
|
||
extra_cuda_flags = [ | ||
"-v", | ||
f"-DCUDA_ARCH={cuda_arch}", | ||
"-std=c++17", | ||
"-U__CUDA_NO_HALF_OPERATORS__", | ||
"-U__CUDA_NO_HALF_CONVERSIONS__", | ||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK", | ||
] | ||
|
||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags | ||
return append_nvcc_threads(ret) | ||
|
||
def builder(self): | ||
try: | ||
super().builder() | ||
except: | ||
warnings.warn("build smoothquant lib not successful") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import warnings | ||
|
||
import pytest | ||
import torch | ||
|
||
try: | ||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder | ||
|
||
smoothquant_cuda = SmoothquantBuilder().load() | ||
HAS_SMOOTHQUANT_CUDA = True | ||
except: | ||
warnings.warn("CUDA smoothquant linear is not installed") | ||
HAS_SMOOTHQUANT_CUDA = False | ||
|
||
|
||
@pytest.mark.skipif( | ||
not HAS_SMOOTHQUANT_CUDA, | ||
reason="smoothquant linear not installed properly", | ||
) | ||
def test_linear(): | ||
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") | ||
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") | ||
c = torch.rand(256, dtype=torch.float, device="cuda") | ||
|
||
alpha = 1 / 127 | ||
beta = 1.0 | ||
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c | ||
|
||
silu = torch.nn.SiLU() | ||
torch_out = silu(torch_out) | ||
|
||
b = b.transpose(0, 1).contiguous() | ||
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) | ||
|
||
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_linear() |