Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frontend update, demo scripts. #10

Merged
merged 30 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
75cb0ad
Format and Buffer data structure (#1)
yzh119 Oct 20, 2021
d186b7e
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseB…
MasterJH5574 Oct 20, 2021
632baa0
[CherryPick][Intrinsic] lower_bound and upper_bound for binary search…
MasterJH5574 Oct 22, 2021
85e8283
Fix AxisTree (#3)
yzh119 Oct 22, 2021
53c1709
Format and Buffer data structure (#1)
yzh119 Oct 20, 2021
3c09a0e
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseB…
MasterJH5574 Oct 20, 2021
a3a43f4
fix axis tree
yzh119 Oct 22, 2021
e766a44
upd
yzh119 Oct 22, 2021
bb912d8
Merge remote-tracking branch 'origin/sparse' into develop
yzh119 Oct 22, 2021
1f051e3
Format and Buffer data structure (#1)
yzh119 Oct 20, 2021
e08dd33
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseB…
MasterJH5574 Oct 20, 2021
a61c9f2
[CherryPick][Intrinsic] lower_bound and upper_bound for binary search…
MasterJH5574 Oct 22, 2021
6023a16
Fix AxisTree (#3)
yzh119 Oct 22, 2021
9e8c926
[SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)
MasterJH5574 Oct 23, 2021
61f1671
merge
yzh119 Oct 25, 2021
d3e1ce6
Format and Buffer data structure (#1)
yzh119 Oct 20, 2021
85413f2
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseB…
MasterJH5574 Oct 20, 2021
5f88452
[CherryPick][Intrinsic] lower_bound and upper_bound for binary search…
MasterJH5574 Oct 22, 2021
933f86f
Fix AxisTree (#3)
yzh119 Oct 22, 2021
14f6f92
[SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)
MasterJH5574 Oct 23, 2021
b264051
Merge remote-tracking branch 'origin/sparse' into develop
yzh119 Oct 26, 2021
3f96d26
[SparseTIR] Introduce SpIterVar (#6)
MasterJH5574 Oct 27, 2021
8bd5f9f
[BugFix] Fix binary search & SpIterVar (#7)
MasterJH5574 Oct 29, 2021
fe9610e
[BugFix] Add field `is_reduction` for SpIterVar (#9)
MasterJH5574 Nov 1, 2021
eba6a26
Merge remote-tracking branch 'upstream/main' into develop
yzh119 Nov 1, 2021
ab9d570
upd
yzh119 Nov 3, 2021
5798c08
upd
yzh119 Nov 3, 2021
ce51cda
upd
yzh119 Nov 3, 2021
9e4832c
Merge branch 'sparse' into develop
yzh119 Nov 3, 2021
4968e07
upd
yzh119 Nov 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class AxisNode : public Object {
* the current axis. */
PrimExpr length;

String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down Expand Up @@ -139,8 +143,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
v->Visit("indptr", &indptr);
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
bool SEqualReduce(const DenseVariableAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -159,9 +165,11 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length,
Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
DenseVariableAxisNode);
};

/*!
Expand Down Expand Up @@ -198,7 +206,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
v->Visit("num_cols", &num_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
bool SEqualReduce(const SparseFixedAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}
Expand All @@ -220,9 +229,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
SparseFixedAxisNode);
};

/*!
Expand All @@ -240,7 +251,8 @@ class SparseVariableAxisNode : public SparseAxisNode {
v->Visit("indices", &indices);
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
bool SEqualReduce(const SparseVariableAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}
Expand All @@ -262,24 +274,25 @@ class SparseVariableAxisNode : public SparseAxisNode {
*/
class SparseVariableAxis : public SparseAxis {
public:
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices);

TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
SparseVariableAxisNode);
};

/*!
* \brief Axis Dependency Tree.
*/
class AxisTreeNode : public Object {
public:
// mapping from names to axes.
std::unordered_map<String, Axis> axis_map;
// unordered map that stores the parent relationship between axes.
std::unordered_map<Axis, Axis, ObjectPtrHash, ObjectPtrEqual> parent;
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
parent;
// unordered map that stores the children relationship between axes.
std::unordered_map<Axis, Array<Axis>, ObjectPtrHash, ObjectPtrEqual> children;
// The root axis.
Axis root;
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
ObjectPtrEqual>
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
children;

void VisitAttrs(AttrVisitor* v) {}

Expand All @@ -293,7 +306,9 @@ class AxisTreeNode : public Object {
*/
class AxisTree : public ObjectRef {
public:
TVM_DLL AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names);
TVM_DLL AxisTree(Array<String> axis_names,
Array<Optional<String>> axis_parent_names);

TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
};

Expand All @@ -302,38 +317,30 @@ class AxisTree : public ObjectRef {
*/
class SparseBufferNode : public Object {
public:
/* Root of Axis Dependency Tree. */
AxisTree tree;
/* Axes */
Array<Axis> axes;
/* Buffer corresponding to flattened value */
Buffer data;
/* Buffer Name */
String name;
/* Data type */
runtime::DataType dtype;

inline int ndim() const { return static_cast<int>(axes.size()); }

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
v->Visit("length", &axes);
v->Visit("num_cols", &data);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name) && equal(dtype, other->dtype);
return equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tree);
hash_reduce(axes);
hash_reduce(data);
hash_reduce(name);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -346,8 +353,7 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype);
TVM_DLL explicit SparseBuffer(Array<Axis> axes, Buffer data, String name);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};
Expand Down Expand Up @@ -380,8 +386,8 @@ class SpIterVarNode : public Object {

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
equal(axis, other->axis) &&
equal(is_reduction, other->is_reduction) && equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -400,8 +406,8 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
bool is_reduction, Optional<Axis> axis = NullOpt);

/*!
* \return the corresponding var in the IterVar.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,28 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
* \brief Sparse Block node.
*/
class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse buffers defined in the block. */
Array<SparseBuffer> sp_buffers;
Comment on lines +337 to +338
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't all the SparseBuffers defined under the root block of the PrimFunc?

/*! \brief The body of the block */
Stmt body;

static constexpr const char* _type_key = "tir.SparseBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
};

class SparseBlock : public Stmt {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
};


/*!
* \brief Store value to the high dimension sparse buffer.
*
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
from tvm.script.tir.sparse import MatchSparseBuffer
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from tvm.tir.expr import IterVar
from tvm.tir.sparse import Axis, SparseBuffer
from .tir.node import BufferSlice


Expand Down Expand Up @@ -74,6 +76,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
axes: List[Axis] = []
"""List[Axis]: list of sparse axis created in the block signature."""
match_sparse_buffers: List[MatchSparseBuffer]
"""List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature."""
iter_values: List[PrimExpr] = []
"""List[PrimExpr]: list of binding values for iter vars"""
iter_vars: List[IterVar] = []
Expand Down Expand Up @@ -119,14 +125,16 @@ class ContextMaintainer:
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: Dict[Var, Range] = {}
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
symbols: List[Dict[str, Union[Var, Buffer, SparseBuffer, Axis]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

# function context
func_params: List[Var] = []
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
Expand All @@ -147,6 +155,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# function context
self.func_params = []
self.func_buffer_map = {}
self.func_sparse_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
Expand Down Expand Up @@ -202,9 +211,9 @@ def exit_block_scope(self):
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, Buffer):
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
if name in self.symbols[0]:
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
Expand All @@ -219,7 +228,7 @@ def remove_symbol(self, name: str):
return
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var, SparseBuffer, Axis]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
Expand Down
Loading