diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index cabf299a886b..fa74f83313c9 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -637,29 +637,36 @@ class ConstIntBoundAnalyzer::Impl static std::vector DetectBoundInfo(const PrimExpr& cond) { PVar x, y; PVar c; - // NOTE: canonical form always use <= or < - if ((c <= x).Match(cond)) { - return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; - } - if ((c < x).Match(cond)) { - return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))}; - } - if ((x <= c).Match(cond)) { - return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))}; - } - if ((x < c).Match(cond)) { - return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))}; - } - if ((x == c).Match(cond) || (c == x).Match(cond)) { - return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))}; - } if ((x && y).Match(cond)) { auto ret1 = DetectBoundInfo(x.Eval()); auto ret2 = DetectBoundInfo(y.Eval()); ret1.insert(ret1.end(), ret2.begin(), ret2.end()); return ret1; } - return {}; + + // NOTE: canonical form always use <= or < + Entry bound; + if ((c <= x).Match(cond)) { + bound = MakeBound(c.Eval()->value, kPosInf); + } else if ((c < x).Match(cond)) { + bound = MakeBound(c.Eval()->value + 1, kPosInf); + } else if ((x <= c).Match(cond)) { + bound = MakeBound(kNegInf, c.Eval()->value); + } else if ((x < c).Match(cond)) { + bound = MakeBound(kNegInf, c.Eval()->value - 1); + } else if ((x == c).Match(cond) || (c == x).Match(cond)) { + bound = MakeBound(c.Eval()->value, c.Eval()->value); + } else { + return {}; + } + + // If the conditional is comparing two integers, do not assign a + // value to them. + if (x.Eval().as()) { + return {}; + } + + return {BoundInfo(x.Eval(), bound)}; } /*! diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index fb1fb72eb82c..872afeeba5c7 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -16,6 +16,7 @@ # under the License. import tvm +import tvm.testing from tvm.script import tir as T # fmt: off @@ -124,5 +125,55 @@ def test_renormalize_split_pattern(): tvm.ir.assert_structural_equal(after, After_simplified) +@T.prim_func +def impossible_equality(n: T.int32): + # Prior to bugfix, this conditional defined the expression "2" as + # equal to zero within the then_case. [min_value=2, max_value=0] + if 2 == 0: + # Then this expression evaluates n/2, using the min/max values + # of "2", which is caught as a divide by zero error. + if n / 2 >= 16: + T.evaluate(0) + + +@T.prim_func +def impossible_inequality(n: T.int32): + # Prior to bugfix, this conditional set up a range of possible + # values for the expression "-2" as [0, kPosInf]. + if -1 < -2: + if n / (-2) >= 16: + T.evaluate(0) + + +integer_condition = tvm.testing.parameter( + impossible_equality, + impossible_inequality, +) + + +def test_analyze_inside_integer_conditional(integer_condition): + """Avoid crash occurring in ConstIntBoundAnalyzer. + + Crash occurred when simplifying some expressions with provably + false integer expressions. If the expressions were renormalized + before calling Simplify, conditional statements could assign a + range of possible values to integers, as if they were variables. + This would result in divide by zero throwing an exception, + followed by a second exception during stack unwinding causing the + program to crash. + """ + + # Similar issue would occur in most transformations that subclass + # IRMutatorWithAnalyzer. tir.transform.Simplify() is an + # exception, as it rewrites the integer conditionals first. These + # tests are written using RenormalizeSplitPattern as it is the + # first case identified. + transform = tvm.tir.transform.RenormalizeSplitPattern() + + # Issue would result in an error through while applying the transformation. + mod = tvm.IRModule.from_expr(integer_condition) + transform(mod) + + if __name__ == "__main__": - tesd_renormalize_split_pattern() + tvm.testing.main()