Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] Basic storage flatten #13

Merged
merged 1 commit into from
Jan 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def func(*args):
"""TVM function"""
cargs = []
for x in args:
if isinstance(x, (list, tuple, SliceBase)):
if isinstance(x, (list, tuple, dict, SliceBase)):
cargs.append(convert(x))
else:
cargs.append(x)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def compute(shape, fcompute, name="compute"):


def Buffer(shape, dtype=None,
name="buffer", ptr=None,
name="buffer",
ptr=None,
strides=None):
"""Create a new buffer

Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);

} // namespace ir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Expr Buffer::MakeLoad(Array<Expr> index) const {
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->ptr, BufferOffset(n, index), value);
return ir::Store::make(n->ptr, value, BufferOffset(n, index));
}

Buffer BufferNode::make(std::string name,
Expand Down
2 changes: 1 addition & 1 deletion src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, op->value, op->body);
return AttrStmt::make(op->node, op->type_key, value, body);
}
});

Expand Down
168 changes: 168 additions & 0 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*!
* Copyright (c) 2016 by Contributors
* \file storage_flatten.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>

namespace tvm {
namespace ir {

// key of function buffer
struct TensorKey {
FunctionRef f;
int value_index;

inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};

} // namespace ir
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std

namespace tvm {
namespace ir {

using Halide::Internal::Region;

// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
e.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
}
}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args));
} else {
return expr;
}
}

Stmt Mutate(Stmt stmt) final {
const Realize* realize = stmt.as<Realize>();
if (realize != nullptr) {
return HandleRealize(realize);
} else if (stmt.as<Provide>()) {
return HandleProvide(stmt);
} else {
return IRMutator::Mutate(stmt);
}
}

private:
// The buffer entry in the flatten map
struct BufferEntry {
// the buffer of storage
Buffer buffer;
// the bounds of realization, can be null
Region bounds;
// Whether the buffer is external
bool external{false};
// Whether we are out of allocation bounds and buffer get released.
bool released{false};
// TODO(tqchen) allow permutation and inference of index dimension.
// relative index
inline Array<Expr> RelIndex(Array<Expr> args) const {
if (bounds.size() != 0) {
Array<Expr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
}
return index;
} else {
return args;
}
}
};

// The buffer assignment map
std::unordered_map<TensorKey, BufferEntry> buf_map_;

Stmt HandleRealize(const Realize* op) {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
} else {
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = Buffer(shape, op->type, key.GetName());

buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;

return Allocate::make(
e.buffer->ptr, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
}

Stmt HandleProvide(Stmt stmt) {
stmt = IRMutator::Mutate(stmt);
const Provide* op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
}
};


Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) {
stmt = StorageFlattener(extern_buffer).Mutate(stmt);
return stmt;
}

} // namespace ir
} // namespace tvm
24 changes: 24 additions & 0 deletions tests/python/test_pass_storage_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tvm

def test_flatten2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.Schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)

print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
print(stmt)

if __name__ == "__main__":
test_flatten2()