Skip to content

Commit

Permalink
Add more examples part 1 (sddmm) (#22)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd
  • Loading branch information
yzh119 committed Jan 24, 2022
1 parent a4213c9 commit f9736b8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 23 deletions.
20 changes: 10 additions & 10 deletions src/target/source/literal/cuda_binary_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@

static constexpr const char* _cuda_binary_search_def = R"(
template <typename DType>
__forceinline__ __device__ int32_t __lower_bound(
__forceinline__ __device__ int __lower_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
int l,
int r) {
int low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
int mid = (low + high) >> 1;
if (arr[mid] < val) {
low = mid;
} else {
Expand All @@ -46,15 +46,15 @@ __forceinline__ __device__ int32_t __lower_bound(
}
template <typename DType>
__forceinline__ __device__ int32_t __upper_bound(
__forceinline__ __device__ int __upper_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
int l,
int r) {
int low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
int mid = (low + high) >> 1;
if (arr[mid] > val) {
high = mid;
} else {
Expand Down
64 changes: 64 additions & 0 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
B[A_indices[vi * NNZ_COLS + vj] * K + vk]


@T.prim_func
def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalis": True})
A = T.match_buffer(a, (M * K,), "float32")
B = T.match_buffer(b, (N * K,), "float32")
C_data = T.match_buffer(c, (NNZ,), "float32")
C_indptr = T.match_buffer(indptr, (M + 1,), "int32")
C_indices = T.match_buffer(indices, (NNZ,), "int32")
for ij, k in T.grid(NNZ, K):
with T.block("sddmm"):
vij, vk = T.axis.remap("SR", [ij, k])
T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]])
T.writes([C_data[vij]])
with T.init():
C_data[vij] = 0.
C_data[vij] = C_data[vij] + \
A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk]


def test_csrmm():
# generate random input
m = 4096
Expand Down Expand Up @@ -219,6 +238,50 @@ def test_ellmm():
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


def test_sddmm():
# generate random input
m = 4096
n = 4096
k = 256
C = sp.random(m, n, dtype="float32", density=0.0125, format='csr')
indptr = C.indptr
indices = C.indices
C_coo = C.tocoo()
nnz = C.nnz
x = np.random.rand(m, k).astype("float32")
y = np.random.rand(n, k).astype("float32")
z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col]
z = np.zeros((nnz,)).astype("float32")

# specialize function
_, _, _, _, _, M, N, K, NNZ = sddmm_tir.params
sch = tir.Schedule(
sddmm_tir.specialize(
{M: m, N: n, K: k, NNZ: nnz}
)
)
blk = sch.get_block("sddmm")
ij, k = sch.get_loops(blk)
#sch.decompose_reduction(blk, ij)
sch.bind(ij, "blockIdx.x")
ko, ki = sch.split(k, [None, 1])
sch.bind(ki, "threadIdx.x")

# convert numpy tensor to tvm ndarray
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))
C_data = tvm.nd.array(z, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod['main'], target="cuda")
f(X_nd, Y_nd, C_data, C_indptr, C_indices)

# assertion
np.allclose(z_ground_truth, C_data.numpy())


def test_bmm():
# TODO(zihao)
pass
Expand All @@ -228,4 +291,5 @@ def test_bmm():
test_csrmm()
test_bsrmm()
test_ellmm()
test_sddmm()
test_bmm()
28 changes: 15 additions & 13 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,21 @@ def test_fma():
assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"


@tvm.script.tir
def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
n = tir.var('int32')
m = tir.var('int32')
A = tir.match_buffer(a, (n,), dtype='int32')
B = tir.match_buffer(b, (m,), dtype='int32')
C = tir.match_buffer(c, (m,), dtype='int32')
D = tir.match_buffer(d, (m,), dtype='int32')
with tir.block([m], 'search') as [vi]:
tir.reads([A[0:n], B[vi]])
tir.writes([C[vi], D[vi]])
C[vi] = tir.lower_bound(A.data, B[vi], 0, n)
D[vi] = tir.upper_bound(A.data, B[vi], 0, n)
@T.prim_func
def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
n = T.var('int32')
m = T.var('int32')
A = T.match_buffer(a, (n,), dtype='int32')
B = T.match_buffer(b, (m,), dtype='int32')
C = T.match_buffer(c, (m,), dtype='int32')
D = T.match_buffer(d, (m,), dtype='int32')
for i in T.serial(0, m):
with T.block('search'):
vi = T.axis.S(m, i)
T.reads([A[0:n], B[vi]])
T.writes([C[vi], D[vi]])
C[vi] = T.lower_bound(A.data, B[vi], 0, n)
D[vi] = T.upper_bound(A.data, B[vi], 0, n)


def test_binary_search():
Expand Down

0 comments on commit f9736b8

Please sign in to comment.