Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparseTIR] SparseTIR Lowering #20

Merged
merged 21 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,11 @@ class DenseFixedAxisNode : public DenseAxisNode {
}

bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(from_sparse, other->from_sparse);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(from_sparse);
Expand Down Expand Up @@ -170,12 +168,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand Down Expand Up @@ -213,13 +209,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indices);
Expand Down Expand Up @@ -257,13 +251,11 @@ class SparseVariableAxisNode : public SparseAxisNode {
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
Expand Down Expand Up @@ -347,12 +339,10 @@ class SparseBufferNode : public Object {
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(axes);
hash_reduce(data);
hash_reduce(name);
Expand Down
13 changes: 8 additions & 5 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,8 @@ class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse data structures */
Array<ObjectRef> sp_structs;
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
Map<ObjectRef, Array<Var>> sp_struct2param_map;
/*! \brief The name of the block */
Expand All @@ -1296,6 +1298,7 @@ class SparseBlockNode : public StmtNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("sp_iter_vars", &sp_iter_vars);
v->Visit("sp_structs", &sp_structs);
v->Visit("sp_struct2param_map", &sp_struct2param_map);
v->Visit("name", &name);
v->Visit("body", &body);
Expand All @@ -1305,15 +1308,15 @@ class SparseBlockNode : public StmtNode {
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) &&
equal(body, other->body) && equal(init, other->init) &&
equal(sp_struct2param_map, other->sp_struct2param_map);
equal(sp_structs, other->sp_structs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(sp_iter_vars);
hash_reduce(name);
hash_reduce(body);
hash_reduce(init);
hash_reduce(sp_struct2param_map);
hash_reduce(sp_structs);
}

static constexpr const char* _type_key = "tir.SparseBlock";
Expand All @@ -1326,9 +1329,9 @@ class SparseBlockNode : public StmtNode {
*/
class SparseBlock : public Stmt {
public:
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init = NullOpt, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ class ContextMaintainer:
"""Mapping[Var, str]: The map from var to env thread"""

# sparse block context
sp_struct2param_map: Mapping[Object, List[Var]] = {}
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""
sp_struct: List[Object] = []
"""List[Object]: The sparse data structures"""
sp_struct_params: List[List[Var]] = []
"""List[List[Var]]: The function parameters that corresponding to each sparse data structures"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
Expand All @@ -155,7 +157,8 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
self.func_dict_attr = {}
self.func_var_env_dict = {}
# sparse block context
self.sp_struct2param_map = {}
self.sp_struct = []
self.sp_struct_params = []
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,6 @@ def pos(axis: Axis, span: Optional[Span] = None):
elif isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
elif isinstance(axis, SparseFixedAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis)
return SpIterVar(var_temp, axis.num_cols, SpIterVar.SparseFixed, False, axis)
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
else:
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)
3 changes: 2 additions & 1 deletion python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = No

block = tvm.tir.SparseBlock(
sp_iters,
self.context.sp_struct2param_map,
self.context.sp_struct,
self.context.sp_struct_params,
name,
self.body,
block_info.init,
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ def dense_fixed(length: PrimExpr, span: Optional[Span] = None):
)

axis = DenseFixedAxis(names[0], length)
self.context.sp_struct2param_map[axis] = []
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([])
self.context.update_symbol(names[0], axis, self.node)

super().__init__(dense_fixed, def_symbol=True)
Expand Down Expand Up @@ -860,7 +861,8 @@ def dense_variable(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = DenseVariableAxis(names[0], length, indptr_buf)
self.context.sp_struct2param_map[axis] = [indptr_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)

Expand Down Expand Up @@ -889,7 +891,8 @@ def sparse_fixed(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols)
self.context.sp_struct2param_map[axis] = [indices_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indices_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)

Expand Down Expand Up @@ -922,7 +925,8 @@ def sparse_variable(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf)
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var, indices_var])
self.context.update_symbol(names[0], axis, self.node)
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)
Expand Down Expand Up @@ -958,7 +962,8 @@ def match_sparse_buffer(
if param in self.context.func_params:
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
self.context.sp_struct2param_map[buffer] = [param]
self.context.sp_struct.append(buffer)
self.context.sp_struct_params.append([param])
self.context.update_symbol(buffer_name + "_data", data, self.node)
self.context.update_symbol(buffer_name, buffer, self.node)
else:
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .buffer import Buffer


@tvm._ffi.register_object("tir.sparse.Axis")
class Axis(Object):
"""Base class of all the sparse axes."""

Expand All @@ -42,10 +43,12 @@ def idtype(self):
return _ffi_api.GetAxisIndexType(self)


@tvm._ffi.register_object("tir.sparse.DenseAxis")
class DenseAxis(Axis):
pass


@tvm._ffi.register_object("tir.sparse.SparseAxis")
class SparseAxis(Axis):
pass

Expand Down
13 changes: 11 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,12 @@ class SparseBlock(Stmt):
sp_iter_vars : List[SpIterVar]
The sparse iteration variables of the block.

sp_struct : List[Object]
The sparse data structures

sp_struct_params : List[List[Var]]
The function parameters that corresponding to each sparse data structures

sp_struct2param_map : Mapping[Object, List[Var]]
The mapping from sparse data structures to the PrimFunc parameters.

Expand All @@ -666,6 +672,7 @@ class SparseBlock(Stmt):
"""

sp_iter_vars: List[SpIterVar]
sp_struct: List[Object]
sp_struct2param_map: Mapping[Object, List[Var]]
name: str
body: Stmt
Expand All @@ -675,7 +682,8 @@ class SparseBlock(Stmt):
def __init__(
self,
sp_iter_vars: List[SpIterVar],
sp_struct2param_map: Mapping[Object, List[Var]],
sp_struct: List[Object],
sp_struct_params: List[List[Var]],
name: str,
body: Stmt,
init: Optional[Stmt] = None,
Expand All @@ -684,7 +692,8 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SparseBlock, # type: ignore
sp_iter_vars,
sp_struct2param_map,
sp_struct,
sp_struct_params,
name,
body,
init,
Expand Down
34 changes: 18 additions & 16 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1263,12 +1263,14 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
std::vector<Doc> axis_docs;
std::vector<Doc> sp_buf_docs;

for (auto it : sp_block->sp_struct2param_map) {
for (const ObjectRef& obj : sp_block->sp_structs) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we design an interface for sparse structures rather than use ObjectRef?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with both designs. Hmm, it concerns me a bit that the unified class has no other use. Do you think it's necessary?

Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();
yzh119 marked this conversation as resolved.
Show resolved Hide resolved

Doc doc;
doc << Print(it.first) << " = " << tir_prefix_ << ".";
doc << Print(obj) << " = " << tir_prefix_ << ".";

if (const auto* sp_buffer = it.first.as<SparseBufferNode>()) {
ICHECK_EQ(it.second.size(), 1);
if (const auto* sp_buffer = obj.as<SparseBufferNode>()) {
ICHECK_EQ(params.size(), 1);
Doc axes_doc;
if (sp_buffer->axes.size() != 1) {
std::vector<Doc> axes_docs;
Expand All @@ -1281,30 +1283,30 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
axes_doc << Print(sp_buffer->axes[0]) << ",";
}

doc << "match_sparse_buffer(" << Print(it.second[0]) << ", (" << axes_doc << "), "
doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), "
<< Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")";
sp_buf_docs.push_back(doc);
continue;
}

if (const auto* df_axis = it.first.as<DenseFixedAxisNode>()) {
ICHECK_EQ(it.second.size(), 0);
if (const auto* df_axis = obj.as<DenseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 0);
doc << "dense_fixed(" << Print(df_axis->length) << ")";
} else if (const auto* dv_axis = it.first.as<DenseVariableAxisNode>()) {
ICHECK_EQ(it.second.size(), 1);
} else if (const auto* dv_axis = obj.as<DenseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable((" << Print(dv_axis->length) << ", "
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(it.second[0]) << ", "
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
} else if (const auto* sf_axis = it.first.as<SparseFixedAxisNode>()) {
ICHECK_EQ(it.second.size(), 1);
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0])
<< ", " << Print(sf_axis->num_cols) << "), " << Print(it.second[0]) << ", "
<< ", " << Print(sf_axis->num_cols) << "), " << Print(params[0]) << ", "
<< PrintDType(sf_axis->indices->dtype) << ")";
} else if (const auto* sv_axis = it.first.as<SparseVariableAxisNode>()) {
ICHECK_EQ(it.second.size(), 2);
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 2);
doc << "sparse_variable((" << Print(sv_axis->length) << ", "
<< Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), ("
<< Print(it.second[0]) << ", " << Print(it.second[1]) << "), "
<< Print(params[0]) << ", " << Print(params[1]) << "), "
<< PrintDType(sv_axis->indptr->dtype) << ")";
} else {
ICHECK(false) << "Cannot reach here";
Expand Down
1 change: 0 additions & 1 deletion src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

arith::Analyzer ana;
CHECK(ana.CanProveEqual(axis->length, max_extent));
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
if (kind == SpIterKind::kDenseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>()) << err_str;
Expand Down
41 changes: 36 additions & 5 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,11 +968,42 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init, Span span) {
CHECK_EQ(sp_structs.size(), sp_struct_params.size())
<< "ValueError: The length of `sp_struct_params` is expected to be equal to the length "
"`sp_structs`, which is the number of sparse data structures";
Map<ObjectRef, Array<Var>> sp_struct2param_map;
for (int i = 0; i < static_cast<int>(sp_structs.size()); ++i) {
ObjectRef obj = sp_structs[i];
Array<Var> params = sp_struct_params[i];

if (obj->IsInstance<DenseFixedAxisNode>()) {
CHECK(params.size() == 0)
<< "ValueError: The number of function parameters for dense-fixed axes should be 0";
} else if (obj->IsInstance<DenseVariableAxisNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for dense-variable axes should be 1";
} else if (obj->IsInstance<SparseFixedAxisNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for sparse-fixed axes should be 1";
} else if (obj->IsInstance<SparseVariableAxisNode>()) {
CHECK(params.size() == 2)
<< "ValueError: The number of function parameters for sparse-variable axes should be 2";
} else if (obj->IsInstance<SparseBufferNode>()) {
CHECK(params.size() == 1)
<< "ValueError: The number of function parameters for SparseBuffer should be 1";
} else {
LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure";
}

sp_struct2param_map.Set(obj, params);
}

ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
node->sp_iter_vars = std::move(sp_iter_vars);
node->sp_structs = std::move(sp_structs);
node->sp_struct2param_map = std::move(sp_struct2param_map);
node->name = std::move(name);
node->body = std::move(body);
Expand All @@ -982,10 +1013,10 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
}

TVM_REGISTER_GLOBAL("tir.SparseBlock")
.set_body_typed([](Array<SpIterVar> sp_iter_vars,
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
Array<Array<Var>> sp_struct_params, String name, Stmt body,
Optional<Stmt> init, Span span) {
return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span);
return SparseBlock(sp_iter_vars, sp_structs, sp_struct_params, name, body, init, span);
});

TVM_REGISTER_NODE_TYPE(SparseBlockNode);
Expand Down
Loading