Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more examples part 1 (sddmm) #22

Merged
merged 3 commits into from
Nov 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/target/source/literal/cuda_binary_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@

static constexpr const char* _cuda_binary_search_def = R"(
template <typename DType>
__forceinline__ __device__ int32_t __lower_bound(
__forceinline__ __device__ int __lower_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
int l,
int r) {
int low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
int mid = (low + high) >> 1;
if (arr[mid] < val) {
low = mid;
} else {
Expand All @@ -46,15 +46,15 @@ __forceinline__ __device__ int32_t __lower_bound(
}

template <typename DType>
__forceinline__ __device__ int32_t __upper_bound(
__forceinline__ __device__ int __upper_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
int l,
int r) {
int low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
int mid = (low + high) >> 1;
if (arr[mid] > val) {
high = mid;
} else {
Expand Down
64 changes: 64 additions & 0 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
B[A_indices[vi * NNZ_COLS + vj] * K + vk]


@T.prim_func
def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalis": True})
A = T.match_buffer(a, (M * K,), "float32")
B = T.match_buffer(b, (N * K,), "float32")
C_data = T.match_buffer(c, (NNZ,), "float32")
C_indptr = T.match_buffer(indptr, (M + 1,), "int32")
C_indices = T.match_buffer(indices, (NNZ,), "int32")
for ij, k in T.grid(NNZ, K):
with T.block("sddmm"):
vij, vk = T.axis.remap("SR", [ij, k])
T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]])
T.writes([C_data[vij]])
with T.init():
C_data[vij] = 0.
C_data[vij] = C_data[vij] + \
A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk]


def test_csrmm():
# generate random input
m = 4096
Expand Down Expand Up @@ -219,6 +238,50 @@ def test_ellmm():
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


def test_sddmm():
# generate random input
m = 4096
n = 4096
k = 256
C = sp.random(m, n, dtype="float32", density=0.0125, format='csr')
indptr = C.indptr
indices = C.indices
C_coo = C.tocoo()
nnz = C.nnz
x = np.random.rand(m, k).astype("float32")
y = np.random.rand(n, k).astype("float32")
z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col]
z = np.zeros((nnz,)).astype("float32")

# specialize function
_, _, _, _, _, M, N, K, NNZ = sddmm_tir.params
sch = tir.Schedule(
sddmm_tir.specialize(
{M: m, N: n, K: k, NNZ: nnz}
)
)
blk = sch.get_block("sddmm")
ij, k = sch.get_loops(blk)
#sch.decompose_reduction(blk, ij)
sch.bind(ij, "blockIdx.x")
ko, ki = sch.split(k, [None, 1])
sch.bind(ki, "threadIdx.x")

# convert numpy tensor to tvm ndarray
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))
C_data = tvm.nd.array(z, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod['main'], target="cuda")
f(X_nd, Y_nd, C_data, C_indptr, C_indices)

# assertion
np.allclose(z_ground_truth, C_data.numpy())


def test_bmm():
# TODO(zihao)
pass
Expand All @@ -228,4 +291,5 @@ def test_bmm():
test_csrmm()
test_bsrmm()
test_ellmm()
test_sddmm()
test_bmm()
28 changes: 15 additions & 13 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,21 @@ def test_fma():
assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"


@tvm.script.tir
def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
n = tir.var('int32')
m = tir.var('int32')
A = tir.match_buffer(a, (n,), dtype='int32')
B = tir.match_buffer(b, (m,), dtype='int32')
C = tir.match_buffer(c, (m,), dtype='int32')
D = tir.match_buffer(d, (m,), dtype='int32')
with tir.block([m], 'search') as [vi]:
tir.reads([A[0:n], B[vi]])
tir.writes([C[vi], D[vi]])
C[vi] = tir.lower_bound(A.data, B[vi], 0, n)
D[vi] = tir.upper_bound(A.data, B[vi], 0, n)
@T.prim_func
def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
n = T.var('int32')
m = T.var('int32')
A = T.match_buffer(a, (n,), dtype='int32')
B = T.match_buffer(b, (m,), dtype='int32')
C = T.match_buffer(c, (m,), dtype='int32')
D = T.match_buffer(d, (m,), dtype='int32')
for i in T.serial(0, m):
with T.block('search'):
vi = T.axis.S(m, i)
T.reads([A[0:n], B[vi]])
T.writes([C[vi], D[vi]])
C[vi] = T.lower_bound(A.data, B[vi], 0, n)
D[vi] = T.upper_bound(A.data, B[vi], 0, n)


def test_binary_search():
Expand Down