Skip to content

Commit

Permalink
[SparseTIR] SparseTIR Lowering (#20)
Browse files Browse the repository at this point in the history
* Fix a previous bug of sparse-fixed SpIterVar creation

* Fix a previous bug in `GetDenseValue`

* Refactor Collector and IndexTransformer

* Construct block and loops

* Fix a previous bug which rejects DV iters in collector

* Update buffer map

* Create root block

* Fix bug of sparse-fixed SpIterVar creation

* Fix bug on SpIterVar conversion (with refactor)

* Fix bug when getting dependent SpIterVars

* Fix bug on dependency map and index lowering

* Full block read/write region

* Test version 1

* Fix bug of loop order

* Fix bug of batch-mm iterator ordering

* Update PrimFunc args to use symbolic params

* Fix bug of test "csr_element_wise"

* Fix bug of index accumulation for sparse-fixed axis

* Update correctness test

* Test structural equality

* Refactor and use Array
  • Loading branch information
MasterJH5574 authored and yzh119 committed Nov 15, 2021
1 parent c120956 commit d2bccc2
Show file tree
Hide file tree
Showing 14 changed files with 843 additions and 84 deletions.
10 changes: 0 additions & 10 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,11 @@ class DenseFixedAxisNode : public DenseAxisNode {
}

bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(from_sparse, other->from_sparse);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(from_sparse);
Expand Down Expand Up @@ -170,12 +168,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand Down Expand Up @@ -213,13 +209,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indices);
Expand Down Expand Up @@ -257,13 +251,11 @@ class SparseVariableAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand Down Expand Up @@ -347,12 +339,10 @@ class SparseBufferNode : public Object {
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(axes);
hash_reduce(data);
hash_reduce(name);
Expand Down
13 changes: 8 additions & 5 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,8 @@ class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse data structures */
Array<ObjectRef> sp_structs;
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
Map<ObjectRef, Array<Var>> sp_struct2param_map;
/*! \brief The name of the block */
Expand All @@ -1296,6 +1298,7 @@ class SparseBlockNode : public StmtNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("sp_iter_vars", &sp_iter_vars);
v->Visit("sp_structs", &sp_structs);
v->Visit("sp_struct2param_map", &sp_struct2param_map);
v->Visit("name", &name);
v->Visit("body", &body);
Expand All @@ -1305,15 +1308,15 @@ class SparseBlockNode : public StmtNode {
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) &&
equal(body, other->body) && equal(init, other->init) &&
equal(sp_struct2param_map, other->sp_struct2param_map);
equal(sp_structs, other->sp_structs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(sp_iter_vars);
hash_reduce(name);
hash_reduce(body);
hash_reduce(init);
hash_reduce(sp_struct2param_map);
hash_reduce(sp_structs);
}

static constexpr const char* _type_key = "tir.SparseBlock";
Expand All @@ -1326,9 +1329,9 @@ class SparseBlockNode : public StmtNode {
*/
class SparseBlock : public Stmt {
public:
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init = NullOpt, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ class ContextMaintainer:
"""Mapping[Var, str]: The map from var to env thread"""

# sparse block context
sp_struct2param_map: Mapping[Object, List[Var]] = {}
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""
sp_struct: List[Object] = []
"""List[Object]: The sparse data structures"""
sp_struct_params: List[List[Var]] = []
"""List[List[Var]]: The function parameters that corresponding to each sparse data structures"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
Expand All @@ -155,7 +157,8 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
self.func_dict_attr = {}
self.func_var_env_dict = {}
# sparse block context
self.sp_struct2param_map = {}
self.sp_struct = []
self.sp_struct_params = []
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,6 @@ def pos(axis: Axis, span: Optional[Span] = None):
elif isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
elif isinstance(axis, SparseFixedAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis)
return SpIterVar(var_temp, axis.num_cols, SpIterVar.SparseFixed, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)
3 changes: 2 additions & 1 deletion python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = No

block = tvm.tir.SparseBlock(
sp_iters,
self.context.sp_struct2param_map,
self.context.sp_struct,
self.context.sp_struct_params,
name,
self.body,
block_info.init,
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ def dense_fixed(length: PrimExpr, span: Optional[Span] = None):
)

axis = DenseFixedAxis(names[0], length)
self.context.sp_struct2param_map[axis] = []
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([])
self.context.update_symbol(names[0], axis, self.node)

super().__init__(dense_fixed, def_symbol=True)
Expand Down Expand Up @@ -860,7 +861,8 @@ def dense_variable(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = DenseVariableAxis(names[0], length, indptr_buf)
self.context.sp_struct2param_map[axis] = [indptr_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)

Expand Down Expand Up @@ -889,7 +891,8 @@ def sparse_fixed(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols)
self.context.sp_struct2param_map[axis] = [indices_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indices_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)

Expand Down Expand Up @@ -922,7 +925,8 @@ def sparse_variable(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf)
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var, indices_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)
Expand Down Expand Up @@ -958,7 +962,8 @@ def match_sparse_buffer(
if param in self.context.func_params:
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
self.context.sp_struct2param_map[buffer] = [param]
self.context.sp_struct.append(buffer)
self.context.sp_struct_params.append([param])
self.context.update_symbol(buffer_name + "_data", data, self.node)
self.context.update_symbol(buffer_name, buffer, self.node)
else:
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .buffer import Buffer


@tvm._ffi.register_object("tir.sparse.Axis")
class Axis(Object):
"""Base class of all the sparse axes."""

Expand All @@ -42,10 +43,12 @@ def idtype(self):
return _ffi_api.GetAxisIndexType(self)


@tvm._ffi.register_object("tir.sparse.DenseAxis")
class DenseAxis(Axis):
pass


@tvm._ffi.register_object("tir.sparse.SparseAxis")
class SparseAxis(Axis):
pass

Expand Down
13 changes: 11 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,12 @@ class SparseBlock(Stmt):
sp_iter_vars : List[SpIterVar]
The sparse iteration variables of the block.
sp_struct : List[Object]
The sparse data structures
sp_struct_params : List[List[Var]]
The function parameters that corresponding to each sparse data structures
sp_struct2param_map : Mapping[Object, List[Var]]
The mapping from sparse data structures to the PrimFunc parameters.
Expand All @@ -666,6 +672,7 @@ class SparseBlock(Stmt):
"""

sp_iter_vars: List[SpIterVar]
sp_struct: List[Object]
sp_struct2param_map: Mapping[Object, List[Var]]
name: str
body: Stmt
Expand All @@ -675,7 +682,8 @@ class SparseBlock(Stmt):
def __init__(
self,
sp_iter_vars: List[SpIterVar],
sp_struct2param_map: Mapping[Object, List[Var]],
sp_struct: List[Object],
sp_struct_params: List[List[Var]],
name: str,
body: Stmt,
init: Optional[Stmt] = None,
Expand All @@ -684,7 +692,8 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SparseBlock, # type: ignore
sp_iter_vars,
sp_struct2param_map,
sp_struct,
sp_struct_params,
name,
body,
init,
Expand Down
34 changes: 18 additions & 16 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1277,12 +1277,14 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
std::vector<Doc> axis_docs;
std::vector<Doc> sp_buf_docs;

for (auto it : sp_block->sp_struct2param_map) {
for (const ObjectRef& obj : sp_block->sp_structs) {
Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();

Doc doc;
doc << Print(it.first) << " = " << tir_prefix_ << ".";
doc << Print(obj) << " = " << tir_prefix_ << ".";

if (const auto* sp_buffer = it.first.as<SparseBufferNode>()) {
ICHECK_EQ(it.second.size(), 1);
if (const auto* sp_buffer = obj.as<SparseBufferNode>()) {
ICHECK_EQ(params.size(), 1);
Doc axes_doc;
if (sp_buffer->axes.size() != 1) {
std::vector<Doc> axes_docs;
Expand All @@ -1295,30 +1297,30 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
axes_doc << Print(sp_buffer->axes[0]) << ",";
}

doc << "match_sparse_buffer(" << Print(it.second[0]) << ", (" << axes_doc << "), "
doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), "
<< Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")";
sp_buf_docs.push_back(doc);
continue;
}

if (const auto* df_axis = it.first.as<DenseFixedAxisNode>()) {
ICHECK_EQ(it.second.size(), 0);
if (const auto* df_axis = obj.as<DenseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 0);
doc << "dense_fixed(" << Print(df_axis->length) << ")";
} else if (const auto* dv_axis = it.first.as<DenseVariableAxisNode>()) {
ICHECK_EQ(it.second.size(), 1);
} else if (const auto* dv_axis = obj.as<DenseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable((" << Print(dv_axis->length) << ", "
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(it.second[0]) << ", "
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
} else if (const auto* sf_axis = it.first.as<SparseFixedAxisNode>()) {
ICHECK_EQ(it.second.size(), 1);
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0])
<< ", " << Print(sf_axis->num_cols) << "), " << Print(it.second[0]) << ", "
<< ", " << Print(sf_axis->num_cols) << "), " << Print(params[0]) << ", "
<< PrintDType(sf_axis->indices->dtype) << ")";
} else if (const auto* sv_axis = it.first.as<SparseVariableAxisNode>()) {
ICHECK_EQ(it.second.size(), 2);
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 2);
doc << "sparse_variable((" << Print(sv_axis->length) << ", "
<< Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), ("
<< Print(it.second[0]) << ", " << Print(it.second[1]) << "), "
<< Print(params[0]) << ", " << Print(params[1]) << "), "
<< PrintDType(sv_axis->indptr->dtype) << ")";
} else {
ICHECK(false) << "Cannot reach here";
Expand Down
1 change: 0 additions & 1 deletion src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

arith::Analyzer ana;
CHECK(ana.CanProveEqual(axis->length, max_extent));
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
if (kind == SpIterKind::kDenseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>()) << err_str;
Expand Down
41 changes: 36 additions & 5 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,11 +968,42 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init, Span span) {
CHECK_EQ(sp_structs.size(), sp_struct_params.size())
<< "ValueError: The length of `sp_struct_params` is expected to be equal to the length "
"`sp_structs`, which is the number of sparse data structures";
Map<ObjectRef, Array<Var>> sp_struct2param_map;
for (int i = 0; i < static_cast<int>(sp_structs.size()); ++i) {
ObjectRef obj = sp_structs[i];
Array<Var> params = sp_struct_params[i];

if (obj->IsInstance<DenseFixedAxisNode>()) {
CHECK(params.size() == 0)
<< "ValueError: The number of function parameters for dense-fixed axes should be 0";
} else if (obj->IsInstance<DenseVariableAxisNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for dense-variable axes should be 1";
} else if (obj->IsInstance<SparseFixedAxisNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for sparse-fixed axes should be 1";
} else if (obj->IsInstance<SparseVariableAxisNode>()) {
CHECK(params.size() == 2)
<< "ValueError: The number of function parameters for sparse-variable axes should be 2";
} else if (obj->IsInstance<SparseBufferNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for SparseBuffer should be 1";
} else {
LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure";
}

sp_struct2param_map.Set(obj, params);
}

ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
node->sp_iter_vars = std::move(sp_iter_vars);
node->sp_structs = std::move(sp_structs);
node->sp_struct2param_map = std::move(sp_struct2param_map);
node->name = std::move(name);
node->body = std::move(body);
Expand All @@ -982,10 +1013,10 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
}

TVM_REGISTER_GLOBAL("tir.SparseBlock")
.set_body_typed([](Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init, Span span) {
return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span);
return SparseBlock(sp_iter_vars, sp_structs, sp_struct_params, name, body, init, span);
});

TVM_REGISTER_NODE_TYPE(SparseBlockNode);
Expand Down
Loading

0 comments on commit d2bccc2

Please sign in to comment.