Skip to content

Commit

Permalink
[inference] add silu linear fusion for smoothquant llama mlp (hpcaite…
Browse files Browse the repository at this point in the history
…ch#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests
  • Loading branch information
Xu-Kai committed Oct 13, 2023
1 parent 574dd2c commit 920751a
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 0 deletions.
8 changes: 8 additions & 0 deletions colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <torch/extension.h>

#include "linear.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}
162 changes: 162 additions & 0 deletions colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu

#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>

#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);

using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B

// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;

#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;

using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif

auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);

auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});

// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);

cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));

typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}
12 changes: 12 additions & 0 deletions colossalai/kernel/cuda_native/csrc/smoothquant/linear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <torch/torch.h>
#include <torch/types.h>

#include <cstdint>
#include <iostream>

torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);
52 changes: 52 additions & 0 deletions op_builder/smoothquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from .builder import Builder
from .utils import append_nvcc_threads


class SmoothquantBuilder(Builder):
NAME = "cu_smoothquant"
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"

def __init__(self):
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)

def include_dirs(self):
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
return ret

def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"smoothquant/binding.cpp",
"smoothquant/linear.cu",
]
]
return ret

def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros

def nvcc_flags(self):
compute_capability = torch.cuda.get_device_capability()
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10

extra_cuda_flags = [
"-v",
f"-DCUDA_ARCH={cuda_arch}",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]

ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

def builder(self):
try:
super().builder()
except:
warnings.warn("build smoothquant lib not successful")
39 changes: 39 additions & 0 deletions tests/test_smoothquant/test_smoothquant_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import warnings

import pytest
import torch

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False


@pytest.mark.skipif(
not HAS_SMOOTHQUANT_CUDA,
reason="smoothquant linear not installed properly",
)
def test_linear():
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
c = torch.rand(256, dtype=torch.float, device="cuda")

alpha = 1 / 127
beta = 1.0
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c

silu = torch.nn.SiLU()
torch_out = silu(torch_out)

b = b.transpose(0, 1).contiguous()
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)

assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)


if __name__ == "__main__":
test_linear()

0 comments on commit 920751a

Please sign in to comment.