diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 6043dce377de..9c88dd4c7218 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -50,18 +50,15 @@ class AxisNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("length", &length); - v->Visit("is_derived_axis", &is_derived_axis); } bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && - equal(is_derived_axis, other->is_derived_axis); + return equal(name, other->name) && equal(length, other->length); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(name); hash_reduce(length); - hash_reduce(is_derived_axis); } /* name of current axis. */ @@ -69,13 +66,12 @@ class AxisNode : public Object { /* length of current axis. For sparse axis, length refers to the upperbound of * the current axis. */ PrimExpr length; - /* indicates whether current axis is derived by dense(axis) or fuse(axis1, axis2, ...) */ - bool is_derived_axis = false; String GetName() const { return name; } PrimExpr GetLength() const { return length; } DataType GetIndexType() const { return length->dtype; } virtual Optional GetParentAxis() const = 0; + Axis GetRootAxis() const; virtual AxisKind kind() const = 0; virtual PrimExpr nnz() const = 0; @@ -266,7 +262,7 @@ class DenseVariableAxisNode : public DenseAxisNode { Optional GetParentAxis() const final { return parent_; } static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; - TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); + TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; /*! @@ -281,6 +277,30 @@ class DenseVariableAxis : public DenseAxis { TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; +/*! + * \brief Dense variable axis attached to another dense variable axis. + */ +class AttachedAxisNode : public DenseVariableAxisNode { + public: + /* The original axis before attaching. */ + Axis orig_; + + Axis GetOriginalAxis() const { return orig_; } + + static constexpr const char* _type_key = "tir.sparse.AttachedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode); +}; + +/*! + * \brief Managed reference to AttachedAxisNode. + * \sa AttachedAxisNode + */ +class AttachedAxis : public DenseVariableAxis { + public: + TVM_DLL explicit AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr); + TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode); +}; + /*! * \brief Sparse axis with fixed number of non-zero columns per row. */ diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 1acfa85767e1..e1cd6d3f3bc3 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -17,6 +17,7 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level +from os import name from typing import Callable, List, Optional, Tuple, Any, Mapping, Union import synr @@ -35,6 +36,7 @@ DenseVariableAxis, SparseFixedAxis, SparseVariableAxis, + AttachedAxis, ) from .node import BufferSlice @@ -946,6 +948,38 @@ def dense_variable( super().__init__(dense_variable, def_symbol=True) +@register +class Attach(SpecialStmt): + """Special Stmt for attaching axis.""" + + def __init__(self): + def attach_axis( + parent: Axis, + orig: Axis, + nnz: PrimExpr, + indptr_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`attach_axis` expected assign to only one var, but got {names}", span + ) + + indptr_len = orig.nnz + 1 + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span + ) + axis = AttachedAxis(names[0], parent, orig, nnz, indptr_buf) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var]) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) + + super().__init__(attach_axis, def_symbol=True) + + @register class SparseFixed(SpecialStmt): """Special Stmt for creating sparse fixed axis.""" diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index d4df74f9685c..d75e4337f483 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -144,6 +144,36 @@ def __init__(self, name, parent, length, nnz, indptr): ) +@tvm._ffi.register_object("tir.sparse.AttachedAxis") +class AttachedAxis(DenseVariableAxis): + """AttachedAxis node + + Parameters + ---------- + name : str + The name of the axis. + parent : Axis + The axis to attach to. + orig : Axis + The axis to be attached. + nnz : PrimExpr + The number of nonzeros of the returned axis. + indptr : PrimExpr + The new indptr array of the the returned axis. + """ + + name : str + parent : Axis + orig : Axis + nnz : PrimExpr + indptr : PrimExpr + + def __init__(self, name, parent, length, nnz, indptr): + self.__init_handle_by_constructor__( + _ffi_api.AttachedAxis, name, parent, length, nnz, indptr + ) + + @tvm._ffi.register_object("tir.sparse.SparseFixedAxis") class SparseFixedAxis(DenseAxis): """SparseFixedAxis node diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 66490c10ab1a..3789cf6f0ecd 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -469,14 +469,12 @@ Doc TVMScriptPrinter::AllocAxis(const Axis& axis) { return it->second; } Doc val; - const auto* df_axis = axis.as(); - - if (df_axis != nullptr && df_axis->is_derived_axis) { - if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { - val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")"); - } else { - CHECK(false) << "Cannot allocate fused axis"; - } + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + // DenseFromSparseAxis is a temporally defined axis. + val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")"); + } else if (axis.as()) { + // FusedAxis is also a temporally defined axis. + CHECK(false) << "Cannot allocate fused axis"; } else { std::string name = axis->name; if (name.length() == 0 || !std::isalnum(name[0])) { @@ -1336,19 +1334,16 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) { Doc iter_doc; std::string axis_repr = sp_iter->axis->name; - if (axis->is_derived_axis) { - if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { - iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")"; + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")"; + } else if (const FusedAxisNode* fused_axis = axis.as()) { + std::string orig_axis_name = fused_axis->group[fused_axis->index]->name; + if (fused_axis->index == 0) { + iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name; + } else if (fused_axis->index == int(fused_axis->group.size() - 1)) { + iter_doc << orig_axis_name << ")"; } else { - const FusedAxisNode* fused_axis = axis.as(); - std::string orig_axis_name = fused_axis->group[fused_axis->index]->name; - if (fused_axis->index == 0) { - iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name; - } else if (fused_axis->index == fused_axis->group.size() - 1) { - iter_doc << orig_axis_name << ")"; - } else { - iter_doc << orig_axis_name; - } + iter_doc << orig_axis_name; } } else { iter_doc << axis->name; @@ -1421,10 +1416,17 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo ICHECK_EQ(params.size(), 0); doc << "dense_fixed(" << Print(df_axis->length) << ")"; } else if (const auto* dv_axis = obj.as()) { - ICHECK_EQ(params.size(), 1); - doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) << ", " - << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", " - << PrintDType(dv_axis->indptr->dtype) << ")"; + if (const auto* attached_axis = obj.as()) { + ICHECK_EQ(params.size(), 1); + doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name + << ", " << Print(attached_axis->nnz()) << ", " << Print(params[0]) << ", " + << PrintDType(attached_axis->indptr->dtype) << ")"; + } else { + ICHECK_EQ(params.size(), 1); + doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) + << ", " << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", " + << PrintDType(dv_axis->indptr->dtype) << ")"; + } } else if (const auto* sf_axis = obj.as()) { ICHECK_EQ(params.size(), 1); doc << "sparse_fixed(" << sf_axis->parent_->name << ", (" << Print(sf_axis->length) << ", " diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 7c6225f59cdb..7897b5bae01c 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -45,6 +45,18 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); }); +/******** AxisNode ********/ + +/*! \brief Implementation of get root axis function. */ +Axis AxisNode::GetRootAxis() const { + Optional parent = GetParentAxis(); + if (parent.defined()) { + return parent.value()->GetRootAxis(); + } else { + return GetRef(this); + } +} + /******** DenseFixedAxis ********/ /*! \brief Default constructor of DenseFixedAxis */ @@ -67,35 +79,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "dense_fixed(" << op->name << ", " << op->length << ")"; }); -/******** DenseVariableAxis ********/ - -/*! \brief Default constuctor of DenseVariableAxis */ -DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, - Buffer indptr) { - ObjectPtr node = make_object(); - node->name = std::move(name); - node->parent_ = std::move(parent); - node->length = std::move(length); - node->nnz_ = std::move(nnz); - node->indptr = std::move(indptr); - data_ = std::move(node); -} - -TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); - -TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") - .set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) { - return DenseVariableAxis(std::move(name), std::move(parent), std::move(length), - std::move(nnz), std::move(indptr)); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name - << ")"; - }); - /******** DenseFromSparseAxis ********/ /*! \brief Default constructor of DenseFromSparseAxis */ @@ -103,7 +86,6 @@ DenseFromSparseAxis::DenseFromSparseAxis(SparseAxis base) { ObjectPtr node = make_object(); node->name = base->name + "_dense"; node->length = base->length; - node->is_derived_axis = true; node->base = std::move(base); data_ = std::move(node); } @@ -135,7 +117,6 @@ FusedAxis::FusedAxis(Array group, int index) { } node->name = "fused_" + fused_name + "_" + group[index]->name; node->length = group[index]->nnz(); - node->is_derived_axis = true; node->group = std::move(group); node->index = index; data_ = std::move(node); @@ -163,6 +144,63 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +/******** DenseVariableAxis ********/ + +/*! \brief Default constuctor of DenseVariableAxis */ +DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, + Buffer indptr) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->length = std::move(length); + node->nnz_ = std::move(nnz); + node->indptr = std::move(indptr); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") + .set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) { + return DenseVariableAxis(std::move(name), std::move(parent), std::move(length), + std::move(nnz), std::move(indptr)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name + << ")"; + }); + +/******** AttachedAxis ********/ +/*! \brief Default constructor of AttachedAxis */ +AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->orig_ = std::move(orig); + node->length = node->orig_->length; + node->nnz_ = std::move(nnz); + node->indptr = std::move(indptr); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(AttachedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis") + .set_body_typed([](String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) { + return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz), + std::move(indptr)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "attached_axis(" << op->name << ", " << op->length << ", " << op->indptr->name + << ")"; + }); + /******** SparseFixedAxis ********/ /*! \brief Default constructor of SparseFixedAxis */ diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 1d91f6a4969c..5e28b1974a06 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -105,6 +105,7 @@ PrimExpr AggregateOffset(PrimExpr prev_offset, Axis axis, PrimExpr index, break; } case AxisKind::kDenseVariable: { + // TODO(zihao): finish the aggregating offset for attached axis. auto dv_axis = axis.as(); new_offset = add(BufferLoad(dv_axis->indptr, {std::move(prev_offset)}), std::move(index)); break; diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 086b8094be7b..9c6bcc2c2bd9 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -310,6 +310,40 @@ def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices T.block_attr({"sparse": True}) B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) +@T.prim_func +def bmm( + x: T.handle, + y: T.handle, + z: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_ij: T.handle, + indptr_jk: T.handle, + indptr_ik: T.handle, + batch_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_ij: T.int32, + nnz_jk: T.int32, + nnz_ik: T.int32 +) -> None: + B = T.dense_fixed(batch_size) + I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32") + J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32") + IJ = T.attach_axis(I, J, nnz_ij, indptr_ij, "int32") + JK = T.attach_axis(J, K, nnz_jk, indptr_jk, "int32") + IK = T.attach_axis(I, K, nnz_ik, indptr_ik, "int32") + X = T.match_sparse_buffer(x, (B, I, IJ), "float32") + Y = T.match_sparse_buffer(y, (B, J, JK), "float32") + Z = T.match_sparse_buffer(z, (B, I, IK), "float32") + with T.iter([B, I, J, K], "SSRS", "bmm") as [vb, vi, vj, vk]: + with T.init(): + Z[vb, vi, vk] = 0. + Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj] + def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) @@ -475,6 +509,12 @@ def test_csr_element_wise(): tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) +def test_bmm(): + mod = tvm.IRModule.from_expr(bmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod['main'].script()) + + if __name__ == "__main__": test_csrmm() test_csrmm_dense_iter() @@ -483,3 +523,4 @@ def test_csr_element_wise(): test_bsrmm() test_ellpack_mm() test_csr_element_wise() + test_bmm()