From c9a3efb11ed3670fd4c729521008d80e0ae0ea75 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Mon, 23 Nov 2020 20:15:32 +0000 Subject: [PATCH] AArch64 base algorithm refactoring in LLVM (#6907) * AArch64 base algorithm refactoring in LLVM - I refactored the assembly in arm_cpu/tensor_intrin.py to use LLVM+TIR - Removed the `interleave` boolean parameter in the intrinsic to switch among two different interleaving modes. LLVM will now take care of interleaving the instructions - Applied the changes accordingly to conv2d_gemm.py to call the right instrinsic Note: I found LLVM very sensible to the choice of the `-mcpu`. So, in order to preserve performance, it is important to specify the right `-mcpu` when creating the LLVM target * Fix linting * Fix linting -2 * Fixing comments * Address review comments * Fix spaces around ':' in docstrings --- python/tvm/topi/arm_cpu/conv2d_gemm.py | 15 +- python/tvm/topi/arm_cpu/tensor_intrin.py | 746 +++++++++++------------ 2 files changed, 348 insertions(+), 413 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 6a5cb2ae890e..85c03997a98d 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -24,8 +24,7 @@ from ..utils import get_const_tuple, get_const_int from ..nn.utils import get_pad_tuple from .tensor_intrin import ( - gemm_quantized, - gemm_quantized_impl, + gemm_4x4_int8_int8_int32, gemm_acc_4x4_int8_int8_int32, gemm_acc_nx16_int8_int8_int32, gemm_acc_2x2_int8_int8_int32, @@ -51,11 +50,8 @@ def configure_knobs(cfg, M, K): if not is_dotprod_available(): cfg.define_knob("gemm_quantized_unroll", [True, False]) - cfg.define_knob("gemm_quantized_interleave", [True, False]) - if cfg.is_fallback: cfg["gemm_quantized_unroll"] = OtherOptionEntity(False) - cfg["gemm_quantized_interleave"] = OtherOptionEntity(True) # Compute function @@ -361,14 +357,9 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): elif is_aarch64_arm(): s[C_interleaved].reorder(yi, xi) K = A_interleaved_input.shape[2] + assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported" unroll = cfg["gemm_quantized_unroll"].val - interleave = cfg["gemm_quantized_interleave"].val - gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type) - s[C_interleaved].pragma( - b_outer_gemm_fused, - "import_llvm", - gemm_quantized_impl(M, N, K, unroll, interleave, in_type), - ) + gemm = gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type) s[C_interleaved].tensorize(yi, gemm) # Output transform diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 8ccbe0c41298..4055d7b05c24 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -19,392 +19,42 @@ import tvm from tvm import te -from tvm.contrib import utils, clang - - -def gemm_quantized_4_4_batched(): - return """ - // First half - // Higher part of a0 * {b0,b1,b2,b3} - "umull v8.8h, v0.8b, v4.8b\\n" - "umull v9.8h, v0.8b, v5.8b\\n" - "umull v10.8h, v0.8b, v6.8b\\n" - "umull v11.8h, v0.8b, v7.8b\\n" - - // Higher part of a1 * {b0,b1,b2,b3} - "umull v12.8h, v1.8b, v4.8b\\n" - "umull v13.8h, v1.8b, v5.8b\\n" - "umull v14.8h, v1.8b, v6.8b\\n" - "umull v15.8h, v1.8b, v7.8b\\n" - - // Accumulate - "uadalp v16.4s, v8.8h\\n" - "uadalp v17.4s, v9.8h\\n" - "uadalp v18.4s, v10.8h\\n" - "uadalp v19.4s, v11.8h\\n" - "uadalp v20.4s, v12.8h\\n" - "uadalp v21.4s, v13.8h\\n" - "uadalp v22.4s, v14.8h\\n" - "uadalp v23.4s, v15.8h\\n" - - // Lower part of a0 * {b0,b1,b2,b3} - "umull2 v8.8h, v0.16b, v4.16b\\n" - "umull2 v9.8h, v0.16b, v5.16b\\n" - "umull2 v10.8h, v0.16b, v6.16b\\n" - "umull2 v11.8h, v0.16b, v7.16b\\n" - - // Lower part of a1 * {b0,b1,b2,b3} - "umull2 v12.8h, v1.16b, v4.16b\\n" - "umull2 v13.8h, v1.16b, v5.16b\\n" - "umull2 v14.8h, v1.16b, v6.16b\\n" - "umull2 v15.8h, v1.16b, v7.16b\\n" - - // Accumulate again - "uadalp v16.4s, v8.8h\\n" - "uadalp v17.4s, v9.8h\\n" - "uadalp v18.4s, v10.8h\\n" - "uadalp v19.4s, v11.8h\\n" - "uadalp v20.4s, v12.8h\\n" - "uadalp v21.4s, v13.8h\\n" - "uadalp v22.4s, v14.8h\\n" - "uadalp v23.4s, v15.8h\\n" - - // Second half - // Lower part of a2 * {b0,b1,b2,b3} - "umull v8.8h, v2.8b, v4.8b\\n" - "umull v9.8h, v2.8b, v5.8b\\n" - "umull v10.8h, v2.8b, v6.8b\\n" - "umull v11.8h, v2.8b, v7.8b\\n" - - // Lower part of a3 * {b0,b1,b2,b3} - "umull v12.8h, v3.8b, v4.8b\\n" - "umull v13.8h, v3.8b, v5.8b\\n" - "umull v14.8h, v3.8b, v6.8b\\n" - "umull v15.8h, v3.8b, v7.8b\\n" - - // Accumulate - "uadalp v24.4s, v8.8h\\n" - "uadalp v25.4s, v9.8h\\n" - "uadalp v26.4s, v10.8h\\n" - "uadalp v27.4s, v11.8h\\n" - "uadalp v28.4s, v12.8h\\n" - "uadalp v29.4s, v13.8h\\n" - "uadalp v30.4s, v14.8h\\n" - "uadalp v31.4s, v15.8h\\n" - - // Higher part of a2 * {b0,b1,b2,b3} - "umull2 v8.8h, v2.16b, v4.16b\\n" - "umull2 v9.8h, v2.16b, v5.16b\\n" - "umull2 v10.8h, v2.16b, v6.16b\\n" - "umull2 v11.8h, v2.16b, v7.16b\\n" - - // Higher part of a3 * {b0,b1,b2,b3} - "umull2 v12.8h, v3.16b, v4.16b\\n" - "umull2 v13.8h, v3.16b, v5.16b\\n" - "umull2 v14.8h, v3.16b, v6.16b\\n" - "umull2 v15.8h, v3.16b, v7.16b\\n" - - // Accumulate again - "uadalp v24.4s, v8.8h\\n" - "uadalp v25.4s, v9.8h\\n" - "uadalp v26.4s, v10.8h\\n" - "uadalp v27.4s, v11.8h\\n" - "uadalp v28.4s, v12.8h\\n" - "uadalp v29.4s, v13.8h\\n" - "uadalp v30.4s, v14.8h\\n" - "uadalp v31.4s, v15.8h\\n" - """ - - -def gemm_quantized_4_4_interleaved(): - return """ - // First half - // Higher part of a0 * {b0,b1,b2,b3} and accumulate - "umull v8.8h, v0.8b, v4.8b\\n" - "uadalp v16.4s, v8.8h\\n" - "umull v9.8h, v0.8b, v5.8b\\n" - "uadalp v17.4s, v9.8h\\n" - "umull v10.8h, v0.8b, v6.8b\\n" - "uadalp v18.4s, v10.8h\\n" - "umull v11.8h, v0.8b, v7.8b\\n" - "uadalp v19.4s, v11.8h\\n" - - // Higher part of a1 * {b0,b1,b2,b3} and accumulate - "umull v12.8h, v1.8b, v4.8b\\n" - "uadalp v20.4s, v12.8h\\n" - "umull v13.8h, v1.8b, v5.8b\\n" - "uadalp v21.4s, v13.8h\\n" - "umull v14.8h, v1.8b, v6.8b\\n" - "uadalp v22.4s, v14.8h\\n" - "umull v15.8h, v1.8b, v7.8b\\n" - "uadalp v23.4s, v15.8h\\n" - - // Lower part of a0 * {b0,b1,b2,b3} and accumulate - "umull2 v8.8h, v0.16b, v4.16b\\n" - "uadalp v16.4s, v8.8h\\n" - "umull2 v9.8h, v0.16b, v5.16b\\n" - "uadalp v17.4s, v9.8h\\n" - "umull2 v10.8h, v0.16b, v6.16b\\n" - "uadalp v18.4s, v10.8h\\n" - "umull2 v11.8h, v0.16b, v7.16b\\n" - "uadalp v19.4s, v11.8h\\n" - - // Lower part of a1 * {b0,b1,b2,b3} and accumulate - "umull2 v12.8h, v1.16b, v4.16b\\n" - "uadalp v20.4s, v12.8h\\n" - "umull2 v13.8h, v1.16b, v5.16b\\n" - "uadalp v21.4s, v13.8h\\n" - "umull2 v14.8h, v1.16b, v6.16b\\n" - "uadalp v22.4s, v14.8h\\n" - "umull2 v15.8h, v1.16b, v7.16b\\n" - "uadalp v23.4s, v15.8h\\n" - - // Second half - // Higher part of a2 * {b0,b1,b2,b3} and accumulate - "umull v8.8h, v2.8b, v4.8b\\n" - "uadalp v24.4s, v8.8h\\n" - "umull v9.8h, v2.8b, v5.8b\\n" - "uadalp v25.4s, v9.8h\\n" - "umull v10.8h, v2.8b, v6.8b\\n" - "uadalp v26.4s, v10.8h\\n" - "umull v11.8h, v2.8b, v7.8b\\n" - "uadalp v27.4s, v11.8h\\n" - - // Higher part of a3 * {b0,b1,b2,b3} and accumulate - "umull v12.8h, v3.8b, v4.8b\\n" - "uadalp v28.4s, v12.8h\\n" - "umull v13.8h, v3.8b, v5.8b\\n" - "uadalp v29.4s, v13.8h\\n" - "umull v14.8h, v3.8b, v6.8b\\n" - "uadalp v30.4s, v14.8h\\n" - "umull v15.8h, v3.8b, v7.8b\\n" - "uadalp v31.4s, v15.8h\\n" - - // Lower part of a2 * {b0,b1,b2,b3} and accumulate - "umull2 v8.8h, v2.16b, v4.16b\\n" - "uadalp v24.4s, v8.8h\\n" - "umull2 v9.8h, v2.16b, v5.16b\\n" - "uadalp v25.4s, v9.8h\\n" - "umull2 v10.8h, v2.16b, v6.16b\\n" - "uadalp v26.4s, v10.8h\\n" - "umull2 v11.8h, v2.16b, v7.16b\\n" - "uadalp v27.4s, v11.8h\\n" - - // Lower part of a3 * {b0,b1,b2,b3} and accumulate - "umull2 v12.8h, v3.16b, v4.16b\\n" - "uadalp v28.4s, v12.8h\\n" - "umull2 v13.8h, v3.16b, v5.16b\\n" - "uadalp v29.4s, v13.8h\\n" - "umull2 v14.8h, v3.16b, v6.16b\\n" - "uadalp v30.4s, v14.8h\\n" - "umull2 v15.8h, v3.16b, v7.16b\\n" - "uadalp v31.4s, v15.8h\\n" - """ - - -def gemm_quantized_impl(M, N, K, unroll, interleave, data_type="uint8"): - """Assembly implementation of a blocked gemv. Given - a block a of shape (4, k) and a block b' of shape (4, k) - produces the output block c = a*b of shape (4,4)""" - stepA = min(4, M) - stepB = min(4, N) - assert data_type in ["uint8", "int8"], "Only uint8/int8 supported for this implementation" - signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format( - data_type, stepA, stepB - ) - if unroll: - signature += "_" + str(K) - - if interleave: - signature += "_interleaved" - - signature += """(int *c_buffer, - unsigned char *a_buffer, - unsigned char *b_buffer, - int K, int m, int n)""" - - cc_code = signature - cc_code += """ - { - unsigned char * a_ptr = a_buffer; - unsigned char * b_ptr = b_buffer; - int * c_ptr = c_buffer; - - int k = K / 16; - - __asm__ __volatile__ ( - "movi v16.4s, #0\\n" - "movi v17.4s, #0\\n" - "movi v18.4s, #0\\n" - "movi v19.4s, #0\\n" - "movi v20.4s, #0\\n" - "movi v21.4s, #0\\n" - "movi v22.4s, #0\\n" - "movi v23.4s, #0\\n" - "movi v24.4s, #0\\n" - "movi v25.4s, #0\\n" - "movi v26.4s, #0\\n" - "movi v27.4s, #0\\n" - "movi v28.4s, #0\\n" - "movi v29.4s, #0\\n" - "movi v30.4s, #0\\n" - "movi v31.4s, #0\\n" - "1:" +def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type): """ + Int8 4x4 matrix multiplication and accumulation using a sequence of + umull -> uadalp -> umull2 -> uadalp instructions. This function + takes two arrays of int8 data type A[4][K] and B[4][K], and produces + a 4x4 matrix which is equal to A*B'. - main_loop = ' "ldr q0, [%[a_ptr]]\\n" ' - - if M > 1: - main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" ' - else: - main_loop += ' "movi v1.4s, #0\\n" ' - - if M > 2: - main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" ' - else: - main_loop += ' "movi v2.4s, #0\\n" ' - - if M > 3: - main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" ' - else: - main_loop += ' "movi v3.4s, #0\\n" ' - - main_loop += ' "ldr q4, [%[b_ptr]]\\n" ' - - if N > 1: - main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" ' - - if N > 2: - main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" ' - - if N > 3: - main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" ' - - # Main computation can interleave multiply/accumulate instructions - # or schedule them in batches (first all multiplies then all accumulates) - if interleave: - main_loop += gemm_quantized_4_4_interleaved() - else: - main_loop += gemm_quantized_4_4_batched() + The pseudo code is as follows. - blockA = min(64, M * 16) - blockB = min(64, N * 16) - main_loop += """// Increment pointers - "add %[a_ptr], %[a_ptr], #{0}\\n" - "add %[b_ptr], %[b_ptr], #{1}\\n" """.format( - blockA, blockB - ) + .. code-block:: c - if unroll: - k = int(K // 16) - for l in range(0, k): - cc_code += main_loop - else: - cc_code += main_loop - cc_code += """ - "subs %w[k], %w[k], #1\\n" - "cbnz %w[k], 1b\\n" - """ - cc_code += """ - // Final additions - - // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d) - // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h) - // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l) - // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p) - "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h) - "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p) - "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) - - // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d) - // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h) - // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l) - // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p) - "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h) - "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p) - "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) - - // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d) - // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h) - // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l) - // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p) - "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b, c+d, e+f, g+h) - "addp v25.4s, v26.4s, v27.4s\\n" // v25 = (i+j, k+l, m+n, o+p) - "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) - - // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d) - // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h) - // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l) - // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p) - "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h) - "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p) - "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) - - "str q16, [%[c_ptr]]\\n" - """ - - stepC = min(4, N) - if M > 1: - cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4) - - if M > 2: - cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8) - - if M > 3: - cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12) - - cc_code += """ - : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k) - : - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31" - ); - return 0; + void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){ + for (int i = 0; i < 4; i++){ + for (int j = 0; j < 4; j++){ + for (int k = 0; k < K; k++){ + C[i][j] += A[i][k] * B[j][k] + } + } } - """ - if data_type == "int8": - cc_code = cc_code.replace("unsigned char", "char") - cc_code = cc_code.replace("umull", "smull") - cc_code = cc_code.replace("uadalp", "sadalp") - - temp = utils.tempdir() - ll_path = temp.relpath("temp.ll") - # Create LLVM ir from c source code - ll_code = clang.create_llvm( - cc_code, options=["--target=aarch64-linux-gnu -mattr=+neon"], output=ll_path - ) - return ll_code - - -def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type): - """ - Use integer ARM v8 instructions in order to produce a block c of 4x4 elements - given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final - result is c = a*b (where '*' indicates the matrix product) - - Every row of the matrix c is obtained (for uint8) by a sequence of - - umull -> uadalp -> umull2 -> uadalp - - The block size is constrained by the number of registers available in arvm8. This - function returns a TensorIntrin that can be used to tensorize - a schedule. + Notes: + * The tiling strategy is picked to maximize register usage. Parameters ---------- - M: int + M : int rows of the matrix A - N: int + N : int columns of the matrix B - K: int + K : int columns of matrix A - in_type: str, {'uint8', 'int8'} - out_type: str, {'uint32', 'int32'} + unroll : bool + Unroll the loop accumulation if True + in_type : str, {'uint8', 'int8'} Returns ------- @@ -414,7 +64,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type): assert in_type in ["uint8", "int8"] A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A") B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B") - + dtype_vec = in_type + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, K), "k") @@ -447,28 +97,322 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type): C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1] ) + # Intrinsics used in the following algorithm + umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull" + uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp" + addp_intrin = "llvm.aarch64.neon.addp" + + def uadalp(a, b): + """Add pair and accumulate + + Parameters: + ---------- + a: int16x8 vector + b: int16x8 vector + + Returns: + -------- + return a int32x4 vector + + Pseudocode: + ---------- + a += (b0+b1, b2+b3, b4+b5, b6+b7) + """ + + return a + tvm.tir.call_llvm_pure_intrin( + "int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b + ) + + def umull(a, b): + """Multiply long (higher part) + + Parameters: + ---------- + a: int8x16 vector + b: int8x16 vector + + Returns: + -------- + return a int16x8 vector + + Pseudocode: + ---------- + c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7) + """ + a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a) + b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b) + c = tvm.tir.call_llvm_pure_intrin( + "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high + ) + return c + + def umull2(a, b): + """Multiply long (lower part) + + Parameters: + ---------- + a: int8x16 vector + b: int8x16 vector + + Returns: + -------- + return a int16x8 vector + + Pseudocode: + ---------- + c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15) + """ + a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a) + b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b) + c = tvm.tir.call_llvm_pure_intrin( + "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low + ) + return c + + def addp(a, b): + """Add two vectors in pairs + + Parameters: + ---------- + a: int32x4 vector + b: int32x4 vector + + Returns: + -------- + return a int32x4 vector + + Pseudocode: + ---------- + c = (a0+a1, a2+a3, b0+b1, b0+b3) + """ + return tvm.tir.call_llvm_pure_intrin( + "int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b + ) + + def accumulation_loop(M, N, ins, acc, tile_idx): + """Internal tile accumulation. This function + takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces + a 4x4 matrix which is equal to A*B' and accumulates into C[4][4] + + The pseudo code is as follows. + + .. code-block:: c + + void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K], + int8 B[tile_idx][4][K], + int32 C[4][4]){ + for (int i = 0; i < 4; i++){ + for (int j = 0; j < 4; j++){ + for (int k = 0; k < 16; k++){ + C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k] + } + } + } + + Notes: + * The tiling strategy is picked to maximize register usage. + + Parameters: + ---------- + M : int + Number of total rows of the output matrix + N : int + Number of total columns of the output matrix + ins : list of tvm.tir.buffer + Input buffers + acc : tvm.tir.ir_builder.BufferVar + Bank of register accumulators + tiled_idx : int + Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:]. + Please note that 0 <= tile_idx <= K//16 + + """ + a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec) + a1 = tvm.tir.const(0, "int8x16") + if M > 1: + a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec) + a2 = tvm.tir.const(0, "int8x16") + if M > 2: + a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec) + a3 = tvm.tir.const(0, "int8x16") + if M > 3: + a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec) + + b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec) + b1 = tvm.tir.const(0, "int8x16") + if N > 1: + b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec) + b2 = tvm.tir.const(0, "int8x16") + if N > 2: + b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec) + b3 = tvm.tir.const(0, "int8x16") + if N > 3: + b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec) + + # First half + # Lower part of a0 * {b0,b1,b2,b3} + d00 = umull(a0, b0) + d01 = umull(a0, b1) + d02 = umull(a0, b2) + d03 = umull(a0, b3) + + # Lower part of a1 * {b0,b1,b2,b3} + d10 = umull(a1, b0) + d11 = umull(a1, b1) + d12 = umull(a1, b2) + d13 = umull(a1, b3) + + # Accumulate + acc[0] = uadalp(acc[0], d00) + acc[1] = uadalp(acc[1], d01) + acc[2] = uadalp(acc[2], d02) + acc[3] = uadalp(acc[3], d03) + acc[4] = uadalp(acc[4], d10) + acc[5] = uadalp(acc[5], d11) + acc[6] = uadalp(acc[6], d12) + acc[7] = uadalp(acc[7], d13) + + # Higher part of a0 * {b0,b1,b2,b3} + d00 = umull2(a0, b0) + d01 = umull2(a0, b1) + d02 = umull2(a0, b2) + d03 = umull2(a0, b3) + + # Higher part of a1 * {b0,b1,b2,b3} + d10 = umull2(a1, b0) + d11 = umull2(a1, b1) + d12 = umull2(a1, b2) + d13 = umull2(a1, b3) + + # Accumulate again + acc[0] = uadalp(acc[0], d00) + acc[1] = uadalp(acc[1], d01) + acc[2] = uadalp(acc[2], d02) + acc[3] = uadalp(acc[3], d03) + acc[4] = uadalp(acc[4], d10) + acc[5] = uadalp(acc[5], d11) + acc[6] = uadalp(acc[6], d12) + acc[7] = uadalp(acc[7], d13) + + # Second half + # Lower part of a2 * {b0,b1,b2,b3} + d00 = umull(a2, b0) + d01 = umull(a2, b1) + d02 = umull(a2, b2) + d03 = umull(a2, b3) + + # Lower part of a3 * {b0,b1,b2,b3} + d10 = umull(a3, b0) + d11 = umull(a3, b1) + d12 = umull(a3, b2) + d13 = umull(a3, b3) + + # Accumulate + acc[8] = uadalp(acc[8], d00) + acc[9] = uadalp(acc[9], d01) + acc[10] = uadalp(acc[10], d02) + acc[11] = uadalp(acc[11], d03) + acc[12] = uadalp(acc[12], d10) + acc[13] = uadalp(acc[13], d11) + acc[14] = uadalp(acc[14], d12) + acc[15] = uadalp(acc[15], d13) + + # Higher part of a2 * {b0,b1,b2,b3} + d00 = umull2(a2, b0) + d01 = umull2(a2, b1) + d02 = umull2(a2, b2) + d03 = umull2(a2, b3) + + # Lower part of a3 * {b0,b1,b2,b3} + d10 = umull2(a3, b0) + d11 = umull2(a3, b1) + d12 = umull2(a3, b2) + d13 = umull2(a3, b3) + + # Accumulate + acc[8] = uadalp(acc[8], d00) + acc[9] = uadalp(acc[9], d01) + acc[10] = uadalp(acc[10], d02) + acc[11] = uadalp(acc[11], d03) + acc[12] = uadalp(acc[12], d10) + acc[13] = uadalp(acc[13], d11) + acc[14] = uadalp(acc[14], d12) + acc[15] = uadalp(acc[15], d13) + def _intrin_func(ins, outs): def _instr(): ib = tvm.tir.ir_builder.create() - aa, bb = ins - cc = outs[0] - stepA = min(4, M) - stepB = min(4, N) - intrin_name = "gemm_quantized_{0}_{0}_int32_{1}_{2}".format(in_type, stepA, stepB) + # Allocate a local buffer (possibly translates to registers) + acc = ib.allocate("int32x4", 16, name="accs", scope="local") + m = outs[0].shape[0] + n = outs[0].shape[1] + # Initialization + for i in range(0, 16): + acc[i] = tvm.tir.const(0, "int32x4") + if unroll: - intrin_name += "_" + str(K) - if interleave: - intrin_name += "_interleaved" - ib.emit( - tvm.tir.call_extern( - "int32", - intrin_name, - outs[0].access_ptr("w"), - a_buffer.access_ptr("r"), - b_buffer.access_ptr("r"), - K, - ) - ) + for i in range(0, int(K // 16)): + accumulation_loop(M, N, ins, acc, i) + else: + with ib.for_range(0, K // 16, name="i") as i: + accumulation_loop(M, N, ins, acc, i) + + # Final accumulations + # acc[4*r + c] contains the partial accumulations of element C[r][c] + # + # In particular: + # acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d) + # acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h) + # acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l) + # acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p) + # + # Please note that 0<= r, c < 4 + + acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h) + acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p) + acc[0] = addp(acc[0], acc[1]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + acc[4] = addp(acc[4], acc[5]) # (a+b, c+d, e+f, g+h) + acc[5] = addp(acc[6], acc[7]) # (i+j, k+l, m+n, o+p) + acc[4] = addp(acc[4], acc[5]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + acc[8] = addp(acc[8], acc[9]) # (a+b, c+d, e+f, g+h) + acc[9] = addp(acc[10], acc[11]) # (i+j, k+l, m+n, o+p) + acc[8] = addp(acc[8], acc[9]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + acc[12] = addp(acc[12], acc[13]) # (a+b, c+d, e+f, g+h) + acc[13] = addp(acc[14], acc[15]) # (i+j, k+l, m+n, o+p) + acc[12] = addp(acc[12], acc[13]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + # Store the result + if N > 3: + out_0 = acc[0] + out_1 = acc[4] + out_2 = acc[8] + out_3 = acc[12] + elif N > 2: + out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[0]) + out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[4]) + out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[8]) + out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[12]) + elif N > 1: + out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[0]) + out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[4]) + out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[8]) + out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[12]) + else: + out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0]) + out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4]) + out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8]) + out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[12]) + + ib.emit(outs[0].vstore([0, 0], out_0)) + if M > 1: + ib.emit(outs[0].vstore([1, 0], out_1)) + if M > 2: + ib.emit(outs[0].vstore([2, 0], out_2)) + if M > 3: + ib.emit(outs[0].vstore([3, 0], out_3)) return ib.get() # body, reset, update @@ -509,9 +453,9 @@ def dot_int8_int8_int32(int32_lanes, dtype="uint"): Parameters ---------- - int32_lanes: int + int32_lanes : int How many int32/uint32 to produce - dtype: str, optional, {"uint", "int"} + dtype : str, optional, {"uint", "int"} Whether it works on unsigned int or signed int Returns @@ -602,16 +546,16 @@ def select_word(vec, lane, dtype_vec): Parameters ---------- - vec: tvm.tir.Expr + vec : tvm.tir.Expr int8x16 vector expression - lane: int + lane : int vector lane we want to replicate - dtype_vec: str + dtype_vec : str vector data type (e.g., int8x16) Returns ---------- - output: tvm.tir.Expr + output : tvm.tir.Expr replicated vector """ # Reinterpret vec_a as 4 int32 words @@ -648,7 +592,7 @@ def gemm_acc_4x4_int8_int8_int32(dtype): Parameters ---------- - dtype: str, {"uint8", "int8"} + dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns @@ -779,9 +723,9 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows): Parameters ---------- - dtype: str, {"uint8", "int8"} + dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int - rows: int + rows : int Number of of the output rows "n" Returns @@ -990,7 +934,7 @@ def gemm_acc_2x2_int8_int8_int32(dtype): Parameters ---------- - dtype: str, {"uint8", "int8"} + dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns