Skip to content

Cjkkkk/KgeN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

KgeN

A TVM-like CUDA/C code generator.

Component

  • Expression IR
  • Compute primitives
  • Schedule primitives
  • Infer bound pass
  • Cuda codegen pass

TODO

  • consolidate ranges
  • if_then_else expression
  • reduce expression
  • bind to thread
  • fix pass up and pass down
  • bound normalization and cosumer index change
  • fix eval_expr_bound with opening and closing corner case
  • add expr comparison for min max expr
  • add codegen for reduce and if_then_else expr
  • add boundary test to avoid out of index
  • add symbolic expression simplify
  • apply expr simplifier
  • fix attach.py
  • fix bound normalization
  • fix recursive attach path
  • change codegen to visitor pattern
  • transform into stmts
  • tensor index flatten
  • cache read dataflow rewrite
  • virtual thread and reduce axis
  • cache write
  • move collect input as build graph pass
  • expand split axis to enable expr simplify(i - > i_outer * 32 + i_inner)
  • normalize single point or not?
  • add expr simplify single point iter var as const expr
  • add sync_threads()
  • add unroll
  • split nparts
  • KgeN.all for padding
  • conv example
  • tir print
  • compute inline
  • default scope for placeholder is global, for compute is local
  • consider hash expr to avoid same_as overhead, make expr attribute immutable to make sure re-hash
  • consider if_then_else when evaluate bound
  • check if thread range equals to axis range after infer bound
  • add interval set
  • move const folding from tir to expr_simplifier?
  • add host code
  • sort expression term before simplification?
  • fix stride.py
  • directory restructure
  • rebase
  • operation
  • block fusion optimization

example

install

pip3 install -e .

matmul

python3 example/matmul_cache_write.py
// tensor: C[64, 64]
// tensor: C_local[4, 4]
// tensor: B_shared_local[1, 4]
// tensor: B_shared[32, 16]
// tensor: B[64, 64]
// tensor: A_shared_local[4, 1]
// tensor: A_shared[16, 32]
// tensor: A[64, 64]
// gridDim: [4, 4, 1]
// blockDim: [4, 4, 1]
__global__ void kernel(float* A, float* B, float* C) {
    float C_local[16];
    float B_shared_local[4];
    __shared__ float B_shared[512];
    float A_shared_local[4];
    __shared__ float A_shared[512];
    #pragma unroll
    for (int C_local_i = 0; C_local_i < 4 ; C_local_i += 1) {
        #pragma unroll
        for (int C_local_j = 0; C_local_j < 4 ; C_local_j += 1) {
            C_local[((C_local_i * 4) + C_local_j)] = 0;
        }
    }
    for (int k_outer = 0; k_outer < 2 ; k_outer += 1) {
        for (int A_shared_i0_inner = 0; A_shared_i0_inner < 4 ; A_shared_i0_inner += 1) {
            for (int A_shared_i1_inner = 0; A_shared_i1_inner < 8 ; A_shared_i1_inner += 1) {
                A_shared[(((threadIdx.x * 128) + (A_shared_i0_inner * 32)) + ((threadIdx.y * 8) + A_shared_i1_inner))] = A[((((threadIdx.x * 256) + (A_shared_i0_inner * 64)) + (blockIdx.x * 1024)) + (((threadIdx.y * 8) + A_shared_i1_inner) + (k_outer * 32)))];
            }
        }
        for (int B_shared_i0_inner = 0; B_shared_i0_inner < 8 ; B_shared_i0_inner += 1) {
            for (int B_shared_i1_inner = 0; B_shared_i1_inner < 4 ; B_shared_i1_inner += 1) {
                B_shared[(((threadIdx.x * 128) + (B_shared_i0_inner * 16)) + ((threadIdx.y * 4) + B_shared_i1_inner))] = B[((((threadIdx.x * 512) + (B_shared_i0_inner * 64)) + (k_outer * 2048)) + (((threadIdx.y * 4) + B_shared_i1_inner) + (blockIdx.y * 16)))];
            }
        }
        __syncthreads();
        for (int k_inner = 0; k_inner < 32 ; k_inner += 1) {
            for (int A_shared_local_i0 = 0; A_shared_local_i0 < 4 ; A_shared_local_i0 += 1) {
                A_shared_local[A_shared_local_i0] = A_shared[(((A_shared_local_i0 * 32) + (threadIdx.x * 128)) + k_inner)];
            }
            for (int B_shared_local_i1 = 0; B_shared_local_i1 < 4 ; B_shared_local_i1 += 1) {
                B_shared_local[B_shared_local_i1] = B_shared[((k_inner * 16) + (B_shared_local_i1 + (threadIdx.y * 4)))];
            }
            #pragma unroll
            for (int C_local_i = 0; C_local_i < 4 ; C_local_i += 1) {
                #pragma unroll
                for (int C_local_j = 0; C_local_j < 4 ; C_local_j += 1) {
                    C_local[((C_local_i * 4) + C_local_j)] = (C_local[((C_local_i * 4) + C_local_j)] + (A_shared_local[C_local_i] * B_shared_local[C_local_j]));
                }
            }
        }
        __syncthreads();
    }
    for (int C_i_inner = 0; C_i_inner < 4 ; C_i_inner += 1) {
        for (int C_j_inner = 0; C_j_inner < 4 ; C_j_inner += 1) {
            C[((((blockIdx.x * 1024) + (threadIdx.x * 256)) + (C_i_inner * 64)) + (((blockIdx.y * 16) + (threadIdx.y * 4)) + C_j_inner))] = C_local[((C_i_inner * 4) + C_j_inner)];
        }
    }
}

About

A TVM-like CUDA/C code generator.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages