diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 74bb8a5771a6..b43904f463da 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1285,8 +1285,8 @@ class SparseBlockNode : public StmtNode { public: /*! \brief The sparse iteration variables of the block. */ Array sp_iter_vars; - /*! \brief The sparse buffers defined in the block. */ - Array sp_buffers; + /*! \brief The mapping from sparse data structures to the PrimFunc parameters */ + Map> sp_struct2param_map; /*! \brief The name of the block */ String name; /*! \brief The body of the block */ @@ -1296,20 +1296,21 @@ class SparseBlockNode : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("sp_iter_vars", &sp_iter_vars); - v->Visit("sp_buffers", &sp_buffers); + v->Visit("sp_struct2param_map", &sp_struct2param_map); v->Visit("name", &name); v->Visit("body", &body); v->Visit("init", &init); } bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const { - return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) && - equal(name, other->name) && equal(body, other->body) && equal(init, other->init); + return equal(sp_iter_vars, other->sp_iter_vars) && + equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) && + equal(body, other->body) && equal(init, other->init); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(sp_iter_vars); - hash_reduce(sp_buffers); + hash_reduce(sp_struct2param_map); hash_reduce(name); hash_reduce(body); hash_reduce(init); @@ -1325,9 +1326,9 @@ class SparseBlockNode : public StmtNode { */ class SparseBlock : public Stmt { public: - TVM_DLL explicit SparseBlock(Array sp_iter_vars, Array sp_buffers, - String name, Stmt body, Optional init = NullOpt, - Span span = Span()); + TVM_DLL explicit SparseBlock(Array sp_iter_vars, + Map> sp_struct2param_map, String name, + Stmt body, Optional init = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode); diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 2a84c3d896d2..724f9a27078b 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -128,13 +128,15 @@ class ContextMaintainer: """List[Var]: The function parameters""" func_buffer_map: Mapping[Var, Buffer] = {} """Mapping[Var, Buffer]: The function buffer map""" - func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {} - """Mapping[Var, SparseBuffer]: The function sparse buffer map""" func_dict_attr: Mapping[str, Object] = {} """Mapping[str, Object]: The function attrs""" func_var_env_dict: Mapping[Var, str] = {} """Mapping[Var, str]: The map from var to env thread""" + # sparse block context + sp_struct2param_map: Mapping[Object, List[Var]] = {} + """Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters""" + # parser and analyzer analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer() """tvm.arith.Analyzer: The analyzer for simplifying""" @@ -154,9 +156,10 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No # function context self.func_params = [] self.func_buffer_map = {} - self.func_sparse_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} + # sparse block context + self.sp_struct2param_map = {} # parser and analyzer self._report_error = _report_error self.analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index ff4837696db8..b243942aabb7 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -903,6 +903,7 @@ def __init__(self): def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None): var_name = self.node.lhs[0].id.name axis = DenseFixedAxis(name, length) + self.context.sp_struct2param_map[axis] = [] self.context.update_symbol(var_name, axis, self.node) super().__init__(dense_fixed, def_symbol=True) @@ -926,7 +927,7 @@ def dense_variable( (indptr_len,), dtype=idtype, name=name + "_indptr", span=span ) axis = DenseVariableAxis(name, length, indptr_buf) - self.context.func_buffer_map[indptr_var] = indptr_buf + self.context.sp_struct2param_map[axis] = indptr_var self.context.update_symbol(var_name, axis, self.node) self.context.update_symbol(name + "_indptr", indptr_buf, self.node) @@ -951,7 +952,7 @@ def sparse_fixed( (nnz,), dtype=idtype, name=name + "_indices", span=span ) axis = SparseFixedAxis(name, length, indices_buf, nnz_cols) - self.context.func_buffer_map[indices_var] = indices_buf + self.context.sp_struct2param_map[axis] = [indices_var] self.context.update_symbol(var_name, axis, self.node) self.context.update_symbol(name + "_indices", indices_buf, self.node) @@ -980,8 +981,7 @@ def sparse_variable( (nnz,), dtype=idtype, name=name + "_indices", span=span ) axis = SparseVariableAxis(name, length, indptr_buf, indices_buf) - self.context.func_buffer_map[indices_var] = indices_buf - self.context.func_buffer_map[indptr_var] = indptr_buf + self.context.sp_struct2param_map[axis] = [indptr_var, indices_var] self.context.update_symbol(var_name, axis, self.node) self.context.update_symbol(name + "_indptr", indptr_buf, self.node) self.context.update_symbol(name + "_indices", indices_buf, self.node) @@ -1017,8 +1017,7 @@ def match_sparse_buffer( if param in self.context.func_params: data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span) buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name) - self.context.func_buffer_map[param] = data - self.context.func_sparse_buffer_map[param] = buffer + self.context.sp_struct2param_map[buffer] = [param] self.context.update_symbol(buffer_name + "_data", data, self.node) self.context.update_symbol(buffer_name, buffer, self.node) else: diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 84b91981ea89..68b5eca8ecda 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -35,7 +35,7 @@ from . import _ffi_api from .buffer import Buffer -from .expr import IterVar +from .expr import Var, IterVar from .sparse import SpIterVar, SparseBuffer @@ -624,8 +624,8 @@ class SparseBlock(Stmt): sp_iter_vars : List[SpIterVar] The sparse iteration variables of the block. - sp_buffers : List[SparseBuffer] - The sparse buffers defined in the block. + sp_struct2param_map : Mapping[Object, List[Var]] + The mapping from sparse data structures to the PrimFunc parameters. name : str The name of the block. @@ -641,7 +641,7 @@ class SparseBlock(Stmt): """ sp_iter_vars: List[SpIterVar] - sp_buffers: List[SparseBuffer] + sp_struct2param_map: Mapping[Object, List[Var]] name: str body: Stmt init: Optional[Stmt] @@ -650,7 +650,7 @@ class SparseBlock(Stmt): def __init__( self, sp_iter_vars: List[SpIterVar], - sp_buffers: List[SparseBuffer], + sp_struct2param_map: Mapping[Object, List[Var]], name: str, body: Stmt, init: Optional[Stmt] = None, @@ -659,7 +659,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.SparseBlock, # type: ignore sp_iter_vars, - sp_buffers, + sp_struct2param_map, name, body, init, diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1cc80dd4d73c..b62b206b1420 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -975,11 +975,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); -SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_buffers, String name, - Stmt body, Optional init, Span span) { +SparseBlock::SparseBlock(Array sp_iter_vars, + Map> sp_struct2param_map, String name, Stmt body, + Optional init, Span span) { ObjectPtr node = make_object(); node->sp_iter_vars = std::move(sp_iter_vars); - node->sp_buffers = std::move(sp_buffers); + node->sp_struct2param_map = std::move(sp_struct2param_map); node->name = std::move(name); node->body = std::move(body); node->init = std::move(init); @@ -988,9 +989,10 @@ SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_b } TVM_REGISTER_GLOBAL("tir.SparseBlock") - .set_body_typed([](Array sp_iter_vars, Array sp_buffers, String name, - Stmt body, Optional init, Span span) { - return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span); + .set_body_typed([](Array sp_iter_vars, + Map> sp_struct2param_map, String name, Stmt body, + Optional init, Span span) { + return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span); }); TVM_REGISTER_NODE_TYPE(SparseBlockNode);