From 67179f78692ccc071573093a0faae952bae4fc38 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 8 Sep 2016 16:28:47 -0700 Subject: [PATCH] Enable Alias, refactor C API to reflect Op Semantics (#41) * Enable Alias, refactor C API to reflect Op Semantics * add alias example --- nnvm/example/src/operator.cc | 1 + nnvm/include/dmlc/registry.h | 41 +++++++++--- nnvm/include/nnvm/c_api.h | 54 +++++++++++----- nnvm/include/nnvm/op.h | 7 +++ nnvm/python/nnvm/_base.py | 2 +- nnvm/python/nnvm/{ctypes => _ctypes}/README | 0 .../nnvm/{ctypes => _ctypes}/__init__.py | 0 .../python/nnvm/{ctypes => _ctypes}/symbol.py | 37 ++++++----- nnvm/python/nnvm/cython/base.pyi | 2 +- nnvm/python/nnvm/cython/symbol.pyd | 6 -- nnvm/python/nnvm/cython/symbol.pyx | 62 +++++++++++-------- nnvm/python/nnvm/symbol.py | 15 ++--- nnvm/src/c_api/c_api_symbolic.cc | 47 ++++++++++---- nnvm/src/core/op.cc | 5 ++ 14 files changed, 184 insertions(+), 95 deletions(-) rename nnvm/python/nnvm/{ctypes => _ctypes}/README (100%) rename nnvm/python/nnvm/{ctypes => _ctypes}/__init__.py (100%) rename nnvm/python/nnvm/{ctypes => _ctypes}/symbol.py (87%) delete mode 100644 nnvm/python/nnvm/cython/symbol.pyd diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc index e67ad8e024f3..7ec898961c9b 100644 --- a/nnvm/example/src/operator.cc +++ b/nnvm/example/src/operator.cc @@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity) NNVM_REGISTER_OP(add) .describe("add two data together") .set_num_inputs(2) +.add_alias("__add_symbol__") .attr("FInferShape", SameShape) .attr("FInplaceOption", InplaceIn0Out0) .attr( diff --git a/nnvm/include/dmlc/registry.h b/nnvm/include/dmlc/registry.h index 380b31cd3d61..4f9947831c28 100644 --- a/nnvm/include/dmlc/registry.h +++ b/nnvm/include/dmlc/registry.h @@ -26,9 +26,19 @@ namespace dmlc { template class Registry { public: - /*! \return list of functions in the registry */ - inline static const std::vector &List() { - return Get()->entry_list_; + /*! \return list of entries in the registry(excluding alias) */ + inline static const std::vector& List() { + return Get()->const_list_; + } + /*! \return list all names registered in the registry, including alias */ + inline static std::vector ListAllNames() { + const std::map &fmap = Get()->fmap_; + typename std::map::const_iterator p; + std::vector names; + for (p = fmap.begin(); p !=fmap.end(); ++p) { + names.push_back(p->first); + } + return names; } /*! * \brief Find the entry with corresponding name. @@ -44,6 +54,21 @@ class Registry { return NULL; } } + /*! + * \brief Add alias to the key_name + * \param key_name The original entry key + * \param alias The alias key. + */ + inline void AddAlias(const std::string& key_name, + const std::string& alias) { + EntryType* e = fmap_.at(key_name); + if (fmap_.count(alias)) { + CHECK_EQ(e, fmap_.at(alias)) + << "Entry " << e->name << " already registered under different entry"; + } else { + fmap_[alias] = e; + } + } /*! * \brief Internal function to register a name function under name. * \param name name of the function @@ -55,6 +80,7 @@ class Registry { EntryType *e = new EntryType(); e->name = name; fmap_[name] = e; + const_list_.push_back(e); entry_list_.push_back(e); return *e; } @@ -79,16 +105,17 @@ class Registry { private: /*! \brief list of entry types */ - std::vector entry_list_; + std::vector entry_list_; + /*! \brief list of entry types */ + std::vector const_list_; /*! \brief map of name->function */ std::map fmap_; /*! \brief constructor */ Registry() {} /*! \brief destructor */ ~Registry() { - for (typename std::map::iterator p = fmap_.begin(); - p != fmap_.end(); ++p) { - delete p->second; + for (size_t i = 0; i < entry_list_.size(); ++i) { + delete entry_list_[i]; } } }; diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index f92470111594..3122e26b7038 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -29,7 +29,7 @@ typedef unsigned int nn_uint; /*! \brief handle to a function that takes param and creates symbol */ -typedef void *AtomicSymbolCreator; +typedef void *OpHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to Graph */ @@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg); NNVM_DLL const char *NNGetLastError(void); /*! - * \brief list all the available AtomicSymbolEntry + * \brief list all the available operator names, include entries. + * \param out_size the size of returned array + * \param out_array the output operator name array. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNListAllOpNames(nn_uint *out_size, + const char*** out_array); + +/*! + * \brief Get operator handle given name. + * \param op_name The name of the operator. + * \param op_out The returnning op handle. + */ +NNVM_DLL int NNGetOpHandle(const char* op_name, + OpHandle* op_out); + +/*! + * \brief list all the available operators. + * This won't include the alias, use ListAllNames + * instead to get all alias names. + * * \param out_size the size of returned array * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, - AtomicSymbolCreator **out_array); +NNVM_DLL int NNListUniqueOps(nn_uint *out_size, + OpHandle **out_array); + /*! * \brief Get the detailed information about atomic symbol. - * \param creator the AtomicSymbolCreator. - * \param name The returned name of the creator. + * \param op The operator handle. + * \param real_name The returned name of the creator. + * This name is not the alias name of the atomic symbol. * \param description The returned description of the symbol. * \param num_doc_args Number of arguments that contain documents. * \param arg_names Name of the arguments of doc args @@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +NNVM_DLL int NNGetOpInfo(OpHandle op, + const char **real_name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); /*! * \brief Create an AtomicSymbol functor. - * \param creator the AtomicSymbolCreator + * \param op The operator handle * \param num_param the number of parameters * \param keys the keys to the params * \param vals the vals of the params * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, +NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char **keys, const char **vals, diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index e49bc9ae6643..e754d5988013 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -199,6 +199,13 @@ class Op { * \return reference to self. */ inline Op& set_attr_parser(std::function fn); // NOLINT(*) + /*! + * \brief Add another alias to this operator. + * The same Op can be queried with Op::Get(alias) + * \param alias The alias of the operator. + * \return reference to self. + */ + Op& add_alias(const std::string& alias); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index 825a3d380f38..ccc4a3bdb501 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -45,7 +45,7 @@ def _load_lib(): # type definitions nn_uint = ctypes.c_uint -SymbolCreatorHandle = ctypes.c_void_p +OpHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p diff --git a/nnvm/python/nnvm/ctypes/README b/nnvm/python/nnvm/_ctypes/README similarity index 100% rename from nnvm/python/nnvm/ctypes/README rename to nnvm/python/nnvm/_ctypes/README diff --git a/nnvm/python/nnvm/ctypes/__init__.py b/nnvm/python/nnvm/_ctypes/__init__.py similarity index 100% rename from nnvm/python/nnvm/ctypes/__init__.py rename to nnvm/python/nnvm/_ctypes/__init__.py diff --git a/nnvm/python/nnvm/ctypes/symbol.py b/nnvm/python/nnvm/_ctypes/symbol.py similarity index 87% rename from nnvm/python/nnvm/ctypes/symbol.py rename to nnvm/python/nnvm/_ctypes/symbol.py index 3f5cb4e2abd7..0ae6d7a08495 100644 --- a/nnvm/python/nnvm/ctypes/symbol.py +++ b/nnvm/python/nnvm/_ctypes/symbol.py @@ -8,7 +8,7 @@ import sys from .._base import _LIB from .._base import c_array, c_str, nn_uint, py_str, string_types -from .._base import SymbolHandle +from .._base import SymbolHandle, OpHandle from .._base import check_call, ctypes2docstring from ..name import NameManager from ..attribute import AttrScope @@ -114,9 +114,9 @@ def _set_symbol_class(cls): _symbol_cls = cls -def _make_atomic_symbol_function(handle): +def _make_atomic_symbol_function(handle, name): """Create an atomic symbol function by handle and funciton name.""" - name = ctypes.c_char_p() + real_name = ctypes.c_char_p() desc = ctypes.c_char_p() num_args = nn_uint() arg_names = ctypes.POINTER(ctypes.c_char_p)() @@ -124,15 +124,15 @@ def _make_atomic_symbol_function(handle): arg_descs = ctypes.POINTER(ctypes.c_char_p)() ret_type = ctypes.c_char_p() - check_call(_LIB.NNSymbolGetAtomicSymbolInfo( - handle, ctypes.byref(name), ctypes.byref(desc), + check_call(_LIB.NNGetOpInfo( + handle, ctypes.byref(real_name), ctypes.byref(desc), ctypes.byref(num_args), ctypes.byref(arg_names), ctypes.byref(arg_types), ctypes.byref(arg_descs), ctypes.byref(ret_type))) param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) - func_name = py_str(name.value) + func_name = name desc = py_str(desc.value) doc_str = ('%s\n\n' + @@ -199,22 +199,25 @@ def creator(*args, **kwargs): return creator -def _init_symbol_module(): +def _init_symbol_module(symbol_class, root_namespace): """List and add all the atomic symbol functions to current module.""" - plist = ctypes.POINTER(ctypes.c_void_p)() + _set_symbol_class(symbol_class) + plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), - ctypes.byref(plist))) - module_obj = sys.modules["nnvm.symbol"] - module_internal = sys.modules["nnvm._symbol_internal"] + check_call(_LIB.NNListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] for i in range(size.value): - hdl = SymbolHandle(plist[i]) - function = _make_atomic_symbol_function(hdl) + op_names.append(py_str(plist[i])) + + module_obj = sys.modules["%s.symbol" % root_namespace] + module_internal = sys.modules["%s._symbol_internal" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_atomic_symbol_function(hdl, name) if function.__name__.startswith('_'): setattr(module_internal, function.__name__, function) else: setattr(module_obj, function.__name__, function) - -# Initialize the atomic symbol in startups -_init_symbol_module() diff --git a/nnvm/python/nnvm/cython/base.pyi b/nnvm/python/nnvm/cython/base.pyi index 1163f64b07fd..186d65a46b68 100644 --- a/nnvm/python/nnvm/cython/base.pyi +++ b/nnvm/python/nnvm/cython/base.pyi @@ -1,5 +1,5 @@ ctypedef void* SymbolHandle -ctypedef void* AtomicSymbolCreator +ctypedef void* OpHandle ctypedef unsigned nn_uint cdef py_str(const char* x): diff --git a/nnvm/python/nnvm/cython/symbol.pyd b/nnvm/python/nnvm/cython/symbol.pyd deleted file mode 100644 index a5df49d25f54..000000000000 --- a/nnvm/python/nnvm/cython/symbol.pyd +++ /dev/null @@ -1,6 +0,0 @@ -ctypedef void* SymbolHandle - - -cdef class Symbol: - # handle for symbolic operator. - cdef SymbolHandle handle diff --git a/nnvm/python/nnvm/cython/symbol.pyx b/nnvm/python/nnvm/cython/symbol.pyx index ee3e2d0bf258..fd7c3755786c 100644 --- a/nnvm/python/nnvm/cython/symbol.pyx +++ b/nnvm/python/nnvm/cython/symbol.pyx @@ -14,21 +14,25 @@ include "./base.pyi" cdef extern from "nnvm/c_api.h": const char* NNGetLastError(); - int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, - AtomicSymbolCreator **out_array); - int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + int NNListAllOpNames(nn_uint *out_size, + const char ***out_array); + int NNGetOpHandle(const char *op_name, + OpHandle *handle); + int NNGetOpInfo(OpHandle op, + const char **name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); + int NNListOpNames(nn_uint *out_size, + const char ***out_array); + int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char **keys, const char **vals, SymbolHandle *out); - int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); int NNSymbolFree(SymbolHandle symbol); int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, @@ -88,7 +92,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): _symbol_cls = SymbolBase -def _set_symbol_class(cls): +cdef _set_symbol_class(cls): global _symbol_cls _symbol_cls = cls @@ -98,9 +102,9 @@ cdef NewSymbol(SymbolHandle handle): (sym).handle = handle return sym -cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): +cdef _make_atomic_symbol_function(OpHandle handle, string name): """Create an atomic symbol function by handle and funciton name.""" - cdef const char *name + cdef const char *real_name cdef const char *desc cdef nn_uint num_args cdef const char** arg_names @@ -108,13 +112,14 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): cdef const char** arg_descs cdef const char* return_type - CALL(NNSymbolGetAtomicSymbolInfo( - handle, &name, &desc, + CALL(NNGetOpInfo( + handle, &real_name, &desc, &num_args, &arg_names, &arg_types, &arg_descs, &return_type)) + param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs) - func_name = py_str(name) + func_name = py_str(name.c_str()) doc_str = ('%s\n\n' + '%s\n' + 'name : string, optional.\n' + @@ -190,20 +195,23 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): return creator -def _init_symbol_module(): +def _init_symbol_module(symbol_class, root_namespace): """List and add all the atomic symbol functions to current module.""" - cdef AtomicSymbolCreator* plist + cdef const char** op_name_ptrs cdef nn_uint size - CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) - module_obj = _sys.modules["nnvm.symbol"] - module_internal = _sys.modules["nnvm._symbol_internal"] - for i in range(size): - function = _make_atomic_symbol_function(plist[i]) + cdef vector[string] op_names + cdef OpHandle handle + _set_symbol_class(symbol_class) + CALL(NNListAllOpNames(&size, &op_name_ptrs)) + for i in range(size): + op_names.push_back(string(op_name_ptrs[i])); + module_obj = _sys.modules["%s.symbol" % root_namespace] + module_internal = _sys.modules["%s._symbol_internal" % root_namespace] + for i in range(op_names.size()): + CALL(NNGetOpHandle(op_names[i].c_str(), &handle)) + function = _make_atomic_symbol_function(handle, op_names[i]) if function.__name__.startswith('_'): setattr(module_internal, function.__name__, function) else: setattr(module_obj, function.__name__, function) - -# Initialize the atomic symbol in startups -_init_symbol_module() diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index e753299c7433..df5bad28266a 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -12,15 +12,16 @@ # Use different verison of SymbolBase # When possible, use cython to speedup part of computation. + try: - if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0: - from .ctypes.symbol import SymbolBase, _set_symbol_class + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from ._ctypes.symbol import SymbolBase, _init_symbol_module elif _sys.version_info >= (3, 0): - from ._cy3.symbol import SymbolBase, _set_symbol_class + from ._cy3.symbol import SymbolBase, _init_symbol_module else: - from ._cy2.symbol import SymbolBase, _set_symbol_class -except: - from .ctypes.symbol import SymbolBase, _set_symbol_class + from ._cy2.symbol import SymbolBase, _init_symbol_module +except ImportError: + from ._ctypes.symbol import SymbolBase, _init_symbol_module class Symbol(SymbolBase): @@ -286,4 +287,4 @@ def Group(symbols): return Symbol(handle) # Set the real symbol class to Symbol -_set_symbol_class(Symbol) +_init_symbol_module(Symbol, "nnvm") diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index dcdce820c7e6..ebcc252ed455 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -10,24 +10,45 @@ using namespace nnvm; -int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, - AtomicSymbolCreator **out_array) { +int NNListAllOpNames(nn_uint *out_size, + const char*** out_array) { + API_BEGIN(); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + ret->ret_vec_str = dmlc::Registry::ListAllNames(); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_array = dmlc::BeginPtr(ret->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_str.size()); + API_END(); +} + +int NNGetOpHandle(const char* op_name, + OpHandle* op_out) { + API_BEGIN(); + *op_out = (OpHandle)Op::Get(op_name); // NOLINT(*) + API_END(); +} + +int NNListUniqueOps(nn_uint *out_size, + OpHandle **out_array) { API_BEGIN(); auto &vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); - *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - const Op *op = static_cast(creator); +int NNGetOpInfo(OpHandle handle, + const char **name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type) { + const Op *op = static_cast(handle); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); @@ -51,7 +72,7 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, API_END(); } -int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, +int NNSymbolCreateAtomicSymbol(OpHandle creator, nn_uint num_param, const char **keys, const char **vals, diff --git a/nnvm/src/core/op.cc b/nnvm/src/core/op.cc index a5518ace6be1..a8e54ab92cab 100644 --- a/nnvm/src/core/op.cc +++ b/nnvm/src/core/op.cc @@ -38,6 +38,11 @@ Op::Op() { index_ = mgr->op_counter++; } +Op& Op::add_alias(const std::string& alias) { // NOLINT(*) + dmlc::Registry::Get()->AddAlias(this->name, alias); + return *this; +} + // find operator by name const Op* Op::Get(const std::string& name) { const Op* op = dmlc::Registry::Find(name);