Skip to content

Commit

Permalink
Updates (#14)
Browse files Browse the repository at this point in the history
* Remove outstanding cython functions

* Add in operator overload

* Enable JSON to save version
  • Loading branch information
tqchen committed May 29, 2018
1 parent 20ac351 commit 5629330
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 35 deletions.
25 changes: 0 additions & 25 deletions nnvm/python/nnvm/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ include "./base.pyi"

cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
Expand All @@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h":
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** values);
int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option,
nn_uint *out_size,
const char*** out);
int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
int NNSymbolGetOutput(SymbolHandle symbol,
nn_uint index,
SymbolHandle *out);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
Expand Down
63 changes: 61 additions & 2 deletions nnvm/python/nnvm/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,71 @@ class Symbol(SymbolBase):

def __add__(self, other):
if isinstance(other, Symbol):
return _internal.__add__symbol__(self, other)
return _internal.__add_symbol__(self, other)
elif isinstance(other, _Number):
return _internal.__add__scalar__(self, scalar=other)
return _internal.__add_scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
if isinstance(other, Symbol):
return _internal.__sub_symbol__(self, other)
if isinstance(other, Number):
return _internal.__sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rsub__(self, other):
if isinstance(other, Number):
return _internal.__rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
if isinstance(other, Number):
return _internal.__mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rmul__(self, other):
return self.__mul__(other)

def __div__(self, other):
if isinstance(other, Symbol):
return _internal.__div_symbol__(self, other)
if isinstance(other, Number):
return _internal.__div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rdiv__(self, other):
if isinstance(other, Number):
return _internal.__rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __truediv__(self, other):
return self.__div__(other)

def __rtruediv__(self, other):
return self.__rdiv__(other)

def __pow__(self, other):
if isinstance(other, Symbol):
return _internal.__pow_symbol__(self, other)
if isinstance(other, Number):
return _internal.__pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __neg__(self):
return self.__mul__(-1.0)

def __copy__(self):
return self.__deepcopy__()

Expand Down
6 changes: 4 additions & 2 deletions nnvm/src/example/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ using nnvm::NodeAttrs;

NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
.set_num_inputs(2);

NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
.set_num_inputs(2);

NNVM_REGISTER_OP(exp)
.describe("take exponmential")
Expand Down
40 changes: 35 additions & 5 deletions nnvm/src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,32 @@ namespace pass {
// auxiliary node structure for serialization.
struct JSONNode {
// the node entry structure in serialized format
typedef std::pair<uint32_t, uint32_t> Entry;
struct Entry {
uint32_t node_id;
uint32_t index;
uint32_t version;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray();
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
writer->EndArray();
}
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};

// pointer to the graph node
NodePtr node;
// inputs
Expand Down Expand Up @@ -75,6 +100,10 @@ struct JSONNode {
if (op_type_str != "null") {
try {
node->op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) {
std::ostringstream os;
os << "Failed loading Op " << node->attrs.name
Expand Down Expand Up @@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) {
n.node->inputs.reserve(n.inputs.size());
for (const JSONNode::Entry &e : n.inputs) {
n.node->inputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second});
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
n.node->control_deps.reserve(n.control_deps.size());
for (uint32_t nid : n.control_deps) {
Expand All @@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) {
ret.outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) {
ret.outputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second});
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
return ret;
}
Expand All @@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) {
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(
std::make_pair(node2index.at(e.node.get()), e.index));
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
Expand All @@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) {
});

for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index));
jgraph.heads.push_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}

std::ostringstream os;
Expand Down
2 changes: 1 addition & 1 deletion nnvm/tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_order_mutation_pass():
assert nindex['assign'] in jnodes[nindex['add2']]['control_deps']
assert nindex['conv'] in jnodes[nindex['assign']]['control_deps']
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']

assert jnodes[nindex['assign']]['inputs'][0][2] == 1

if __name__ == "__main__":
test_order_mutation_pass()
Expand Down

0 comments on commit 5629330

Please sign in to comment.