Skip to content

Commit

Permalink
Bugfix and more test for axis fusion, new workload (#50)
Browse files Browse the repository at this point in the history
* upd

* upd
  • Loading branch information
yzh119 committed Feb 15, 2022
1 parent 235327a commit f476319
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def dense_variable(
)

length, nnz = shape
indptr_len = parent_axis.length + 1
indptr_len = parent_axis.nnz + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
Expand Down
6 changes: 5 additions & 1 deletion src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class IndexTransformer : public StmtExprMutator {

const Optional<Var>& loop_var = axis2loop_var.Get(axis->GetParentAxis().value());
CHECK(loop_var.defined()) << "ValueError: The parent axis of " << axis
<< "does not appear in the sparse block";
<< " does not appear in the sparse block";

if (LoopVarAppears(loop_var.value())) {
return true;
Expand Down Expand Up @@ -559,6 +559,10 @@ class IndexTransformer : public StmtExprMutator {
for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) {
Var loop_var("v_" + sp_iter_var->var->name_hint);
var_map.Set(sp_iter_var->var, loop_var);
if (auto fused_axis = sp_iter_var->axis.as<FusedAxisNode>()) {
// handle the special case of fused_axis
axis2loop_var.Set(fused_axis->group[fused_axis->index], loop_var);
}
axis2loop_var.Set(sp_iter_var->axis, loop_var);
}

Expand Down
54 changes: 54 additions & 0 deletions tests/python/sparsetir/bench_rgcn_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from dgl.heterograph import DGLHeteroGraph
import tvm
import tvm.testing
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
import dgl
import dgl.function as fn
import torch as th
from tvm.script import tir as T
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset


@T.prim_func
def rgcn_hetero_forward(
offset_ntype: T.handle,
w: T.handle,
x: T.handle,
y: T.handle,
indptr_i: T.handle,
indptr_j: T.handle,
indices_j: T.handle,
n: T.int32,
r: T.int32,
feat_size: T.int32,
nnz_i: T.int32,
nnz_j: T.int32
):
I_flatten = T.dense_fixed(n)
R = T.dense_fixed(r)
I = T.dense_variable(R, (n, nnz_i), indptr_i, "int32")
J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), "int32")
F_in = T.dense_fixed(feat_size)
F_out = T.dense_fixed(feat_size)
offset = T.match_sparse_buffer(offset_ntype, (R,), "int32")
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
X = T.match_sparse_buffer(x, (I_flatten, F_in), "float32")
Y = T.match_sparse_buffer(y, (I_flatten, R, F_out), "float32")
with T.iter([T.fuse(R, I), F_out, J, F_in], "SSSRR", "rgcn-hetero-forward") as [
vr, vi, vout, vj, vin
]:
with T.init():
Y[offset[vr] + vi, vr, vout] = 0.
Y[offset[vr] + vi, vr, vout] = Y[offset[vr] + vi, vr, vout] + W[vr, vout, vin] * X[vj, vin]


def test_lower_rgcn_hetero():
mod = tvm.IRModule.from_expr(rgcn_hetero_forward)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())


if __name__ == "__main__":
test_lower_rgcn_hetero()
68 changes: 68 additions & 0 deletions tests/python/sparsetir/lowered_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,71 @@ def lowered_rgcn_forward(etype: T.handle, w: T.handle, x: T.handle, y: T.handle,
Y_data[vi * feat_size + vout] = T.float32(0)
Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[(
E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin]


@T.prim_func
def lowered_fused_reduction_4d_2d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
X_data = T.match_buffer(x, [nnz_l], dtype="float32")
Y_data = T.match_buffer(y, [nnz_j], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32")
# body
# with T.block("root")
for v_vi, v_vj in T.grid(1, nnz_j):
with T.block("reduction_4d_2d0"):
vi, vj = T.axis.remap("SS", [v_vi, v_vj])
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
for v_vk in T.serial(K_indptr[vj + 1] - K_indptr[vj]):
with T.block("reduction_4d_2d1"):
vk = T.axis.reduce(K_indptr[vj + 1] - K_indptr[vj], v_vk)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
with T.init():
Y_data[vj] = T.float32(0)
for v_vl in T.serial(L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk]):
with T.block("reduction_4d_2d2"):
vl = T.axis.reduce(
L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk], v_vl)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j])
T.writes(Y_data[0: nnz_j])
T.block_attr({"sparse": True})
Y_data[vj] = Y_data[vj] + X_data[L_indptr[K_indptr[vj] + vk] + vl]


@T.prim_func
def lowered_fused_reduction_4d_3d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
X_data = T.match_buffer(x, [nnz_l], dtype="float32")
Y_data = T.match_buffer(y, [nnz_k], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32")
# body
# with T.block("root")
for v_vi, v_vj, v_vk in T.grid(1, 1, nnz_k):
with T.block("reduction_4d_3d0"):
vi, vj, vk = T.axis.remap("SSS", [v_vi, v_vj, v_vk])
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k])
T.writes(Y_data[0: nnz_k])
T.block_attr({"sparse": True})
for v_vl in T.serial(L_indptr[vk + 1] - L_indptr[vk]):
with T.block("reduction_4d_3d1"):
vl = T.axis.reduce(L_indptr[vk + 1] - L_indptr[vk], v_vl)
T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1],
L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k])
T.writes(Y_data[0: nnz_k])
T.block_attr({"sparse": True})
with T.init():
Y_data[vk] = T.float32(0)
Y_data[vk] = Y_data[vk] + X_data[L_indptr[vk] + vl]
48 changes: 48 additions & 0 deletions tests/python/sparsetir/sparse_tir_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,54 @@ def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.
B[vi] = B[vi] + A[vi, vj, vk]


@T.prim_func
def fused_reduction_4d_2d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
I = T.dense_fixed(n)
J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32")
K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32")
L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32")
X = T.match_sparse_buffer(x, (I, J, K, L), "float32")
Y = T.match_sparse_buffer(y, (I, J), "float32")
with T.iter([T.fuse(I, J), K, L], "SSRR", "reduction_4d_2d") as [vi, vj, vk, vl]:
with T.init():
Y[vi, vj] = 0.0
Y[vi, vj] = Y[vi, vj] + X[vi, vj, vk, vl]


@T.prim_func
def fused_reduction_4d_3d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
I = T.dense_fixed(n)
J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32")
K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32")
L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32")
X = T.match_sparse_buffer(x, (I, J, K, L), "float32")
Y = T.match_sparse_buffer(y, (I, J, K), "float32")
with T.iter([T.fuse(I, J, K), L], "SSSR", "reduction_4d_3d") as [vi, vj, vk, vl]:
with T.init():
Y[vi, vj, vk] = 0.0
Y[vi, vj, vk] = Y[vi, vj, vk] + X[vi, vj, vk, vl]


@T.prim_func
def rgcn_forward(
etype: T.handle,
Expand Down
14 changes: 11 additions & 3 deletions tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,18 @@ def test_csr_element_wise():
tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True)


@pytest.mark.skip(reason="Under implementation")
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_bmm)


@pytest.mark.skip(reason="Under implementation")
def test_sddmm():
mod = tvm.IRModule.from_expr(sddmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm)


@pytest.mark.skip(reason="Under implementation")
def test_fused_sddmm():
mod = tvm.IRModule.from_expr(fused_sddmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand All @@ -96,6 +93,16 @@ def test_square_sum_two_K():
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True)


def test_fused_reduction():
mod = tvm.IRModule.from_expr(fused_reduction_4d_2d)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_2d, True)

mod = tvm.IRModule.from_expr(fused_reduction_4d_3d)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_3d, True)


if __name__ == "__main__":
test_csrmm()
test_csrmm_dense_iter()
Expand All @@ -109,3 +116,4 @@ def test_square_sum_two_K():
test_bmm()
test_square_sum()
test_square_sum_two_K()
test_fused_reduction()
1 change: 1 addition & 0 deletions tests/python/sparsetir/test_tir_sparse_tensorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO

0 comments on commit f476319

Please sign in to comment.