Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api update] Use ObjectRef instead of Value, NDArrays instead of TensorValue #33

Merged
merged 1 commit into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions aot/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ def convert(a, ctx):
elif isinstance(a, np.ndarray):
a = tvm.nd.array(a, ctx)
elif isinstance(a, tvm.ndarray.NDArray):
a = relay.backend.interpreter.TensorValue(a)
return a
elif isinstance(a, relay.Call):
assert isinstance(a.op, relay.Constructor)
a = (a.op, *a.args)
elif isinstance(a, tuple):
assert isinstance(a[0], relay.Constructor)
a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0])
elif isinstance(a, relay.backend.interpreter.TensorValue):
return a
elif isinstance(a, relay.backend.interpreter.ConstructorValue):
return a
else:
Expand Down
62 changes: 30 additions & 32 deletions aot/to_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def visit_match(self, node):
for v in pattern_var_set:
bind_name = self.fresh_local_name()
self.name_map[v] = bind_name
stmt_str += f"Value {bind_name};\n"
stmt_str += f"ObjectRef {bind_name};\n"

# match data_name to pat, and fill the var accordingly.
# go to fail_label or ok_label base on failure/success.
Expand All @@ -143,7 +143,7 @@ def visit_pattern(pat, data_name, fail_label, ok_label):
for i, input_type in enumerate(pat.constructor.inputs):
bind_name = self.fresh_local_name()
bind_names.append(bind_name)
ok_case += f"Value {bind_name} = {data_name}->fields[{i}];\n"
ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n"
for bind_name, p in zip(bind_names, pat.patterns):
next_label = self.fresh_label_name()
ok_case += visit_pattern(p, bind_name, fail_label, next_label)
Expand All @@ -166,8 +166,8 @@ def visit_pattern(pat, data_name, fail_label, ok_label):

in_name = self.fresh_local_name()
out_name = self.fresh_local_name()
stmt_str += f"Value {in_name} = {vd.expr};\n"
stmt_str += f"Value {out_name};\n"
stmt_str += f"ObjectRef {in_name} = {vd.expr};\n"
stmt_str += f"ObjectRef {out_name};\n"
match_finish_label = self.fresh_label_name()
for c in node.clause:
vc = self.visit(c[1])
Expand Down Expand Up @@ -203,10 +203,10 @@ def visit_if(self, node):
vt = self.visit(node.true_branch)
vf = self.visit(node.false_branch)
ret_name = self.fresh_local_name()
stmt = f"Value {ret_name};"
stmt = f"ObjectRef {ret_name};"
stmt += f"""
{vc.stmt}
if (NDToBool(ValueToND({vc.expr}))) {{
if (NDToBool(ObjectRefToND({vc.expr}))) {{
{vt.stmt}
{ret_name} = {vt.expr};
}} else {{
Expand All @@ -220,7 +220,7 @@ def visit_constant(self, const):
if const not in self.declare_map:
name = self.fresh_global_name()
self.declare_map[const] = name
self.declare += f"Value {name};\n"
self.declare += f"ObjectRef {name};\n"
self.input_const.append((name, const.data.asnumpy()))
return ExprWithStmt(self.declare_map[const])

Expand All @@ -247,7 +247,7 @@ def visit_args(self, args):
def visit_invoke(self, invoke):
args_str, stmt_str = self.visit_args(invoke.args)
func = self.visit(invoke.call)
return ExprWithStmt(f"Apply({func.expr}, std::vector<Value>({{{args_str}}}))", stmt_str + func.stmt)
return ExprWithStmt(f"Apply({func.expr}, std::vector<ObjectRef>({{{args_str}}}))", stmt_str + func.stmt)

def visit_decl(self, decl):
source = ""
Expand All @@ -256,7 +256,7 @@ def visit_decl(self, decl):
self.name_map[var] = local_name
vv = self.visit(value, name=local_name)
source += vv.stmt
source += f"""Value {local_name} = {vv.expr};"""
source += f"""ObjectRef {local_name} = {vv.expr};"""
vb = self.visit(decl.body)
source += vb.stmt
return ExprWithStmt(vb.expr, source)
Expand Down Expand Up @@ -286,7 +286,7 @@ def visit_packed_call(self, call):
args_str = []
def convert_input(ty, arg):
if isinstance(ty, relay.ty.TensorType):
args_str.append(f"ValueToND({arg})")
args_str.append(f"{arg}")
else:
assert isinstance(ty, relay.ty.TupleType)
tuple_name = self.fresh_local_name()
Expand All @@ -302,8 +302,8 @@ def convert_output(ty):
if isinstance(ty, relay.ty.TensorType):
tensor_name = self.fresh_local_name()
nonlocal decl_str
decl_str += f"TensorValue {tensor_name} = TensorValueNode::make(NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context));\n"
args_str.append(f"{tensor_name}->data")
decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n"
args_str.append(f"{tensor_name}")
return tensor_name
else:
assert isinstance(ty, relay.ty.TupleType)
Expand All @@ -324,12 +324,12 @@ def visit_cpp_function(self, func, local, name):
for i, param in enumerate(func.params):
pname = self.fresh_local_name(param)
self.name_map[param] = pname
body += f"Value {pname} = {vec}.at({i});\n"
body += f"ObjectRef {pname} = {vec}.at({i});\n"

body += f"Value {name} = self;\n"
body += f"ObjectRef {name} = self;\n"
vb = self.visit(func.body)
body = body + vb.stmt + f"""return {vb.expr};"""
expr = f"""FunctionValueNode::make([=](const std::vector<Value>& {vec}, const Value& self) {{
expr = f"""FunctionValueNode::make([=](const std::vector<ObjectRef>& {vec}, const ObjectRef& self) {{
{body}
}});
"""
Expand All @@ -340,11 +340,11 @@ def visit_cpp_function(self, func, local, name):
if name is None:
name = self.fresh_global_name()
self.declare += f"""
static Value {name}_func() {{
static Value ret = {expr};
static ObjectRef {name}_func() {{
static ObjectRef ret = {expr};
return ret;
}}
Value {name} = {name}_func();
ObjectRef {name} = {name}_func();
"""
return ExprWithStmt(f"{name}")

Expand All @@ -369,8 +369,8 @@ def mk_register_api(self, name: str, func) -> str:
TVM_REGISTER_GLOBAL("{name}")
.set_body([](TVMArgs args, TVMRetValue* ret) {{
{init}
std::initializer_list<Value> ilist = {{{args}}};
*ret = Apply({vf.expr}, std::vector<Value>(ilist));
std::initializer_list<ObjectRef> ilist = {{{args}}};
*ret = Apply({vf.expr}, std::vector<ObjectRef>(ilist));
}});
"""
return source
Expand All @@ -387,7 +387,7 @@ def mk_file(body, ctx):
return f"""
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>
#include <tvm/ir/env_func.h>
#include <tvm/relay/interpreter.h>
#include <iostream>

Expand All @@ -411,13 +411,11 @@ def mk_file(body, ctx):
return reinterpret_cast<uint8_t*>(cpu_array->data)[0];
}}

static NDArray ValueToND(const Value& v) {{
const TensorValueNode* tv = v.as<TensorValueNode>();
CHECK(tv);
return tv->data;
static NDArray ObjectRefToND(const ObjectRef& v) {{
return Downcast<runtime::NDArray>(v);
}}

static ConstructorValue TagToCV(size_t tag, const tvm::Array<Value>& fields) {{
static ConstructorValue TagToCV(size_t tag, const tvm::Array<ObjectRef>& fields) {{
ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
ObjectPtr<ConstructorNode> con = make_object<ConstructorNode>();
con->tag = tag;
Expand All @@ -430,8 +428,8 @@ def mk_file(body, ctx):
/*! \\brief A Function value. */
class FunctionValue;

using function_value_t = std::function<Value(const std::vector<Value>&, const Value&)>;
struct FunctionValueNode : ValueNode {{
using function_value_t = std::function<ObjectRef(const std::vector<ObjectRef>&, const ObjectRef&)>;
struct FunctionValueNode : Object {{
function_value_t f;

FunctionValueNode() {{ }}
Expand All @@ -441,12 +439,12 @@ class FunctionValue;
TVM_DLL static FunctionValue make(const function_value_t& f);

static constexpr const char* _type_key = "relay.FunctionValue";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object);
}};

class FunctionValue : public Value {{
class FunctionValue : public ObjectRef {{
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, Value, FunctionValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode);
}};

FunctionValue FunctionValueNode::make(const function_value_t& f) {{
Expand All @@ -455,7 +453,7 @@ class FunctionValue : public Value {{
return FunctionValue(n);
}}

Value Apply(const Value& op, const std::vector<Value>& args) {{
ObjectRef Apply(const ObjectRef& op, const std::vector<ObjectRef>& args) {{
return Downcast<FunctionValue>(op)->f(args, op);
}}

Expand Down