diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 8621c2b572e75..c2514168c46d1 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -131,13 +131,11 @@ class DenseFixedAxisNode : public DenseAxisNode { } bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(from_sparse, other->from_sparse); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(from_sparse); @@ -170,12 +168,10 @@ class DenseVariableAxisNode : public DenseAxisNode { } bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indptr); @@ -213,13 +209,11 @@ class SparseFixedAxisNode : public SparseAxisNode { } bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indices, other->indices) && equal(num_cols, other->num_cols); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indices); @@ -257,13 +251,11 @@ class SparseVariableAxisNode : public SparseAxisNode { } bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr) && equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(name); hash_reduce(length); hash_reduce(indptr); @@ -347,12 +339,10 @@ class SparseBufferNode : public Object { } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(axes); hash_reduce(data); hash_reduce(name); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 3e5fd05d6382e..105dd70093555 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1285,6 +1285,8 @@ class SparseBlockNode : public StmtNode { public: /*! \brief The sparse iteration variables of the block. */ Array sp_iter_vars; + /*! \brief The sparse data structures */ + Array sp_structs; /*! \brief The mapping from sparse data structures to the PrimFunc parameters */ Map> sp_struct2param_map; /*! \brief The name of the block */ @@ -1296,6 +1298,7 @@ class SparseBlockNode : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("sp_iter_vars", &sp_iter_vars); + v->Visit("sp_structs", &sp_structs); v->Visit("sp_struct2param_map", &sp_struct2param_map); v->Visit("name", &name); v->Visit("body", &body); @@ -1305,7 +1308,7 @@ class SparseBlockNode : public StmtNode { bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const { return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) && equal(body, other->body) && equal(init, other->init) && - equal(sp_struct2param_map, other->sp_struct2param_map); + equal(sp_structs, other->sp_structs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -1313,7 +1316,7 @@ class SparseBlockNode : public StmtNode { hash_reduce(name); hash_reduce(body); hash_reduce(init); - hash_reduce(sp_struct2param_map); + hash_reduce(sp_structs); } static constexpr const char* _type_key = "tir.SparseBlock"; @@ -1326,9 +1329,9 @@ class SparseBlockNode : public StmtNode { */ class SparseBlock : public Stmt { public: - TVM_DLL explicit SparseBlock(Array sp_iter_vars, - Map> sp_struct2param_map, String name, - Stmt body, Optional init = NullOpt, Span span = Span()); + TVM_DLL explicit SparseBlock(Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, + Optional init = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode); diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index e81a41699fe60..09dc06055e8c1 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -134,8 +134,10 @@ class ContextMaintainer: """Mapping[Var, str]: The map from var to env thread""" # sparse block context - sp_struct2param_map: Mapping[Object, List[Var]] = {} - """Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters""" + sp_struct: List[Object] = [] + """List[Object]: The sparse data structures""" + sp_struct_params: List[List[Var]] = [] + """List[List[Var]]: The function parameters that corresponding to each sparse data structures""" # parser and analyzer analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer() @@ -155,7 +157,8 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No self.func_dict_attr = {} self.func_var_env_dict = {} # sparse block context - self.sp_struct2param_map = {} + self.sp_struct = [] + self.sp_struct_params = [] # parser and analyzer self._report_error = _report_error self.analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 157cc7a722b9b..0c61141fba030 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -282,6 +282,6 @@ def pos(axis: Axis, span: Optional[Span] = None): elif isinstance(axis, DenseVariableAxis): return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) elif isinstance(axis, SparseFixedAxis): - return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis) + return SpIterVar(var_temp, axis.num_cols, SpIterVar.SparseFixed, False, axis) else: return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 44610c9833064..7ed93ba9157ca 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -365,7 +365,8 @@ def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = No block = tvm.tir.SparseBlock( sp_iters, - self.context.sp_struct2param_map, + self.context.sp_struct, + self.context.sp_struct_params, name, self.body, block_info.init, diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 4fcf206d1cac3..1c9ef30351d6b 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -832,7 +832,8 @@ def dense_fixed(length: PrimExpr, span: Optional[Span] = None): ) axis = DenseFixedAxis(names[0], length) - self.context.sp_struct2param_map[axis] = [] + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([]) self.context.update_symbol(names[0], axis, self.node) super().__init__(dense_fixed, def_symbol=True) @@ -860,7 +861,8 @@ def dense_variable( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) axis = DenseVariableAxis(names[0], length, indptr_buf) - self.context.sp_struct2param_map[axis] = [indptr_var] + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var]) self.context.update_symbol(names[0], axis, self.node) self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) @@ -889,7 +891,8 @@ def sparse_fixed( (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols) - self.context.sp_struct2param_map[axis] = [indices_var] + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indices_var]) self.context.update_symbol(names[0], axis, self.node) self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) @@ -922,7 +925,8 @@ def sparse_variable( (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf) - self.context.sp_struct2param_map[axis] = [indptr_var, indices_var] + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var, indices_var]) self.context.update_symbol(names[0], axis, self.node) self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) @@ -958,7 +962,8 @@ def match_sparse_buffer( if param in self.context.func_params: data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span) buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name) - self.context.sp_struct2param_map[buffer] = [param] + self.context.sp_struct.append(buffer) + self.context.sp_struct_params.append([param]) self.context.update_symbol(buffer_name + "_data", data, self.node) self.context.update_symbol(buffer_name, buffer, self.node) else: diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 07fd48208d1f0..bb63ca3a76662 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -26,6 +26,7 @@ from .buffer import Buffer +@tvm._ffi.register_object("tir.sparse.Axis") class Axis(Object): """Base class of all the sparse axes.""" @@ -42,10 +43,12 @@ def idtype(self): return _ffi_api.GetAxisIndexType(self) +@tvm._ffi.register_object("tir.sparse.DenseAxis") class DenseAxis(Axis): pass +@tvm._ffi.register_object("tir.sparse.SparseAxis") class SparseAxis(Axis): pass diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 7a3677cf8f3a5..67785bb1ae51a 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -649,6 +649,12 @@ class SparseBlock(Stmt): sp_iter_vars : List[SpIterVar] The sparse iteration variables of the block. + sp_struct : List[Object] + The sparse data structures + + sp_struct_params : List[List[Var]] + The function parameters that corresponding to each sparse data structures + sp_struct2param_map : Mapping[Object, List[Var]] The mapping from sparse data structures to the PrimFunc parameters. @@ -666,6 +672,7 @@ class SparseBlock(Stmt): """ sp_iter_vars: List[SpIterVar] + sp_struct: List[Object] sp_struct2param_map: Mapping[Object, List[Var]] name: str body: Stmt @@ -675,7 +682,8 @@ class SparseBlock(Stmt): def __init__( self, sp_iter_vars: List[SpIterVar], - sp_struct2param_map: Mapping[Object, List[Var]], + sp_struct: List[Object], + sp_struct_params: List[List[Var]], name: str, body: Stmt, init: Optional[Stmt] = None, @@ -684,7 +692,8 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.SparseBlock, # type: ignore sp_iter_vars, - sp_struct2param_map, + sp_struct, + sp_struct_params, name, body, init, diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index c83399db90fce..3d6aa96fc5c65 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1277,12 +1277,14 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo std::vector axis_docs; std::vector sp_buf_docs; - for (auto it : sp_block->sp_struct2param_map) { + for (const ObjectRef& obj : sp_block->sp_structs) { + Array params = sp_block->sp_struct2param_map.Get(obj).value(); + Doc doc; - doc << Print(it.first) << " = " << tir_prefix_ << "."; + doc << Print(obj) << " = " << tir_prefix_ << "."; - if (const auto* sp_buffer = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); + if (const auto* sp_buffer = obj.as()) { + ICHECK_EQ(params.size(), 1); Doc axes_doc; if (sp_buffer->axes.size() != 1) { std::vector axes_docs; @@ -1295,30 +1297,30 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo axes_doc << Print(sp_buffer->axes[0]) << ","; } - doc << "match_sparse_buffer(" << Print(it.second[0]) << ", (" << axes_doc << "), " + doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), " << Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")"; sp_buf_docs.push_back(doc); continue; } - if (const auto* df_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 0); + if (const auto* df_axis = obj.as()) { + ICHECK_EQ(params.size(), 0); doc << "dense_fixed(" << Print(df_axis->length) << ")"; - } else if (const auto* dv_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); + } else if (const auto* dv_axis = obj.as()) { + ICHECK_EQ(params.size(), 1); doc << "dense_variable((" << Print(dv_axis->length) << ", " - << Print(dv_axis->indptr->shape[0]) << "), " << Print(it.second[0]) << ", " + << Print(dv_axis->indptr->shape[0]) << "), " << Print(params[0]) << ", " << PrintDType(dv_axis->indptr->dtype) << ")"; - } else if (const auto* sf_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); + } else if (const auto* sf_axis = obj.as()) { + ICHECK_EQ(params.size(), 1); doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0]) - << ", " << Print(sf_axis->num_cols) << "), " << Print(it.second[0]) << ", " + << ", " << Print(sf_axis->num_cols) << "), " << Print(params[0]) << ", " << PrintDType(sf_axis->indices->dtype) << ")"; - } else if (const auto* sv_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 2); + } else if (const auto* sv_axis = obj.as()) { + ICHECK_EQ(params.size(), 2); doc << "sparse_variable((" << Print(sv_axis->length) << ", " << Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), (" - << Print(it.second[0]) << ", " << Print(it.second[1]) << "), " + << Print(params[0]) << ", " << Print(params[1]) << "), " << PrintDType(sv_axis->indptr->dtype) << ")"; } else { ICHECK(false) << "Cannot reach here"; diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index f782eea32e747..d3d9865e9cca7 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -234,7 +234,6 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu ObjectPtr node = make_object(); arith::Analyzer ana; - CHECK(ana.CanProveEqual(axis->length, max_extent)); const char* err_str = "ValueError: The given kind doesn't match the type of the given axis"; if (kind == SpIterKind::kDenseFixed) { CHECK(!axis->IsInstance()) << err_str; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index a64ab47d72f85..c8dbfdffe30b6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -968,11 +968,42 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); -SparseBlock::SparseBlock(Array sp_iter_vars, - Map> sp_struct2param_map, String name, Stmt body, +SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, Optional init, Span span) { + CHECK_EQ(sp_structs.size(), sp_struct_params.size()) + << "ValueError: The length of `sp_struct_params` is expected to be equal to the length " + "`sp_structs`, which is the number of sparse data structures"; + Map> sp_struct2param_map; + for (int i = 0; i < static_cast(sp_structs.size()); ++i) { + ObjectRef obj = sp_structs[i]; + Array params = sp_struct_params[i]; + + if (obj->IsInstance()) { + CHECK(params.size() == 0) + << "ValueError: The number of function parameters for dense-fixed axes should be 0"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for dense-variable axes should be 1"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for sparse-fixed axes should be 1"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 2) + << "ValueError: The number of function parameters for sparse-variable axes should be 2"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for SparseBuffer should be 1"; + } else { + LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure"; + } + + sp_struct2param_map.Set(obj, params); + } + ObjectPtr node = make_object(); node->sp_iter_vars = std::move(sp_iter_vars); + node->sp_structs = std::move(sp_structs); node->sp_struct2param_map = std::move(sp_struct2param_map); node->name = std::move(name); node->body = std::move(body); @@ -982,10 +1013,10 @@ SparseBlock::SparseBlock(Array sp_iter_vars, } TVM_REGISTER_GLOBAL("tir.SparseBlock") - .set_body_typed([](Array sp_iter_vars, - Map> sp_struct2param_map, String name, Stmt body, + .set_body_typed([](Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, Optional init, Span span) { - return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span); + return SparseBlock(sp_iter_vars, sp_structs, sp_struct_params, name, body, init, span); }); TVM_REGISTER_NODE_TYPE(SparseBlockNode); diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index ee7337d60e7d3..b27539d55af61 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -33,6 +33,38 @@ namespace tvm { namespace tir { +Map UpdateBufferMap(PrimFunc f) { + struct BufferMapUpdater : public StmtVisitor { + explicit BufferMapUpdater(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} + + void VisitStmt_(const SparseBlockNode* sp_block) { + for (const auto& it : sp_block->sp_struct2param_map) { + if (const auto* dv_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + buffer_map_.Set(it.second[0], dv_axis->indptr); + } else if (const auto* sf_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + buffer_map_.Set(it.second[0], sf_axis->indices); + } else if (const auto* sv_axis = it.first.as()) { + ICHECK_EQ(it.second.size(), 2); + buffer_map_.Set(it.second[0], sv_axis->indptr); + buffer_map_.Set(it.second[1], sv_axis->indices); + } else if (const auto* sp_buffer = it.first.as()) { + ICHECK_EQ(it.second.size(), 1); + buffer_map_.Set(it.second[0], sp_buffer->data); + } + } + return; + } + + Map buffer_map_; + }; + + BufferMapUpdater updater(f->buffer_map); + updater(f->body); + return std::move(updater.buffer_map_); +} + /*! * \brief Check whether a given SparseBuffer contains the given axis. * \brief buffer The SparseBuffer to be checked @@ -67,19 +99,44 @@ class AccessAndDependencyCollector : public StmtExprVisitor { for (int k = 0; k < ndim; ++k) { const SpIterVar& sp_iter = kv_pair.second[k]; if (sp_iter->kind == SpIterKind::kDenseFixed || - sp_iter->kind == SpIterKind::kDenseVariable || !BufferContainsAxis(buffer, sp_iter->axis)) { continue; } - ICHECK(dependency_map_.count(sp_iter) == 0); - dependency_map_[sp_iter] = std::make_pair(buffer, k); + auto it = dependency_map_.find(sp_iter); + if (it == dependency_map_.end()) { + dependency_map_[sp_iter] = std::make_pair(buffer, k); + } else { + const Array& dependent_iters = buffer_access_map_[it->second.first]; + for (int i = 0; i < k; ++i) { + CHECK(kv_pair.second[i].same_as(dependent_iters[i])) + << "ValueError: A SpIterVar can only depend on a fixed set of iterators"; + } + } } } } - BufferAccessMap buffer_access_map_; - DependencyMap dependency_map_; + void GetIteratedBufferAndDependentIters(const SpIterVar& sp_iter, SparseBuffer* iterated_buffer, + Array* dependent_iters) { + std::pair dependent_pair = dependency_map_[sp_iter]; + Array buffer_access_iters = buffer_access_map_[dependent_pair.first]; + int n_dependent = dependent_pair.second; + + *iterated_buffer = std::move(dependent_pair.first); + *dependent_iters = Array(); + dependent_iters->reserve(n_dependent); + for (int i = 0; i < n_dependent; ++i) { + dependent_iters->push_back(buffer_access_iters[i]->var); + } + } + + SpIterVar GetSpIterFromIndex(PrimExpr index) { + auto it = var2sp_iter_map_.find(index.as()); + CHECK(it != var2sp_iter_map_.end()) + << "ValueError: Currently an index is only allowed to be SpIterVar"; + return it->second; + } private: void AddAccessPattern(const SparseBuffer& buffer, const Array& indices) { @@ -89,9 +146,7 @@ class AccessAndDependencyCollector : public StmtExprVisitor { Array iters; iters.reserve(ndim); for (int i = 0; i < ndim; ++i) { - const SpIterVarNode* sp_iter = indices[i].as(); - CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; - iters.push_back(GetRef(sp_iter)); + iters.push_back(GetSpIterFromIndex(indices[i])); } BufferAccessMap::iterator it = buffer_access_map_.find(buffer); @@ -106,6 +161,13 @@ class AccessAndDependencyCollector : public StmtExprVisitor { } } + void VisitStmt_(const SparseBlockNode* sp_block) final { + for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { + var2sp_iter_map_[sp_iter->var.get()] = sp_iter; + } + StmtVisitor::VisitStmt_(sp_block); + } + void VisitStmt_(const SparseBufferStoreNode* store) final { ExprVisitor::VisitExpr(store->value); AddAccessPattern(store->buffer, store->indices); @@ -114,13 +176,16 @@ class AccessAndDependencyCollector : public StmtExprVisitor { void VisitExpr_(const SparseBufferLoadNode* load) final { AddAccessPattern(load->buffer, load->indices); } + + BufferAccessMap buffer_access_map_; + DependencyMap dependency_map_; + std::unordered_map var2sp_iter_map_; }; class IndexTransformer : public StmtExprMutator { public: - explicit IndexTransformer(BufferAccessMap buffer_access_map, DependencyMap dependency_map) - : buffer_access_map_(std::move(buffer_access_map)), - dependency_map_(std::move(dependency_map)) {} + explicit IndexTransformer(AccessAndDependencyCollector collector) + : collector_(std::move(collector)) {} private: PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array& indices) { @@ -135,9 +200,8 @@ class IndexTransformer : public StmtExprMutator { const PrimExpr& index = indices[i]; // Stage 1. Get the sparse index. - const auto* sp_iter = index.as(); + SpIterVar sp_iter = collector_.GetSpIterFromIndex(index); PrimExpr sp_index{nullptr}; - CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0); PrimExpr r = AccumulateLowerIndex(add(lowered_index, 1), sp_buffer, i, 0); @@ -147,7 +211,7 @@ class IndexTransformer : public StmtExprMutator { CHECK(!axis->IsInstance()); if (const auto* df_axis = axis.as()) { CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); - sp_index = GetRef(sp_iter); + sp_index = sp_iter; } else { Var buffer_var; if (const auto* sf_axis = axis.as()) { @@ -165,24 +229,21 @@ class IndexTransformer : public StmtExprMutator { const auto* dv_axis = axis.as(); CHECK(dv_axis != nullptr); CHECK(sp_iter->axis.defined()); - sp_index = GetRef(sp_iter); + sp_index = sp_iter; } else if (kind == SpIterKind::kSparseFixed) { CHECK(!axis->IsInstance()); CHECK(sp_iter->axis.defined()); const Axis& iterated_axis = sp_iter->axis; - if (const auto* df_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + if (axis->IsInstance()) { sp_index = GetDenseValue(sp_iter); } else if (const auto* sf_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); if (iterated_axis.get() == sf_axis) { - sp_index = GetRef(sp_iter); + sp_index = sp_iter; } else { sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), std::move(r)); } } else if (const auto* sv_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), std::move(r)); } else { @@ -203,7 +264,7 @@ class IndexTransformer : public StmtExprMutator { } else if (const auto* sv_axis = axis.as()) { CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); if (iterated_axis.get() == sv_axis) { - sp_index = GetRef(sp_iter); + sp_index = sp_iter; } else { sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), std::move(r)); @@ -224,8 +285,10 @@ class IndexTransformer : public StmtExprMutator { PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim, PrimExpr index) { const Axis& axis = sp_buffer->axes[dim]; - if (axis->IsInstance() || axis->IsInstance()) { + if (axis->IsInstance()) { return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index)); + } else if (const auto* sf_axis = axis.as()) { + return ana_.Simplify(std::move(prev_lowered_index) * sf_axis->num_cols + std::move(index)); } else if (const auto* dv_axis = axis.as()) { return ana_.Simplify( add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); @@ -237,18 +300,17 @@ class IndexTransformer : public StmtExprMutator { throw; } - PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) { + PrimExpr GetDenseValue(SpIterVar sp_iter) { SpIterKind kind = sp_iter->kind; CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); Axis iterated_axis = sp_iter->axis; - std::pair dependent_pair = dependency_map_[GetRef(sp_iter)]; - Array buffer_access_iters = buffer_access_map_[dependent_pair.first]; - int n_dependent = dependent_pair.second; + SparseBuffer iterated_buffer{nullptr}; + Array iters{nullptr}; - Array dependent_iters{buffer_access_iters.begin(), - buffer_access_iters.begin() + n_dependent}; - PrimExpr lowered_indices = LowerIndices(dependent_pair.first, dependent_iters); + collector_.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &iters); + iters.push_back(sp_iter); + PrimExpr lowered_indices = LowerIndices(std::move(iterated_buffer), iters); if (kind == SpIterKind::kSparseFixed) { return BufferLoad(Downcast(iterated_axis)->indices, @@ -260,29 +322,146 @@ class IndexTransformer : public StmtExprMutator { } PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { + buffer_read_.insert(load->buffer.get()); PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices); return BufferLoad(load->buffer->data, {std::move(lowered_indices)}); } Stmt VisitStmt_(const SparseBufferStoreNode* store) final { + buffer_write_.insert(store->buffer.get()); PrimExpr value = ExprMutator::VisitExpr(store->value); PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices); return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)}); } - BufferAccessMap buffer_access_map_; - DependencyMap dependency_map_; + Stmt VisitStmt_(const SparseBlockNode* sp_block) { + int n_iter = static_cast(sp_block->sp_iter_vars.size()); + buffer_read_.clear(); + buffer_write_.clear(); + + // Step 1. Recursively mutate the `init` field and the block body. + Optional init = + sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional(NullOpt); + Stmt body = VisitStmt(sp_block->body); + + // Step 2. Create the new outer loop vars. + Array loop_vars; + std::unordered_map var_map; + loop_vars.reserve(n_iter); + var_map.reserve(n_iter); + for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { + Var loop_var("v_" + sp_iter->var->name_hint); + loop_vars.push_back(loop_var); + var_map[sp_iter->var.get()] = loop_var; + } + + // Step 3. Create block iters and iter bindings. + Array block_iters; + Array iter_bindings; + block_iters.reserve(n_iter); + iter_bindings.reserve(n_iter); + for (int i = 0; i < n_iter; ++i) { + block_iters.push_back(SpIterVar2IterVar(sp_block->sp_iter_vars[i], var_map)); + iter_bindings.push_back(loop_vars[i]); + } + + // Step 4. Generate the read-region and write-retion of the block. + Array reads{nullptr}; + Array writes{nullptr}; + GenerateReadWriteRegions(sp_block, &reads, &writes); + + // Step 5. Create the block and block-realize + Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body), + std::move(init)); + BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); + + // Step 6. Create outer loops and the block binding. + Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); + + return loop; + } + + IterVar SpIterVar2IterVar(const SpIterVar& sp_iter, + const std::unordered_map& var_map) { + PrimExpr extent{nullptr}; + + SpIterKind kind = sp_iter->kind; + if (kind == SpIterKind::kDenseFixed || kind == SpIterKind::kSparseFixed) { + extent = sp_iter->max_extent; + } else { + SparseBuffer iterated_buffer{nullptr}; + Array dependent_iters{nullptr}; + collector_.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &dependent_iters); + PrimExpr lowered_indices = LowerIndices(std::move(iterated_buffer), dependent_iters); + + Buffer indptr{kind == SpIterKind::kDenseVariable + ? Downcast(sp_iter->axis)->indptr + : Downcast(sp_iter->axis)->indptr}; + PrimExpr l = BufferLoad(indptr, {lowered_indices}); + PrimExpr r = BufferLoad(indptr, {add(lowered_indices, 1)}); + extent = sub(r, l); + } + + // Substitute the iteration vars in the expression with the loop vars. + return IterVar(Range::FromMinExtent(0, Substitute(std::move(extent), var_map)), sp_iter->var, + sp_iter->is_reduction ? kCommReduce : kDataPar); + } + + void GenerateReadWriteRegions(const SparseBlockNode* sp_block, Array* reads, + Array* writes) { + for (const ObjectRef& obj : sp_block->sp_structs) { + if (const auto* dv_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(dv_axis->indptr)); + } else if (const auto* sf_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(sf_axis->indices)); + } else if (const auto* sv_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(sv_axis->indptr)); + reads->push_back(BufferRegion::FullRegion(sv_axis->indices)); + } else if (const auto* sp_buffer = obj.as()) { + if (buffer_read_.count(sp_buffer)) { + reads->push_back(BufferRegion::FullRegion(sp_buffer->data)); + } + if (buffer_write_.count(sp_buffer)) { + writes->push_back(BufferRegion::FullRegion(sp_buffer->data)); + } + } + } + } + + Stmt GenerateLoops(Stmt body, const Array& block_iters, const Array& loop_vars) { + int n_iter = static_cast(block_iters.size()); + for (int i = n_iter - 1; i >= 0; --i) { + const Range& dom = block_iters[i]->dom; + body = For(loop_vars[i], dom->min, dom->extent, ForKind::kSerial, std::move(body)); + } + return body; + } + + AccessAndDependencyCollector collector_; arith::Analyzer ana_; + std::unordered_set buffer_read_; + std::unordered_set buffer_write_; }; +Stmt WrapWithRootBlock(Stmt body) { + Block root_block({}, {}, {}, "root", std::move(body)); + body = BlockRealize({}, const_true(), std::move(root_block)); + return Stmt(body); +} + PrimFunc LowerSparseTIR(PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { PrimFuncNode* fptr = f.CopyOnWrite(); + // Step 1. Update the PrimFunc's buffer map. + fptr->buffer_map = UpdateBufferMap(f); + // Step 2. Collect buffer access information and dependency. AccessAndDependencyCollector collector; collector.Collect(f->body); - fptr->body = IndexTransformer(collector.buffer_access_map_, - collector.dependency_map_)(std::move(f->body)); + // Step 3. Lower indices. + fptr->body = IndexTransformer(collector)(std::move(f->body)); + // Step 4. Wrap the function body with a root block. + fptr->body = WrapWithRootBlock(std::move(fptr->body)); return f; } else { return f; diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py new file mode 100644 index 0000000000000..684886b4cd765 --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from os import replace +from numpy.core.fromnumeric import size +from scipy.sparse import bsr +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T + + +@T.prim_func +def csrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), K), m * k, "float32") + C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + with T.iter([T.cord(I), T.pos(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def lowered_csrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [m * k], dtype="float32") + C_data = T.match_buffer(c, [n * k], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, n): + for v_vj, v_vk in T.grid(J_indptr[v_vi + 1] - J_indptr[v_vi], k): + with T.block("csrmm"): + vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) + T.reads( + [ + J_indptr[0 : n + 1], + J_indices[0:nnz], + A_data[0:nnz], + B_data[0 : m * k], + C_data[0 : n * k], + ] + ) + T.writes([C_data[0 : n * k]]) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = ( + C_data[vi * k + vk] + + A_data[J_indptr[vi] + vj] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] + ) + + +@T.prim_func +def csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (I,), n, "float32") + with T.iter([T.cord(I), T.pos(J)], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def lowered_csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [n], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, n): + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("csr_reduce"): + vi, vj = T.axis.remap("SR", [v_vi, v_vj]) + T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]]) + T.writes([B_data[0:n]]) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + + +@T.prim_func +def bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + I = T.dense_fixed(nb) + J = T.sparse_variable((mb, nb + 1, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnzb * blk * blk, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), mb * blk * feat_size, "float32") + C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + + with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSRS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def lowered_bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + A_data = T.match_buffer(a, [nnzb * blk * blk], dtype="float32") + B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") + C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") + J_indptr = T.match_buffer(indptr, [nb + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnzb], dtype="int32") + for v_vi in T.serial(0, nb): + for v_vj, v_vbi, v_vbj, v_vf in T.grid( + J_indptr[v_vi + 1] - J_indptr[v_vi], blk, blk, feat_size + ): + with T.block("bsrmm"): + vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) + T.reads( + [ + J_indptr[0 : nb + 1], + J_indices[0:nnzb], + A_data[0 : nnzb * blk * blk], + B_data[0 : mb * blk * feat_size], + C_data[0 : nb * blk * feat_size], + ] + ) + T.writes([C_data[0 : nb * blk * feat_size]]) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + C_data[(vi * blk + vbi) * feat_size + vf] = ( + C_data[(vi * blk + vbi) * feat_size + vf] + + A_data[((J_indptr[vi] + vj) * blk + vbi) * blk + vbj] + * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] + ) + + +@T.prim_func +def ellpack_mm( + a: T.handle, + b: T.handle, + c: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + nnz: T.int32, + col: T.int32, + blk: T.int32, +) -> None: + I = T.dense_fixed(nb) + J = T.sparse_fixed((mb, nnz, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnz * blk * blk, "float32") + B = T.match_sparse_buffer(b, (T.to_dense(J), BJ, F), mb * blk * feat_size, "float32") + C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + + with T.iter([T.cord(I), T.pos(J), T.cord(BI), T.cord(BJ), T.cord(F)], "SRSRS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def lowered_ellpack_mm( + a: T.handle, + b: T.handle, + c: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + nnz: T.int32, + col: T.int32, + blk: T.int32, +) -> None: + A_data = T.match_buffer(a, [nnz * blk * blk], dtype="float32") + B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") + C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): + with T.block("bsrmm"): + vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) + T.reads( + [ + J_indices[0:nnz], + A_data[0 : nnz * blk * blk], + B_data[0 : mb * blk * feat_size], + C_data[0 : nb * blk * feat_size], + ] + ) + T.writes([C_data[0 : nb * blk * feat_size]]) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + C_data[(vi * blk + vbi) * feat_size + vf] = ( + C_data[(vi * blk + vbi) * feat_size + vf] + + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] + * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] + ) + + +@T.prim_func +def batch_mm( + a: T.handle, + b: T.handle, + c: T.handle, + i_indptr: T.handle, + j_a_indptr: T.handle, + j_b_indptr: T.handle, + k_b_indptr: T.handle, + k_c_indptr: T.handle, + batch: T.int32, + n_max: T.int32, + m_max: T.int32, + k_max: T.int32, + nnz_ac1: T.int32, + nnz_b1: T.int32, + nnz_a2: T.int32, + nnz_b2: T.int32, + nnz_c2: T.int32, +) -> None: + Batch = T.dense_fixed(batch) + I = T.dense_variable((n_max, batch + 1), i_indptr, "int32") + J_a = T.dense_variable((m_max, nnz_ac1 + 1), j_a_indptr, "int32") + J_b = T.dense_variable((m_max, batch + 1), j_b_indptr, "int32") + K_b = T.dense_variable((k_max, nnz_b1 + 1), k_b_indptr, "int32") + K_c = T.dense_variable((k_max, nnz_ac1 + 1), k_c_indptr, "int32") + A = T.match_sparse_buffer(a, (Batch, I, J_a), nnz_a2, "float32") + B = T.match_sparse_buffer(b, (Batch, J_b, K_b), nnz_b2, "float32") + C = T.match_sparse_buffer(c, (Batch, I, K_c), nnz_c2, "float32") + + with T.iter([T.cord(Batch), T.cord(I), T.cord(J_a), T.cord(K_b)], "SSSR", "batch_mm") as [ + vb, + vi, + vj, + vk, + ]: + with T.init(): + C[vb, vi, vk] = 0.0 + C[vb, vi, vk] = C[vb, vi, vk] + A[vb, vi, vj] * B[vb, vj, vk] + + +@T.prim_func +def csr_element_wise( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(m) + J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (I, J), nnz, "float32") + + with T.iter([T.cord(I), T.pos(J)], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.5 + + +@T.prim_func +def lowered_csr_element_wise( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + nnz: T.int32, +) -> None: + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [nnz], dtype="float32") + J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, m): + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("csr_element_wise"): + vi, vj = T.axis.remap("SS", [v_vi, v_vj]) + T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]]) + T.writes([B_data[0:nnz]]) + B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) + + +def test_csrmm(): + mod = tvm.IRModule.from_expr(csrmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) + + A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") + x = np.random.rand(512, 128).astype("float32") + y_ground_truth = A * x + y = np.zeros((512, 128)).astype("float32") + + n, m, k, nnz = mod["main"].params[-4:] + f = tvm.build(mod["main"].specialize({n: 512, m: 512, k: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y.reshape(-1), device=ctx) + f(A_data, X_nd, Y_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_csr_reduce(): + mod = tvm.IRModule.from_expr(csr_reduce) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) + + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = np.array(np.sum(A, axis=1)) + b = np.zeros((128,)).astype("float32") + + n, m, nnz = csr_reduce.params[-3:] + f = tvm.build(mod["main"].specialize({n: 128, m: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_bsrmm(): + mod = tvm.IRModule.from_expr(bsrmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) + + block_size = 16 + nb = 32 + mb = 32 + feat_size = 256 + n = nb * block_size + m = mb * block_size + + A_block = sp.random(mb, nb, dtype="float32", density=0.05, format="csr") + indptr = A_block.indptr + indices = A_block.indices + nnzb = A_block.nnz + data = np.random.rand(nnzb, block_size, block_size) + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") + y_ground_truth = A * x + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_nnzb, v_blk, v_feat_size = bsrmm.params[-5:] + f = tvm.build( + mod["main"].specialize( + {v_nb: nb, v_mb: mb, v_nnzb: nnzb, v_blk: block_size, v_feat_size: feat_size} + ), + target="llvm", + ) + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) + f(A_data, X_nd, Y_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_ellpack_mm(): + mod = tvm.IRModule.from_expr(ellpack_mm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True) + + nnz_cols = 4 + nb = 64 + mb = 64 + feat_size = 1024 + nnz = nb * nnz_cols + block_size = 16 + n = nb * block_size + m = mb * block_size + + rng = np.random.default_rng() + indptr = np.arange(0, (nb + 1) * nnz_cols, nnz_cols) + indices = np.array([rng.choice(mb, size=nnz_cols, replace=False) for i in range(nb)]) + order = indices.argsort(axis=1) + indices = np.array([indices[i, order[i]] for i in range(0, nb)]).reshape(-1) + data = np.random.rand(nnz, block_size, block_size) + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") + y_ground_truth = A * x + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_feat_size, v_nnz, v_col, v_blk = ellpack_mm.params[-6:] + f = tvm.build( + mod["main"].specialize( + { + v_nb: nb, + v_mb: mb, + v_feat_size: feat_size, + v_nnz: nnz, + v_col: nnz_cols, + v_blk: block_size, + } + ), + target="llvm", + ) + + ctx = tvm.cpu(0) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) + f(A_data, X_nd, Y_nd, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_batch_mm(): + mod = tvm.IRModule.from_expr(batch_mm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + # print(mod["main"].script(tir_prefix="T")) + + +def test_csr_element_wise(): + mod = tvm.IRModule.from_expr(csr_element_wise) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) + + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = A * 2.5 + b = np.zeros((A.nnz,)).astype("float32") + + m, n, nnz = csr_element_wise.params[-3:] + f = tvm.build(mod["main"].specialize({m: 128, n: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + test_csrmm() + test_csr_reduce() + test_bsrmm() + test_ellpack_mm() + test_batch_mm() + test_csr_element_wise() diff --git a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py index 17e2d8f04c9b7..84fbdd707e7a8 100644 --- a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py +++ b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py @@ -141,11 +141,11 @@ def batch_mm( B = T.match_sparse_buffer(b, (Batch, J_b, K_b), nnz_b2, "float32") C = T.match_sparse_buffer(c, (Batch, I, K_c), nnz_c2, "float32") - with T.iter([T.cord(Batch), T.cord(I), T.cord(K_b), T.cord(J_a)], "SSSR", "batch_mm") as [ + with T.iter([T.cord(Batch), T.cord(I), T.cord(J_a), T.cord(K_b)], "SSSR", "batch_mm") as [ vb, vi, - vk, vj, + vk, ]: with T.init(): C[vb, vi, vk] = 0.0