diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 00000000..33096a85 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,95 @@ +name: Build AutoAWQ Wheels with CUDA + +on: + push: + tags: + - "v*" + +jobs: + release: + # Retrieve tag and create release + name: Create Release + runs-on: ubuntu-latest + outputs: + upload_url: ${{ steps.create_release.outputs.upload_url }} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Extract branch info + shell: bash + run: | + echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + + - name: Create Release + id: create_release + uses: "actions/github-script@v6" + env: + RELEASE_TAG: ${{ env.release_tag }} + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + script: | + const script = require('.github/workflows/scripts/github_create_release.js') + await script(github, context, core) + + build_wheels: + name: Build AWQ + runs-on: ${{ matrix.os }} + needs: release + + strategy: + matrix: + os: [ubuntu-20.04, windows-latest] + pyver: ["3.8", "3.9", "3.10", "3.11"] + cuda: ["11.8"] + defaults: + run: + shell: pwsh + env: + CUDA_VERSION: ${{ matrix.cuda }} + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.pyver }} + + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v2.2.0 + with: + activate-environment: "build" + python-version: ${{ matrix.pyver }} + mamba-version: "*" + use-mamba: false + channels: conda-forge,defaults + channel-priority: true + add-pip-as-python-dependency: true + auto-activate-base: false + + - name: Install Dependencies + run: | + conda install cuda-toolkit -c "nvidia/label/cuda-${env:CUDA_VERSION}.0" + conda install pytorch "pytorch-cuda=${env:CUDA_VERSION}" -c pytorch -c nvidia + python -m pip install --upgrade build setuptools wheel ninja + + # Environment variables + Add-Content $env:GITHUB_ENV "CUDA_PATH=$env:CONDA_PREFIX" + Add-Content $env:GITHUB_ENV "CUDA_HOME=$env:CONDA_PREFIX" + if ($IsLinux) {$env:LD_LIBRARY_PATH = $env:CONDA_PREFIX + '/lib:' + $env:LD_LIBRARY_PATH} + + # Print version information + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Build Wheel + run: | + python setup.py sdist bdist_wheel + + - name: Upload Assets + uses: shogo82148/actions-upload-release-asset@v1 + with: + upload_url: ${{ needs.release.outputs.upload_url }} + asset_path: ./dist/*.whl \ No newline at end of file diff --git a/.github/workflows/scripts/github_create_release.js b/.github/workflows/scripts/github_create_release.js new file mode 100644 index 00000000..fe26188b --- /dev/null +++ b/.github/workflows/scripts/github_create_release.js @@ -0,0 +1,17 @@ +module.exports = async (github, context, core) => { + try { + const response = await github.rest.repos.createRelease({ + draft: false, + generate_release_notes: true, + name: process.env.RELEASE_TAG, + owner: context.repo.owner, + prerelease: false, + repo: context.repo.repo, + tag_name: process.env.RELEASE_TAG, + }); + + core.setOutput('upload_url', response.data.upload_url); + } catch (error) { + core.setFailed(error.message); + } +} diff --git a/README.md b/README.md index 8396b1f0..e77e2068 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ AutoAWQ is a package that implements the Activation-aware Weight Quantization (A Roadmap: -- [ ] Publish pip package +- [x] Publish pip package - [ ] Refactor quantization code - [ ] Support more models - [ ] Optimize the speed of models @@ -13,8 +13,20 @@ Roadmap: Requirements: - Compute Capability 8.0 (sm80). Ampere and later architectures are supported. +- CUDA Toolkit 11.8 and later. -Clone this repository and install with pip. +Install: +- Use pip to install awq + +``` +pip install awq +``` + +### Build source + +
+ +Build AutoAWQ from scratch ``` git clone https://github.com/casper-hansen/AutoAWQ @@ -22,6 +34,8 @@ cd AutoAWQ pip install -e . ``` +
+ ## Supported models The detailed support list: @@ -36,6 +50,7 @@ The detailed support list: | OPT | 125m/1.3B/2.7B/6.7B/13B/30B | | Bloom | 560m/3B/7B/ | | LLaVA-v0 | 13B | +| GPTJ | 6.7B | ## Usage @@ -44,8 +59,8 @@ Below, you will find examples for how to easily quantize a model and run inferen ### Quantization ```python +from awq import AutoAWQForCausalLM from transformers import AutoTokenizer -from awq.models.auto import AutoAWQForCausalLM model_path = 'lmsys/vicuna-7b-v1.5' quant_path = 'vicuna-7b-v1.5-awq' @@ -68,8 +83,8 @@ tokenizer.save_pretrained(quant_path) Run inference on a quantized model from Huggingface: ```python +from awq import AutoAWQForCausalLM from transformers import AutoTokenizer -from awq.models.auto import AutoAWQForCausalLM quant_path = "casperhansen/vicuna-7b-v1.5-awq" quant_file = "awq_model_w4_g128.pt" @@ -101,8 +116,11 @@ Benchmark speeds may vary from server to server and that it also depends on your | MPT-30B | A6000 | OOM | 31.57 | -- | | Falcon-7B | A6000 | 39.44 | 27.34 | 1.44x | +
-For example, here is the difference between a fast and slow CPU on MPT-7B: +Detailed benchmark (CPU vs. GPU) + +Here is the difference between a fast and slow CPU on MPT-7B: RTX 4090 + Intel i9 13900K (2 different VMs): - CUDA 12.0, Driver 525.125.06: 134 tokens/s (7.46 ms/token) @@ -113,6 +131,8 @@ RTX 4090 + AMD EPYC 7-Series (3 different VMs): - CUDA 12.2, Driver 535.54.03: 56 tokens/s (17.71 ms/token) - CUDA 12.0, Driver 525.125.06: 55 tokens/ (18.15 ms/token) +
+ ## Reference If you find AWQ useful or relevant to your research, you can cite their [paper](https://arxiv.org/abs/2306.00978): diff --git a/awq/__init__.py b/awq/__init__.py new file mode 100644 index 00000000..0ffd9d73 --- /dev/null +++ b/awq/__init__.py @@ -0,0 +1 @@ +from awq.models.auto import AutoAWQForCausalLM \ No newline at end of file diff --git a/awq/entry.py b/awq/entry.py index 4a7e135e..3886cdfb 100644 --- a/awq/entry.py +++ b/awq/entry.py @@ -4,7 +4,7 @@ import argparse from lm_eval import evaluator from transformers import AutoTokenizer -from awq.models.auto import AutoAWQForCausalLM +from awq import AutoAWQForCausalLM from awq.quantize.auto_clip import apply_clip from awq.quantize.auto_scale import apply_scale from awq.utils.lm_eval_adaptor import LMEvalAdaptor @@ -152,7 +152,7 @@ def _warmup(device:str): parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. ' 'Separate tasks by comma for multiple tasks.' 'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md') - parser.add_argument("--task_use_pretrained", default=False, action=argparse.BooleanOptionalAction, + parser.add_argument("--task_use_pretrained", default=False, action='store_true', help="Pass '--task_use_pretrained' to use a pretrained model running FP16") parser.add_argument('--task_batch_size', type=int, default=1) parser.add_argument('--task_n_shot', type=int, default=0) diff --git a/awq/modules/fused_attn.py b/awq/modules/fused_attn.py index 2615ce7a..28d55cae 100644 --- a/awq/modules/fused_attn.py +++ b/awq/modules/fused_attn.py @@ -34,8 +34,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) - # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False) def forward( @@ -46,7 +44,6 @@ def forward( ): # Apply rotary embedding to the query and key before passing them # to the attention op. - # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape) query = query.contiguous() key = key.contiguous() awq_inference_engine.rotary_embedding_neox( @@ -146,7 +143,7 @@ def make_quant_attn(model, dev): qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) - # g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + g_idx = None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None @@ -156,8 +153,6 @@ def make_quant_attn(model, dev): qkv_layer.scales = scales qkv_layer.bias = bias - # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch. - attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev) if '.' in name: @@ -169,6 +164,4 @@ def make_quant_attn(model, dev): parent = model child_name = name - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") - setattr(parent, child_name, attn) diff --git a/awq/modules/fused_mlp.py b/awq/modules/fused_mlp.py index 6957a329..0ca30baf 100644 --- a/awq/modules/fused_mlp.py +++ b/awq/modules/fused_mlp.py @@ -71,7 +71,6 @@ def our_llama_mlp(self, x): def make_fused_mlp(m, parent_name=''): if not hasattr(make_fused_mlp, "called"): - # print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.") make_fused_mlp.called = True """ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. diff --git a/awq/modules/fused_norm.py b/awq/modules/fused_norm.py index 50f49c3a..9ce8f64b 100644 --- a/awq/modules/fused_norm.py +++ b/awq/modules/fused_norm.py @@ -38,6 +38,4 @@ def make_quant_norm(model): parent = model child_name = name - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") - setattr(parent, child_name, norm) diff --git a/awq/quantize/auto_scale.py b/awq/quantize/auto_scale.py index a868ed6f..18bc3009 100644 --- a/awq/quantize/auto_scale.py +++ b/awq/quantize/auto_scale.py @@ -1,6 +1,7 @@ import gc import torch import torch.nn as nn +import logging from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.opt.modeling_opt import OPTDecoderLayer @@ -154,9 +155,8 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): best_scales = scales block.load_state_dict(org_sd) if best_ratio == -1: - print(history) + logging.debug(history) raise Exception - # print(best_ratio) best_scales = best_scales.view(-1) assert torch.isnan(best_scales).sum() == 0, best_scales diff --git a/awq/utils/calib_data.py b/awq/utils/calib_data.py index 0c6f82be..9320a1e1 100644 --- a/awq/utils/calib_data.py +++ b/awq/utils/calib_data.py @@ -1,4 +1,5 @@ import torch +import logging from datasets import load_dataset def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): @@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size= # now concatenate all samples and split according to block size cat_samples = torch.cat(samples, dim=1) n_split = cat_samples.shape[1] // block_size - print(f" * Split into {n_split} blocks") + logging.debug(f" * Split into {n_split} blocks") return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)] diff --git a/awq/utils/lm_eval_adaptor.py b/awq/utils/lm_eval_adaptor.py index c1c35d05..a38f1c4f 100644 --- a/awq/utils/lm_eval_adaptor.py +++ b/awq/utils/lm_eval_adaptor.py @@ -2,7 +2,7 @@ import torch from lm_eval.base import BaseLM import fnmatch - +import logging class LMEvalAdaptor(BaseLM): @@ -52,7 +52,7 @@ def max_length(self): elif 'falcon' in self.model_name: return 2048 else: - print(self.model.config) + logging.debug(self.model.config) raise NotImplementedError @property diff --git a/awq/utils/parallel.py b/awq/utils/parallel.py index f1ba27b0..eb4389bc 100644 --- a/awq/utils/parallel.py +++ b/awq/utils/parallel.py @@ -1,6 +1,7 @@ import os import torch import gc +import logging def auto_parallel(args): @@ -23,5 +24,5 @@ def auto_parallel(args): cuda_visible_devices = list(range(8)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( [str(dev) for dev in cuda_visible_devices[:n_gpu]]) - print("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) + logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) return cuda_visible_devices diff --git a/awq_cuda/layernorm/reduction.cuh b/awq_cuda/layernorm/reduction.cuh index 678160e8..f670d185 100644 --- a/awq_cuda/layernorm/reduction.cuh +++ b/awq_cuda/layernorm/reduction.cuh @@ -16,7 +16,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kern #include #include -static const float HALF_FLT_MAX = 65504.F; +#define HALF_FLT_MAX 65504.F #define FINAL_MASK 0xffffffff diff --git a/awq_cuda/quantization/gemm_cuda_gen.cu b/awq_cuda/quantization/gemm_cuda_gen.cu index 1632d8be..067a9c0f 100644 --- a/awq_cuda/quantization/gemm_cuda_gen.cu +++ b/awq_cuda/quantization/gemm_cuda_gen.cu @@ -9,7 +9,6 @@ */ - #include #include "gemm_cuda.h" #include "dequantize.cuh" @@ -31,9 +30,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (128 + 8)]; - - __shared__ half scaling_factors_shared[128]; - __shared__ half zeros_shared[128]; int j_factors1 = ((OC + 128 - 1) / 128); int blockIdx_x = 0; @@ -154,14 +150,14 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { { unsigned int addr; - __asm__ __volatile__( + asm volatile( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) ); - __asm__ __volatile__( + asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) @@ -172,12 +168,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { { unsigned int addr; - __asm__ __volatile__( + asm volatile( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) ); - __asm__ __volatile__( + asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) @@ -187,7 +183,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i } for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { { - __asm__ __volatile__( + asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) @@ -195,7 +191,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i } { - __asm__ __volatile__( + asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) @@ -349,12 +345,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in { { unsigned int addr; - __asm__ __volatile__( + asm volatile( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) ); - __asm__ __volatile__( + asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) @@ -367,12 +363,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in { { unsigned int addr; - __asm__ __volatile__( + asm volatile( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) ); - __asm__ __volatile__( + asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) @@ -385,7 +381,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in { { - __asm__ __volatile__( + asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) @@ -393,7 +389,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in } { - __asm__ __volatile__( + asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) diff --git a/examples/basic_generate.py b/examples/basic_generate.py new file mode 100644 index 00000000..5a9a678f --- /dev/null +++ b/examples/basic_generate.py @@ -0,0 +1,29 @@ +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer, TextStreamer + +quant_path = "casperhansen/vicuna-7b-v1.5-awq" +quant_file = "awq_model_w4_g128.pt" + +# Load model +model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True) +tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) +streamer = TextStreamer(tokenizer, skip_special_tokens=True) + +# Convert prompt to tokens +prompt_template = """\ +A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. + +USER: {prompt} +ASSISTANT:""" + +tokens = tokenizer( + prompt_template.format(prompt="How are you today?"), + return_tensors='pt' +).input_ids.cuda() + +# Generate output +generation_output = model.generate( + tokens, + streamer=streamer, + max_new_tokens=512 +) diff --git a/examples/basic_quant.py b/examples/basic_quant.py new file mode 100644 index 00000000..d6fdc96a --- /dev/null +++ b/examples/basic_quant.py @@ -0,0 +1,19 @@ +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'lmsys/vicuna-7b-v1.5' +quant_path = 'vicuna-7b-v1.5-awq' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4 } + +# Load model +model = AutoAWQForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Quantize +model.quantize(tokenizer, quant_config=quant_config) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file diff --git a/setup.py b/setup.py index 16a316af..ad1295d4 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,117 @@ import os +import torch +from pathlib import Path from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -# Get environment variables -build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1' -torch_is_prebuilt = os.environ.get('TORCH_IS_PREBUILT', '0') == '1' - -# Define dependencies -dependencies = [ - "accelerate", "sentencepiece", "tokenizers>=0.12.1", - "transformers>=4.32.0", - "lm_eval", "texttable", - "toml", "attributedict", - "protobuf" +from distutils.sysconfig import get_python_lib +from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension + +os.environ["CC"] = "g++" +os.environ["CXX"] = "g++" + +common_setup_kwargs = { + "version": "0.0.1", + "name": "autoawq", + "author": "Casper Hansen", + "license": "MIT", + "python_requires": ">=3.8.0", + "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.", + "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"), + "long_description_content_type": "text/markdown", + "url": "https://github.com/casper-hansen/AutoAWQ", + "keywords": ["awq", "autoawq", "quantization", "transformers"], + "platforms": ["linux", "windows"], + "classifiers": [ + "Environment :: GPU :: NVIDIA CUDA :: 11.8", + "Environment :: GPU :: NVIDIA CUDA :: 12", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: C++", + ] +} + +requirements = [ + "torch>=2.0.0", + "transformers>=4.32.0", + "tokenizers>=0.12.1", + "accelerate", + "sentencepiece", + "lm_eval", + "texttable", + "toml", + "attributedict", + "protobuf", + "torchvision" ] -if not torch_is_prebuilt: - dependencies.extend(["torch>=2.0.0", "torchvision"]) - -# Setup CUDA extension -ext_modules = [] - -if build_cuda_extension: - ext_modules.append( - CUDAExtension( - name="awq_inference_engine", - sources=[ - "awq_cuda/pybind.cpp", - "awq_cuda/quantization/gemm_cuda_gen.cu", - "awq_cuda/layernorm/layernorm.cu", - "awq_cuda/position_embedding/pos_encoding_kernels.cu" - ], - extra_compile_args={ - "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], - "nvcc": ["-O3", "-std=c++17"] - }, - ) +include_dirs = [] + +conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") +if os.path.isdir(conda_cuda_include_dir): + include_dirs.append(conda_cuda_include_dir) + +def check_dependencies(): + if CUDA_HOME is None: + raise RuntimeError( + f"Cannot find CUDA_HOME. CUDA must be available to build the package.") + +def get_compute_capabilities(): + # Collect the compute capabilities of all available GPUs. + compute_capabilities = set() + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.") + compute_capabilities.add(major * 10 + minor) + + # figure out compute capability + compute_capabilities = {80, 86, 89, 90} + + capability_flags = [] + for cap in compute_capabilities: + capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"] + + return capability_flags + +check_dependencies() +arch_flags = get_compute_capabilities() + +if os.name == "nt": + # Relaxed args on Windows + extra_compile_args={ + "nvcc": arch_flags + } +else: + extra_compile_args={ + "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], + "nvcc": ["-O3", "-std=c++17"] + arch_flags + } + +extensions = [ + CUDAExtension( + "awq_inference_engine", + [ + "awq_cuda/pybind.cpp", + "awq_cuda/quantization/gemm_cuda_gen.cu", + "awq_cuda/layernorm/layernorm.cu", + "awq_cuda/position_embedding/pos_encoding_kernels.cu" + ], extra_compile_args=extra_compile_args ) +] + +additional_setup_kwargs = { + "ext_modules": extensions, + "cmdclass": {'build_ext': BuildExtension} +} + +common_setup_kwargs.update(additional_setup_kwargs) setup( - name="awq", - version="0.1.0", - description="An efficient and accurate low-bit weight quantization(INT3/4) method for LLMs.", - long_description=open("README.md", "r").read(), - long_description_content_type="text/markdown", - python_requires=">=3.8", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - ], - install_requires=dependencies, - packages=find_packages(exclude=["results*", "scripts*", "examples*"]), - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} -) + packages=find_packages(), + install_requires=requirements, + include_dirs=include_dirs, + **common_setup_kwargs +) \ No newline at end of file