Skip to content

Commit

Permalink
[TIR][OP][API-CHANGE] Remove CallNode.call_type in favor of attribute.
Browse files Browse the repository at this point in the history
This is a followup refactor for tir::Call.
Now that we have switched call->name to call->op, the function effect property
can be registered through the op itself, so we no longer need the call_type in the CallNode.

- Introduce CallEffectKind to provide a more fine grained categorization of calls.
- Introduce call_pure_extern and call_llvm_pure_intrin to
  allow us to indicate pure calls in those cases.
- Migrate existing usecases to the new API.
  • Loading branch information
tqchen committed Jun 26, 2020
1 parent 75f2539 commit b1a7fbb
Show file tree
Hide file tree
Showing 80 changed files with 744 additions and 694 deletions.
34 changes: 30 additions & 4 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,25 @@ TVM_DLL const Op& fma();
*/
TVM_DLL const Op& call_extern();

/*!
* \brief Call an pure extern C function with given name
* and signature from the types of args in the runtime environment.
*
* Type call_pure_extern(name, args...) {
* return dlsym(name)(args...);
* }
*
* \note This intrinsic does not provide any type checking,
* and is main used for backward compatibility reasons.
* Always consider use pre-registered and typed tvm::Op first.
*/
TVM_DLL const Op& call_pure_extern();

/*!
* \brief Call an LLVM intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_intrin(intrin_id, args...) {
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
Expand All @@ -165,15 +179,27 @@ TVM_DLL const Op& call_extern();
TVM_DLL const Op& call_llvm_intrin();

/*!
* \brief Call an SPIRV GLSL450 intrinsic.
* \brief Call an LLVM pure intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
* Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_llvm_pure_intrin();

/*!
* \brief Call an SPIRV pure GLSL450 intrinsic.
*
* Type call_spirv_glsl450(intrin_id, args...) {
* Type call_spirv_pure_glsl450(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
TVM_DLL const Op& call_spirv_glsl450();
TVM_DLL const Op& call_spirv_pure_glsl450();

// TODO(tvm-team) revisit the builtins below
// some of them can simply become ops with special codegen attr.
Expand Down
28 changes: 2 additions & 26 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -875,19 +875,6 @@ class Let : public PrimExpr {
*/
class CallNode : public PrimExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
/*! \brief Extern "C" function. */
Extern = 0,
/*! \brief Extern CXX function. */
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
PureIntrinsic = 5
};
/*!
* \brief The operator(function) being invoked
*
Expand All @@ -898,31 +885,22 @@ class CallNode : public PrimExprNode {

/*! \brief The arguments. */
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
equal(call_type, other->call_type);
return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(op);
hash_reduce(args);
hash_reduce(call_type);
}

/*! \return Whether call node is pure. */
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }

static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
};
Expand All @@ -933,9 +911,7 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
using CallType = CallNode::CallType;

TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};

Expand Down
16 changes: 8 additions & 8 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}); \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand All @@ -583,10 +583,10 @@ TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);

#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}); \
}

TVM_DECLARE_INTRIN_BINARY(atan2);
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,43 @@ using TGlobalSymbol = String;
*/
using TVectorizable = bool;

/*!
* \brief The effect type of the call.
*/
enum class CallEffectKind : int {
/*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
kExprAnnotation = 0,
/*!
* \brief Pure function that do not interacts
* with any external state.
*/
kPure = 1,
/*!
* \brief Function's that may read from states(e.g. RAM)
*/
kReadState = 2,
/*!
* \brief Function that may read/write from states(e.g. RAM).
*/
kUpdateState = 3,
/*!
* \brief Opaque function, cannot make any assumption
*/
kOpaque = kUpdateState,
/*!
* \brief Special intrinsic to annotate call arguments info
* only valid as a direct argument to a call.
*/
kSpecialCallArg = 4,
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5
};

/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_OP_ATTR_TYPES_H_
4 changes: 2 additions & 2 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import call_pure_intrin
from tvm.tir import call_intrin
from tvm.tir.stmt import For

from .util import _internal_assert
Expand Down Expand Up @@ -148,7 +148,7 @@ def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
return call_intrin(args[0].dtype, 'tir.likely', *args)


def max_num_threads(func_id, args):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

from .function import PrimFunc

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
12 changes: 2 additions & 10 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,16 +974,8 @@ class Call(PrimExprWithOp):
args : list of Expr
The input arguments to the call
call_type : int
The type of the call
"""
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Intrinsic = 4
PureIntrinsic = 5
def __init__(self, dtype, op, args, call_type):
def __init__(self, dtype, op, args):
if isinstance(op, str):
if not op.startswith("tir."):
raise ValueError(
Expand All @@ -992,7 +984,7 @@ def __init__(self, dtype, op, args, call_type):
"certain about the intrinsic name, pass in Op.get(name) instead") % op)
op = Op.get(op)
self.__init_handle_by_constructor__(
_ffi_api.Call, dtype, op, args, call_type)
_ffi_api.Call, dtype, op, args)


@tvm._ffi.register_object("tir.Let")
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ def likely(self, expr):
expr : Expr
The expression will likely tag.
"""
return _expr.Call(expr.dtype, "tir.likely", [expr],
_expr.Call.PureIntrinsic)
return _expr.Call(expr.dtype, "tir.likely", [expr])

def get(self):
"""Return the builded IR.
Expand Down
Loading

0 comments on commit b1a7fbb

Please sign in to comment.