diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index fd4e2114b11a..87606f3f738c 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -60,7 +60,7 @@ namespace runtime { // forward declarations class TVMArgs; class TVMArgValue; -class TVMMovableArgValue_; +class TVMMovableArgValueWithContext_; class TVMRetValue; class TVMArgsSetter; @@ -215,7 +215,7 @@ class TypedPackedFunc { * \brief constructor from TVMMovableArgValue_ * \param value The TVMMovableArgValue_ */ - inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*) + inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*) /*! * \brief construct from a lambda function with the same signature. * @@ -223,6 +223,30 @@ class TypedPackedFunc { * \code * auto typed_lambda = [](int x)->int { return x + 1; } * // construct from packed function + * TypedPackedFunc ftyped(typed_lambda, "add_one"); + * // call the typed version. + * ICHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \param name the name of the lambda function. + * \tparam FLambda the type of the lambda function. + */ + template >::value>::type> + TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*) + this->AssignTypedLambda(typed_lambda, name); + } + /*! + * \brief construct from a lambda function with the same signature. + * + * This version does not take a name. It is highly recommend you use the + * version that takes a name for the lambda. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function * TypedPackedFunc ftyped(typed_lambda); * // call the typed version. * ICHECK_EQ(ftyped(1), 2); @@ -231,9 +255,8 @@ class TypedPackedFunc { * \param typed_lambda typed lambda function. * \tparam FLambda the type of the lambda function. */ - template >::value>::type> + template >::value>::type> TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); } @@ -297,6 +320,17 @@ class TypedPackedFunc { * \brief Assign the packed field using a typed lambda function. * * \param flambda The lambda function. + * \param name The name associated with this lambda. + * \tparam FLambda The lambda function type. + * \note We capture the lambda when possible for maximum efficiency. + */ + template + inline void AssignTypedLambda(FLambda flambda, std::string name); + /*! + * \brief Assign the packed field using a typed lambda function. This variant is for functions + * without names. + * + * \param flambda The lambda function. * \tparam FLambda The lambda function type. * \note We capture the lambda when possible for maximum efficiency. */ @@ -337,7 +371,7 @@ inline const char* ArgTypeCode2Str(int type_code); // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE) + ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -401,8 +435,8 @@ class TVMPODValue_ { return static_cast(value_.v_handle); } else { if (type_code_ == kTVMNullptr) return nullptr; - LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_); + LOG(FATAL) << "Expected " + << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_); return nullptr; } } @@ -442,6 +476,7 @@ class TVMPODValue_ { protected: friend class TVMArgsSetter; friend class TVMRetValue; + friend class TVMMovableArgValue_; TVMPODValue_() : type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} @@ -562,6 +597,44 @@ class TVMMovableArgValue_ : public TVMPODValue_ { TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); } }; +/*! + * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with + * additional context information (function name and argument index) for better error reporting. + * + * \sa MovableArgValue_ + * \note For internal development purpose only. + */ +class TVMMovableArgValueWithContext_ { + public: + /*! + * \brief move constructor from another return value. + * \param value The other return value. + * \param type_code The code associated with the type of the value. + * \param arg_index In a function call, this argument is at index arg_index (0-indexed). + * \param optional_name Name of the function being called. Can be nullptr if the function is not + * named. + */ + TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, + const std::string* optional_name) + : value_(value, type_code), arg_index_(arg_index), optional_name_(optional_name) {} + + template + operator T() const { + try { + return value_; // implicit conversion happens here + } catch (dmlc::Error& e) { + LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "" : *optional_name_) + << ": error while converting argument " << arg_index_ << ": " << e.what(); + throw; // never reached, LOG(FATAL) throws, but this silences a warning. + } + } + + private: + TVMMovableArgValue_ value_; + int arg_index_; + const std::string* optional_name_; +}; + /*! * \brief Return Value container, * Unlike TVMArgValue, which only holds reference and do not delete @@ -1213,20 +1286,23 @@ namespace detail { template struct unpack_call_dispatcher { template - TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f, + const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. unpack_call_dispatcher::run( - f, args_pack, rv, std::forward(unpacked_args)..., - TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index])); + optional_name, f, args_pack, rv, std::forward(unpacked_args)..., + TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index, + optional_name)); } }; template struct unpack_call_dispatcher { template - TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f, + const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { using RetType = decltype(f(std::forward(unpacked_args)...)); if (std::is_same::value) { @@ -1240,16 +1316,21 @@ struct unpack_call_dispatcher { template struct unpack_call_dispatcher { template - TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f, + const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; template -TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { - ICHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); - unpack_call_dispatcher::run(f, args, rv); +TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f, + const TVMArgs& args, TVMRetValue* rv) { + CHECK_EQ(nargs, args.size()) << "Function " + << (optional_name == nullptr ? "" : *optional_name) + << " expects " << nargs << " arguments but " << args.size() + << " were provided"; + unpack_call_dispatcher::run(optional_name, f, args, rv); } template @@ -1259,7 +1340,7 @@ template struct unpack_call_by_signature { template TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { - unpack_call(f, args, rv); + unpack_call(nullptr, f, args, rv); } }; @@ -1297,14 +1378,30 @@ TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} template -TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) +TypedPackedFunc::TypedPackedFunc(TVMMovableArgValueWithContext_&& value) : packed_(value.operator PackedFunc()) {} +template +template +inline void TypedPackedFunc::AssignTypedLambda(FType flambda, std::string name) { + packed_ = PackedFunc([flambda, name](const TVMArgs& args, TVMRetValue* rv) { + if (args.size() != sizeof...(Args)) { + LOG(FATAL) << "Function " << name << " expects " << sizeof...(Args) << " arguments, but " + << args.size() << " were provided."; + } + detail::unpack_call(&name, flambda, args, rv); + }); +} + template template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { - detail::unpack_call(flambda, args, rv); + if (args.size() != sizeof...(Args)) { + LOG(FATAL) << "Function expects " << sizeof...(Args) << " arguments, but " + << args.size() << " were provided."; + } + detail::unpack_call(nullptr, flambda, args, rv); }); } @@ -1377,7 +1474,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; if (type_code_ == kTVMNullptr) { - ICHECK(TObjectRef::_type_is_nullable) + CHECK(TObjectRef::_type_is_nullable) << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } @@ -1387,29 +1484,29 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - ICHECK(data->IsInstance()) - << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + CHECK(data->IsInstance()) + << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } if (std::is_base_of::value) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); - ICHECK(data->IsInstance()) - << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + CHECK(data->IsInstance()) + << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); - ICHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() << " but get " + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected " << ObjectTypeChecker::TypeName() << " but got " << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (type_code_ == kTVMObjectRValueRefArg) { Object* ptr = *static_cast(value_.v_handle); - ICHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() << " but get " + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected " << ObjectTypeChecker::TypeName() << " but got " << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (std::is_base_of::value && diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 86e3706b2058..859a8ace1abe 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -93,7 +93,7 @@ class Registry { template Registry& set_body_typed(FLambda f) { using FType = typename detail::function_signature::FType; - return set_body(TypedPackedFunc(std::move(f)).packed()); + return set_body(TypedPackedFunc(std::move(f), name_).packed()); } /*! * \brief set the body of the function to be the passed method pointer. @@ -122,7 +122,7 @@ class Registry { // call method pointer return (target.*f)(params...); }; - return set_body(TypedPackedFunc(fwrap)); + return set_body(TypedPackedFunc(fwrap, name_)); } /*! @@ -152,7 +152,7 @@ class Registry { // call method pointer return (target.*f)(params...); }; - return set_body(TypedPackedFunc(fwrap)); + return set_body(TypedPackedFunc(fwrap, name_)); } /*! @@ -194,7 +194,7 @@ class Registry { // call method pointer return (target->*f)(params...); }; - return set_body(TypedPackedFunc(fwrap)); + return set_body(TypedPackedFunc(fwrap, name_)); } /*! @@ -236,7 +236,7 @@ class Registry { // call method pointer return (target->*f)(params...); }; - return set_body(TypedPackedFunc(fwrap)); + return set_body(TypedPackedFunc(fwrap, name_)); } /*! diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index ce622429bdb9..8ace82be9ff8 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -718,10 +718,7 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeInstanceNorm, args, rv); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm").set_body_typed(MakeInstanceNorm); RELAY_REGISTER_OP("nn.instance_norm") .describe(R"code(Instance Normalization (Ulyanov and et al., 2016) @@ -785,10 +782,7 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, b return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLayerNorm, args, rv); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm); RELAY_REGISTER_OP("nn.layer_norm") .describe(R"code( @@ -831,10 +825,7 @@ Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, d return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGroupNorm, args, rv); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm").set_body_typed(MakeGroupNorm); RELAY_REGISTER_OP("nn.group_norm") .describe(R"code( diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index e9073730641d..6322cfffd7c2 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -101,10 +101,7 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense").set_body_typed(MakeSparseDense); RELAY_REGISTER_OP("nn.sparse_dense") .describe( @@ -130,10 +127,7 @@ Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices, Exp return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDensePadded, args, rv); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded").set_body_typed(MakeSparseDensePadded); RELAY_REGISTER_OP("nn.internal.sparse_dense_padded") .describe( diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index f611dc2eefd2..0b198005001b 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -595,9 +595,7 @@ Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool return Call(op, {data, mean}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeVariance, args, rv); -}); +TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body_typed(MakeVariance); RELAY_REGISTER_OP("variance") .describe(R"code(Computes the variance of array elements over given axes. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0e868cdc50c9..d44bfe6959ca 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -157,9 +157,7 @@ Expr MakeReinterpret(Expr data, DataType dtype) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReinterpret, args, rv); -}); +TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body_typed(MakeReinterpret); RELAY_REGISTER_OP("reinterpret") .describe(R"code(Reinterpret the data into a new data type.