diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 75e8df1d4033..119fc77db7be 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -271,5 +271,5 @@ def dense(axis: Axis, span: Optional[Span] = None): @register -def fuse(group: List[Axis], span: Optional[Span] = None): - return [FusedAxis(group, _) for _ in range(len(group))] +def fuse(*group: List[Axis], span: Optional[Span] = None): + return [FusedAxis(group, i) for i, _ in enumerate(group)] diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 1bae0698a2ff..1bc189ccefa4 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -24,7 +24,7 @@ from tvm.runtime import Object from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind -from tvm.tir.sparse import SpIterVar +from tvm.tir.sparse import SpIterVar, Axis from .node import BufferSlice from .utils import buffer_slice_to_region @@ -331,10 +331,22 @@ class SparseBlock(WithScopeHandler): def __init__(self): def iter(axes: List, iter_types: str, name: str = "", span: Optional[Span] = None): + + # flatten nested axes to axes, to address the special case of fusion. + def flatten_axes(axes: List[Union[Axis, List[Axis]]]) -> List[Axis]: + ret = [] + for axis_group in axes: + if isinstance(axis_group, List): + ret += axis_group + else: + ret.append(axis_group) + return ret + assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] + axes = flatten_axes(axes) if len(axes) != len(self.sp_iters): self.context.report_error( diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index eb538f961351..b43d91831a73 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -346,6 +346,36 @@ def bmm( Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj] +@T.prim_func +def sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([I, J, K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def fused_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + @T.prim_func def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): I = T.dense_fixed(M) @@ -616,7 +646,20 @@ def test_csr_element_wise(): def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - # Todo + # TODO + + +def test_sddmm(): + mod = tvm.IRModule.from_expr(sddmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod['main'].script()) + # TODO + + +def test_fused_sddmm(): + mod = tvm.IRModule.from_expr(fused_sddmm) + print(mod['main'].script()) + # TODO def test_square_sum(): @@ -707,6 +750,8 @@ def test_square_sum_two_K(): test_bsrmm() test_ellpack_mm() test_csr_element_wise() + test_sddmm() + test_fused_sddmm() test_bmm() test_square_sum() test_square_sum_two_K()