From 9660c40cb556d13e7814f55457cc8426231df423 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 1 Nov 2021 11:19:15 +0800 Subject: [PATCH] [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting --- include/tvm/tir/sparse.h | 15 +++++++------- python/tvm/tir/sparse.py | 6 +++++- src/tir/ir/sparse.cc | 45 +++++++++++++++++----------------------- 3 files changed, 32 insertions(+), 34 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index ac40fea615a1..a6fbbda19a91 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -187,7 +187,7 @@ class SparseAxis : public Axis { */ class SparseFixedAxisNode : public SparseAxisNode { public: - Buffer indices; + Buffer indices; /* fixed number of columns of current sparse axis. */ PrimExpr num_cols; @@ -267,7 +267,6 @@ class SparseVariableAxis : public SparseAxis { TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); }; - /*! * \brief Axis Dependency Tree. */ @@ -314,9 +313,7 @@ class SparseBufferNode : public Object { /* Data type */ runtime::DataType dtype; - inline int ndim() const { - return static_cast(axes.size()); - } + inline int ndim() const { return static_cast(axes.size()); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &tree); @@ -370,24 +367,28 @@ class SpIterVarNode : public Object { Var var; PrimExpr max_extent; SpIterKind kind; + bool is_reduction; Optional axis; void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("max_extent", &max_extent); v->Visit("axis", &axis); + v->Visit("is_reduction", &is_reduction); v->Visit("kind", &kind); } bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const { return equal(var, other->var) && equal(max_extent, other->max_extent) && - equal(axis, other->axis) && equal(kind, other->kind); + equal(axis, other->axis) && equal(is_reduction, other->is_reduction) && + equal(kind, other->kind); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(var); hash_reduce(max_extent); hash_reduce(axis); + hash_reduce(is_reduction); hash_reduce(kind); } @@ -399,7 +400,7 @@ class SpIterVarNode : public Object { class SpIterVar : public ObjectRef { public: - TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, + TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Optional axis = NullOpt); /*! diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 7f5c38585980..09cf6a3e9f8d 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -214,6 +214,9 @@ class SpIterVar(Object): kind : int The kind of the SpIterVar + + is_reduction : bool + Whether the SpIterVar is a reduction iterator axis : Optional[Axis] The axis over which the SpIterVar iterates. Required to be defined @@ -222,6 +225,7 @@ class SpIterVar(Object): var: Var max_extent: PrimExpr kind: int + is_reduction: bool axis: Optional[Axis] DenseFixed = 0 @@ -231,6 +235,6 @@ class SpIterVar(Object): def __init__(self, var, max_extent, kind, axis=None): self.__init_handle_by_constructor__( - _ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore + _ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore ) diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 17eca58bcf7a..9154d96f818f 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -38,14 +38,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); -TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") - .set_body_typed([](String name, PrimExpr length) { - return DenseFixedAxis(name, length); - }); +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(name, length); +}); // DenseVariableAxis -DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, - Buffer indptr) { +DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -61,8 +59,7 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") }); // SparseFixedAxis -SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, - PrimExpr num_cols) { +SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -74,16 +71,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indices, - PrimExpr num_cols) { + .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { return SparseFixedAxis(name, length, indices, num_cols); }); // SparseVariableAxis -SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, - Buffer indptr, Buffer indices) { - ObjectPtr node = - make_object(); +SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, + Buffer indices) { + ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); node->indptr = std::move(indptr); @@ -94,14 +89,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indptr, - Buffer indices) { + .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { return SparseVariableAxis(name, length, indptr, indices); }); // AxisTree -AxisTree::AxisTree(Array axes, - Array> axis_parent_names) { +AxisTree::AxisTree(Array axes, Array> axis_parent_names) { CHECK_EQ(axes.size(), axis_parent_names.size()) << "ValueError: The axes array should have the same length as axis_parent_names " "array."; @@ -121,9 +114,7 @@ AxisTree::AxisTree(Array axes, CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end()) << "ValueError: Parent axis name doesn't exist."; } - Axis parent_axis = (parent_name.get() != nullptr) - ? node->axis_map[parent_name.value()] - : root; + Axis parent_axis = (parent_name.get() != nullptr) ? node->axis_map[parent_name.value()] : root; node->parent[axis] = parent_axis; if (node->children.find(parent_axis) != node->children.end()) { node->children[parent_axis].push_back(axis); @@ -139,8 +130,7 @@ AxisTree::AxisTree(Array axes, TVM_REGISTER_NODE_TYPE(AxisTreeNode); TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") - .set_body_typed([](Array axes, - Array> axis_parent_names) { + .set_body_typed([](Array axes, Array> axis_parent_names) { return AxisTree(axes, axis_parent_names); }); @@ -164,7 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") }); // SpIterVar -SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional axis) { +SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, + Optional axis) { ObjectPtr node = make_object(); if (kind != SpIterKind::kDenseFixed) { @@ -175,6 +166,7 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional node->var = Var(std::move(name)); node->max_extent = std::move(max_extent); node->kind = kind; + node->is_reduction = is_reduction; node->axis = std::move(axis); data_ = std::move(node); } @@ -182,8 +174,9 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional TVM_REGISTER_NODE_TYPE(SpIterVarNode); TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") - .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional axis) { - return SpIterVar(name, max_extent, kind, axis); + .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, + Optional axis) { + return SpIterVar(name, max_extent, kind, is_reduction, axis); }); } // namespace tir