diff --git a/nnvm/python/nnvm/cython/symbol.pyx b/nnvm/python/nnvm/cython/symbol.pyx index 7b1435381e99..ee3e2d0bf258 100644 --- a/nnvm/python/nnvm/cython/symbol.pyx +++ b/nnvm/python/nnvm/cython/symbol.pyx @@ -14,10 +14,6 @@ include "./base.pyi" cdef extern from "nnvm/c_api.h": const char* NNGetLastError(); - int NNSymbolCreateVariable(const char *name, SymbolHandle *out); - int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out); int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, AtomicSymbolCreator **out_array); int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, @@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h": const char ***arg_descriptions, const char **return_type); int NNSymbolFree(SymbolHandle symbol); - int NNSymbolPrint(SymbolHandle symbol, const char **out_str); - int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); - int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int *success); int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** values); - int NNSymbolListAttrs(SymbolHandle symbol, - int recursive_option, - nn_uint *out_size, - const char*** out); - int NNSymbolListArguments(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - int NNSymbolListOutputs(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out); - int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out); int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index 7eab9f6045f7..31f1660fc899 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -30,12 +30,71 @@ class Symbol(SymbolBase): def __add__(self, other): if isinstance(other, Symbol): - return _internal.__add__symbol__(self, other) + return _internal.__add_symbol__(self, other) elif isinstance(other, _Number): - return _internal.__add__scalar__(self, scalar=other) + return _internal.__add_scalar__(self, scalar=other) else: raise TypeError("type %s not supported" % str(type(other))) + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + if isinstance(other, Symbol): + return _internal.__sub_symbol__(self, other) + if isinstance(other, Number): + return _internal.__sub_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rsub__(self, other): + if isinstance(other, Number): + return _internal.__rsub_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __mul__(self, other): + if isinstance(other, Symbol): + return _internal.__mul_symbol__(self, other) + if isinstance(other, Number): + return _internal.__mul_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + if isinstance(other, Symbol): + return _internal.__div_symbol__(self, other) + if isinstance(other, Number): + return _internal.__div_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rdiv__(self, other): + if isinstance(other, Number): + return _internal.__rdiv_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __truediv__(self, other): + return self.__div__(other) + + def __rtruediv__(self, other): + return self.__rdiv__(other) + + def __pow__(self, other): + if isinstance(other, Symbol): + return _internal.__pow_symbol__(self, other) + if isinstance(other, Number): + return _internal.__pow_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __neg__(self): + return self.__mul__(-1.0) + def __copy__(self): return self.__deepcopy__() diff --git a/nnvm/src/example/operator.cc b/nnvm/src/example/operator.cc index 2bd4a22ed2dd..9078c314b119 100644 --- a/nnvm/src/example/operator.cc +++ b/nnvm/src/example/operator.cc @@ -11,9 +11,11 @@ using nnvm::NodeAttrs; NNVM_REGISTER_OP(add) .describe("add two data together") -.set_num_inputs(2) -.attr("inplace_pair", std::make_pair(0, 0)); +.set_num_inputs(2); +NNVM_REGISTER_OP(__add_symbol__) +.describe("Alias of add") +.set_num_inputs(2); NNVM_REGISTER_OP(exp) .describe("take exponmential") diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index b454d4f7a736..6ba7ac23f50f 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -30,7 +30,32 @@ namespace pass { // auxiliary node structure for serialization. struct JSONNode { // the node entry structure in serialized format - typedef std::pair Entry; + struct Entry { + uint32_t node_id; + uint32_t index; + uint32_t version; + void Save(dmlc::JSONWriter *writer) const { + writer->BeginArray(); + writer->WriteArrayItem(node_id); + writer->WriteArrayItem(index); + writer->WriteArrayItem(version); + writer->EndArray(); + } + void Load(dmlc::JSONReader *reader) { + reader->BeginArray(); + CHECK(reader->NextArrayItem()) << "invalid json format"; + reader->Read(&node_id); + CHECK(reader->NextArrayItem()) << "invalid json format"; + reader->Read(&index); + if (reader->NextArrayItem()) { + reader->Read(&version); + CHECK(!reader->NextArrayItem()) << "invalid json format"; + } else { + version = 0; + } + } + }; + // pointer to the graph node NodePtr node; // inputs @@ -75,6 +100,10 @@ struct JSONNode { if (op_type_str != "null") { try { node->op = Op::Get(op_type_str); + // rebuild attribute parser + if (node->op->attr_parser != nullptr) { + node->op->attr_parser(&(node->attrs)); + } } catch (const dmlc::Error &err) { std::ostringstream os; os << "Failed loading Op " << node->attrs.name @@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { n.node->inputs.emplace_back( - NodeEntry{jgraph.nodes[e.first].node, e.second}); + NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } n.node->control_deps.reserve(n.control_deps.size()); for (uint32_t nid : n.control_deps) { @@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) { ret.outputs.reserve(jgraph.heads.size()); for (const JSONNode::Entry &e : jgraph.heads) { ret.outputs.emplace_back( - NodeEntry{jgraph.nodes[e.first].node, e.second}); + NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } return ret; } @@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) { jnode.inputs.reserve(n->inputs.size()); for (const NodeEntry& e : n->inputs) { jnode.inputs.emplace_back( - std::make_pair(node2index.at(e.node.get()), e.index)); + JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); } for (const NodePtr& c : n->control_deps) { jnode.control_deps.push_back(node2index.at(c.get())); @@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) { }); for (const NodeEntry& e : src.outputs) { - jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index)); + jgraph.heads.push_back( + JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version}); } std::ostringstream os; diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index b6082364458e..0af641932a5d 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -33,7 +33,7 @@ def test_order_mutation_pass(): assert nindex['assign'] in jnodes[nindex['add2']]['control_deps'] assert nindex['conv'] in jnodes[nindex['assign']]['control_deps'] assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] - + assert jnodes[nindex['assign']]['inputs'][0][2] == 1 if __name__ == "__main__": test_order_mutation_pass()