Skip to content

Commit

Permalink
[RUNTIME] Improve error messages for TypedPackedFunc (#7152)
Browse files Browse the repository at this point in the history
* [RUNTIME] Improve error messages for TypedPackedFunc

- TypedPackedFunc now prints the function name when the incorrect number
  of arguments is passed.
- TypedPackedFunc now prints the function name and which argument when
  an argument cannot be converted to the correct type.

* check argument conversion by template deducing argument types

* switch from template approach to TVMMovableArgValueWithContext

* move passes back into cc files

* remove error message prefixes

* Remove TVM_ICHECK_TYPE_CODE. Rename name to optional_name.

* revert changes to module pass for later PR

* reverted too much

* documentation

* formatting

* more docs

* unify error message language. TypedPackedFunc contrustor that does not take a name

* Update include/tvm/runtime/packed_func.h

Co-authored-by: Junru Shao <junrushao1994@gmail.com>

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
tkonolige and junrushao committed Jan 28, 2021
1 parent 67acad3 commit f17cba7
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 59 deletions.
153 changes: 125 additions & 28 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace runtime {
// forward declarations
class TVMArgs;
class TVMArgValue;
class TVMMovableArgValue_;
class TVMMovableArgValueWithContext_;
class TVMRetValue;
class TVMArgsSetter;

Expand Down Expand Up @@ -215,14 +215,38 @@ class TypedPackedFunc<R(Args...)> {
* \brief constructor from TVMMovableArgValue_
* \param value The TVMMovableArgValue_
*/
inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*)
inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*)
/*!
* \brief construct from a lambda function with the same signature.
*
* Example usage:
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda, "add_one");
* // call the typed version.
* ICHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \param name the name of the lambda function.
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
FLambda, std::function<R(Args...)>>::value>::type>
TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda, name);
}
/*!
* \brief construct from a lambda function with the same signature.
*
* This version does not take a name. It is highly recommend you use the
* version that takes a name for the lambda.
*
* Example usage:
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda);
* // call the typed version.
* ICHECK_EQ(ftyped(1), 2);
Expand All @@ -231,9 +255,8 @@ class TypedPackedFunc<R(Args...)> {
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda, typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>>::value>::type>
template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
FLambda, std::function<R(Args...)>>::value>::type>
TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
}
Expand Down Expand Up @@ -297,6 +320,17 @@ class TypedPackedFunc<R(Args...)> {
* \brief Assign the packed field using a typed lambda function.
*
* \param flambda The lambda function.
* \param name The name associated with this lambda.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template <typename FLambda>
inline void AssignTypedLambda(FLambda flambda, std::string name);
/*!
* \brief Assign the packed field using a typed lambda function. This variant is for functions
* without names.
*
* \param flambda The lambda function.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
Expand Down Expand Up @@ -337,7 +371,7 @@ inline const char* ArgTypeCode2Str(int type_code);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
ICHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE)
ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)

/*!
* \brief Type traits for runtime type check during FFI conversion.
Expand Down Expand Up @@ -401,8 +435,8 @@ class TVMPODValue_ {
return static_cast<DLTensor*>(value_.v_handle);
} else {
if (type_code_ == kTVMNullptr) return nullptr;
LOG(FATAL) << "Expect "
<< "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_);
LOG(FATAL) << "Expected "
<< "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_);
return nullptr;
}
}
Expand Down Expand Up @@ -442,6 +476,7 @@ class TVMPODValue_ {
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMMovableArgValue_;
TVMPODValue_() : type_code_(kTVMNullptr) {}
TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {}

Expand Down Expand Up @@ -562,6 +597,44 @@ class TVMMovableArgValue_ : public TVMPODValue_ {
TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); }
};

/*!
* \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with
* additional context information (function name and argument index) for better error reporting.
*
* \sa MovableArgValue_
* \note For internal development purpose only.
*/
class TVMMovableArgValueWithContext_ {
public:
/*!
* \brief move constructor from another return value.
* \param value The other return value.
* \param type_code The code associated with the type of the value.
* \param arg_index In a function call, this argument is at index arg_index (0-indexed).
* \param optional_name Name of the function being called. Can be nullptr if the function is not
* named.
*/
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index,
const std::string* optional_name)
: value_(value, type_code), arg_index_(arg_index), optional_name_(optional_name) {}

template <typename T>
operator T() const {
try {
return value_; // implicit conversion happens here
} catch (dmlc::Error& e) {
LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_)
<< ": error while converting argument " << arg_index_ << ": " << e.what();
throw; // never reached, LOG(FATAL) throws, but this silences a warning.
}
}

private:
TVMMovableArgValue_ value_;
int arg_index_;
const std::string* optional_name_;
};

/*!
* \brief Return Value container,
* Unlike TVMArgValue, which only holds reference and do not delete
Expand Down Expand Up @@ -1213,20 +1286,23 @@ namespace detail {
template <typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template <typename... Args>
TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
// construct a movable argument value
// which allows potential move of argument to the input of F.
unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
f, args_pack, rv, std::forward<Args>(unpacked_args)...,
TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index]));
optional_name, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index,
optional_name));
}
};

template <typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template <typename... Args>
TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
if (std::is_same<RetType, R>::value) {
Expand All @@ -1240,16 +1316,21 @@ struct unpack_call_dispatcher<R, 0, index, F> {
template <int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template <typename... Args>
TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv,
TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
const TVMArgs& args_pack, TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};

template <typename R, int nargs, typename F>
TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
ICHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f,
const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size()) << "Function "
<< (optional_name == nullptr ? "<anonymous>" : *optional_name)
<< " expects " << nargs << " arguments but " << args.size()
<< " were provided";
unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f, args, rv);
}

template <typename FType>
Expand All @@ -1259,7 +1340,7 @@ template <typename R, typename... Args>
struct unpack_call_by_signature<R(Args...)> {
template <typename F>
TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call<R, sizeof...(Args)>(f, args, rv);
unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv);
}
};

Expand Down Expand Up @@ -1297,14 +1378,30 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
: packed_(value.operator PackedFunc()) {}

template <typename R, typename... Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValue_&& value)
TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValueWithContext_&& value)
: packed_(value.operator PackedFunc()) {}

template <typename R, typename... Args>
template <typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
packed_ = PackedFunc([flambda, name](const TVMArgs& args, TVMRetValue* rv) {
if (args.size() != sizeof...(Args)) {
LOG(FATAL) << "Function " << name << " expects " << sizeof...(Args) << " arguments, but "
<< args.size() << " were provided.";
}
detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv);
});
}

template <typename R, typename... Args>
template <typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
if (args.size() != sizeof...(Args)) {
LOG(FATAL) << "Function <anonymous> expects " << sizeof...(Args) << " arguments, but "
<< args.size() << " were provided.";
}
detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv);
});
}

Expand Down Expand Up @@ -1377,7 +1474,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;

if (type_code_ == kTVMNullptr) {
ICHECK(TObjectRef::_type_is_nullable)
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
Expand All @@ -1387,29 +1484,29 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
ObjectPtr<Object> data =
NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
ICHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
CHECK(data->IsInstance<ContainerType>())
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
ICHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
CHECK(data->IsInstance<ContainerType>())
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (type_code_ == kTVMObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
ICHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected " << ObjectTypeChecker<TObjectRef>::TypeName() << " but got "
<< ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (type_code_ == kTVMObjectRValueRefArg) {
Object* ptr = *static_cast<Object**>(value_.v_handle);
ICHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected " << ObjectTypeChecker<TObjectRef>::TypeName() << " but got "
<< ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Registry {
template <typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
}
/*!
* \brief set the body of the function to be the passed method pointer.
Expand Down Expand Up @@ -122,7 +122,7 @@ class Registry {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -152,7 +152,7 @@ class Registry {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -194,7 +194,7 @@ class Registry {
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -236,7 +236,7 @@ class Registry {
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}

/*!
Expand Down
15 changes: 3 additions & 12 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,7 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
});
TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm").set_body_typed(MakeInstanceNorm);

RELAY_REGISTER_OP("nn.instance_norm")
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Expand Down Expand Up @@ -785,10 +782,7 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, b
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv);
});
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);

RELAY_REGISTER_OP("nn.layer_norm")
.describe(R"code(
Expand Down Expand Up @@ -831,10 +825,7 @@ Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, d
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
});
TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm").set_body_typed(MakeGroupNorm);

RELAY_REGISTER_OP("nn.group_norm")
.describe(R"code(
Expand Down
10 changes: 2 additions & 8 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig
return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeSparseDense, args, rv);
});
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense").set_body_typed(MakeSparseDense);

RELAY_REGISTER_OP("nn.sparse_dense")
.describe(
Expand All @@ -130,10 +127,7 @@ Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices, Exp
return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDensePadded, args, rv);
});
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded").set_body_typed(MakeSparseDensePadded);

RELAY_REGISTER_OP("nn.internal.sparse_dense_padded")
.describe(
Expand Down
Loading

0 comments on commit f17cba7

Please sign in to comment.