diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4fd27a0fde10..cabf299a886b 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -177,7 +177,17 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const CastNode* op) final { - Entry a = VisitExpr(op->value); + Entry a; + + // int(ceil(log2(cast(n,"float64")))) is used as the + // implementation of topi.math.ceil_log2, and appears in iteration + // bounds. + if (auto opt = FindCeilLog2Arg(op)) { + a = CeilLog2Bounds(opt.value()); + } else { + a = VisitExpr(op->value); + } + Entry b = Everything(op->dtype); return Intersect(a, b); } @@ -314,6 +324,8 @@ class ConstIntBoundAnalyzer::Impl if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); + } else if (op->op.same_as(tir::builtin::shift_left())) { + return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { @@ -341,6 +353,20 @@ class ConstIntBoundAnalyzer::Impl } } + Entry VisitLeftShift(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + + if (a.min_value < 0 || b.min_value < 0) { + // If either operand can negative, we may run into undefined + // behavior for some targets. In these cases, avoid making any + // assumptions about the result. + return Everything(op->dtype); + } + + return BinaryOpBoundary(a, b, InfAwareLeftShift); + } + Entry VisitRightShift(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); @@ -509,7 +535,33 @@ class ConstIntBoundAnalyzer::Impl return floordiv(x, y); } /*! - * \brief Compute x / y, aware of inf. + * \brief Compute x << y, aware of inf. + * \param x The left operand. + * \param y The right operand. + * \return the result. + */ + static int64_t InfAwareLeftShift(int64_t x, int64_t y) { + if (x == kPosInf || x == kNegInf) return x; + + // Can be replaced with std::bit_width in C++20 + auto bit_width = [](int64_t as_signed) { + uint64_t val = std::abs(as_signed); + int num_bits = 0; + while (val) { + ++num_bits; + val >>= 1; + } + return num_bits; + }; + int x_bits = bit_width(x); + if (x_bits + y < 64) { + return x << y; + } else { + return kPosInf; + } + } + /*! + * \brief Compute x >> y, aware of inf. * \param x The left operand. * \param y The right operand. * \return the result. @@ -609,6 +661,46 @@ class ConstIntBoundAnalyzer::Impl } return {}; } + + /*! + * \brief Extract the argument from int(ceil(log2(arg))) + * + * This expression is used as the implementation of + * topi.math.ceil_log2, and can appear in iteration bounds. + */ + static Optional FindCeilLog2Arg(const CastNode* op) { + if (op->dtype.is_int()) { + if (auto as_call = op->value.as()) { + if (as_call->op.same_as(Op::Get("tir.ceil"))) { + PrimExpr ceil_arg = as_call->args[0]; + if (auto arg_call = ceil_arg.as()) { + if (arg_call->op.same_as(Op::Get("tir.log2"))) { + PrimExpr log_arg = arg_call->args[0]; + return log_arg; + } + } + } + } + } + return NullOpt; + } + + /*! \brief Propagate constraints through ceil(log2(arg)) + * + * Helper function for CastNode visitor + */ + Entry CeilLog2Bounds(PrimExpr arg) { + if (auto as_float = arg.as()) { + // A cast from int to float may have already been simplified + // out. Normally we don't inspect floating-point arguments, but here we can + int64_t val = std::ceil(std::log2(as_float->value)); + return MakeBound(val, val); + } else { + Entry arg_bounds = VisitExpr(arg); + return MakeBound(std::ceil(std::log2(arg_bounds.min_value)), + std::ceil(std::log2(arg_bounds.max_value))); + } + } }; ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a168e1f0836c..769e58698e09 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1640,13 +1640,34 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // the operator overload will eagerly constant fold. return op->args[0] << op->args[1]; } + } else if (op->op.same_as(Op::Get("tir.ceil"))) { + PrimExpr ceil_arg = op->args[0]; + if (auto arg_int = op->args[0].as()) { + return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); + } else if (auto arg_float = ceil_arg.as()) { + return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value))); + } else if (auto arg_call = ceil_arg.as()) { + // ceil(log2(cast(n,"float64"))) is used as the implementation of + // topi.math.ceil_log2, and appears in iteration bounds. + if (arg_call->op.same_as(Op::Get("tir.log2"))) { + PrimExpr log_arg = arg_call->args[0]; + if (auto as_float = log_arg.as()) { + // ceil(log2(n)) can be simplified, and should produce the + // same integer result regardless of the target's rounding + // conventions. + return FloatImm(op->dtype, std::ceil(std::log2(as_float->value))); + } + } + } } + if (op->op.same_as(tir::builtin::likely())) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (auto match = TryMatchLiteralConstraint(op->args[0])) { return match.value(); } } + return ret; } diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 4f727cd89b12..49e8ee3f786d 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -391,5 +391,112 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): A[i, j] = 2 +class TestCeilLog2Int(BaseBeforeAfter): + """Simplify expressions resulting from topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer[1, "int32"]): + A[0] = T.cast( + T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32" + ) + + @T.prim_func + def expected(A: T.Buffer[1, "int32"]): + A[0] = 4 + + +class TestLeftCeilLog2LowerBound(BaseBeforeAfter): + """Integer bounds are propagated through topi.math.ceil_log2""" + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + x = T.cast( + T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"), + dtype="int32", + ) + if x == 11: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + +class TestLeftShiftLowerBound(BaseBeforeAfter): + """Integer bounds are propagated through left shift + + min(1 << i) = 1 << min(i) + = 1 << 0 + = 1 + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(1, i, dtype="int32") >= 1: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + +class TestLeftShiftUpperBound(BaseBeforeAfter): + """Integer bounds are propagated through left shift + + max(31 << i) = 31 << max(i) + = 31 << 15 + = 1015808 + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(31, i, dtype="int32") <= 1015808: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + A[i] = 0.0 + + +class TestLeftShiftOfNegativeValue(BaseBeforeAfter): + """No const int bounds of left shift of negative value. + + This is target dependent, and does not currently have a specified + behavior in TIR. For example, in CodeGenC, this generates C code + with undefined behavior. + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if -64 <= T.shift_left(-i, 4, dtype="int32"): + A[i] = 0.0 + + expected = before + + +class TestLeftShiftByNegativeValue(BaseBeforeAfter): + """No const int bounds of left shift by negative bit count. + + This is target dependent, and does not currently have a specified + behavior in TIR. For example, in CodeGenC, this generates C code + with undefined behavior. + """ + + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if T.shift_left(16, -i, dtype="int32") <= 16: + A[i] = 0.0 + + expected = before + + if __name__ == "__main__": tvm.testing.main()