Skip to content

Commit

Permalink
[SparseTIR] Add "square sum" lowering test (#37)
Browse files Browse the repository at this point in the history
* Add square sum test

* Remove pylint comment
  • Loading branch information
MasterJH5574 authored and yzh119 committed Jan 21, 2022
1 parent c7c1678 commit 93c1daf
Showing 1 changed file with 87 additions and 1 deletion.
88 changes: 87 additions & 1 deletion tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,52 @@ def bmm(
Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj]


@T.prim_func
def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32):
I = T.dense_fixed(M)
J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32")
K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32")
A = T.match_sparse_buffer(a, (I, J, K), "float32")
B = T.match_sparse_buffer(b, (I,), "float32")

with T.iter([I, J, K], "SRR", "square_sum") as [vi, vj, vk]:
with T.init():
B[vi] = 0.0
B[vi] = B[vi] + A[vi, vj, vk]


@T.prim_func
def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None:
A_data = T.match_buffer(a, [nnz_k], dtype="float32")
B_data = T.match_buffer(b, [M], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32")
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32")

for v_vi in T.serial(0, M):
with T.block("square_sum_2"):
vi = T.axis.spatial(M, v_vi)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
with T.block("square_sum_1"):
vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
with T.init():
B_data[vi] = T.float32(0)
for v_vk in T.serial(0, K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj]):
with T.block("square_sum"):
vk = T.axis.reduce(K_indptr[J_indptr[v_vi] + v_vj + 1] - K_indptr[J_indptr[v_vi] + v_vj], v_vk)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K_indptr[0 : nnz_j + 1], K_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk]


def test_csrmm():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand Down Expand Up @@ -372,11 +418,13 @@ def test_csrmm_dense_iter():
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
# Todo


def test_segment_reduce():
mod = tvm.IRModule.from_expr(segment_reduce)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
# Todo


def test_csr_reduce():
Expand Down Expand Up @@ -512,7 +560,44 @@ def test_csr_element_wise():
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod['main'].script())
# Todo


def test_square_sum():
mod = tvm.IRModule.from_expr(square_sum)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum, True)

density = 0.0125
M = N1 = N2 = 128
A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr")
indptr_j = A_J.indptr
indices_j = A_J.indices
nnz_j = A_J.nnz
A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr")
indptr_k = A_K.indptr
indices_k = A_K.indices
nnz_k = A_K.nnz
data = A_K.data

b_ij = np.asarray(A_K.sum(axis=1)).squeeze()
A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1))
b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze()
b = np.zeros((M,)).astype("float32")

v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum.params[-5:]
f = tvm.build(mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="llvm")

ctx = tvm.cpu(0)
A_data = tvm.nd.array(data.astype("float32"), device=ctx)
A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx)
A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx)
A_indptr_k = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
A_indices_k = tvm.nd.array(indices_k.astype("int32"), device=ctx)
B_data = tvm.nd.array(b.astype("float32"), device=ctx)
f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k, A_indices_k)

tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
Expand All @@ -524,3 +609,4 @@ def test_bmm():
test_ellpack_mm()
test_csr_element_wise()
test_bmm()
test_square_sum()

0 comments on commit 93c1daf

Please sign in to comment.