Skip to content

Commit

Permalink
[TIR][Arith] Avoid assigning range of possible values to integers (#1…
Browse files Browse the repository at this point in the history
…1859)

Previously, in `ConstIntBoundAnalyzer`, entering a conditional such as
`if 2==0` could result in the expression `2` being treated as having a
known value of zero within the body of the conditional.  Evaluating
the range of expressions using `2` in the body of the conditional
could result in exceptions being thrown, such as evaluating `expr / 2`
while setting `2` to its maximum value of zero.

This issue was present for conditions with inequalities for some time,
but was introduced for conditions with equalities in
#11524.  Both types are resolved in
this PR.
  • Loading branch information
Lunderberg committed Jun 24, 2022
1 parent d439f6c commit ed638ef
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 18 deletions.
41 changes: 24 additions & 17 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,29 +637,36 @@ class ConstIntBoundAnalyzer::Impl
static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
PVar<PrimExpr> x, y;
PVar<IntImm> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
}
if ((c < x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))};
}
if ((x <= c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))};
}
if ((x < c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
}
if ((x == c).Match(cond) || (c == x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))};
}
if ((x && y).Match(cond)) {
auto ret1 = DetectBoundInfo(x.Eval());
auto ret2 = DetectBoundInfo(y.Eval());
ret1.insert(ret1.end(), ret2.begin(), ret2.end());
return ret1;
}
return {};

// NOTE: canonical form always use <= or <
Entry bound;
if ((c <= x).Match(cond)) {
bound = MakeBound(c.Eval()->value, kPosInf);
} else if ((c < x).Match(cond)) {
bound = MakeBound(c.Eval()->value + 1, kPosInf);
} else if ((x <= c).Match(cond)) {
bound = MakeBound(kNegInf, c.Eval()->value);
} else if ((x < c).Match(cond)) {
bound = MakeBound(kNegInf, c.Eval()->value - 1);
} else if ((x == c).Match(cond) || (c == x).Match(cond)) {
bound = MakeBound(c.Eval()->value, c.Eval()->value);
} else {
return {};
}

// If the conditional is comparing two integers, do not assign a
// value to them.
if (x.Eval().as<IntImmNode>()) {
return {};
}

return {BoundInfo(x.Eval(), bound)};
}

/*!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import tvm
import tvm.testing
from tvm.script import tir as T

# fmt: off
Expand Down Expand Up @@ -124,5 +125,55 @@ def test_renormalize_split_pattern():
tvm.ir.assert_structural_equal(after, After_simplified)


@T.prim_func
def impossible_equality(n: T.int32):
# Prior to bugfix, this conditional defined the expression "2" as
# equal to zero within the then_case. [min_value=2, max_value=0]
if 2 == 0:
# Then this expression evaluates n/2, using the min/max values
# of "2", which is caught as a divide by zero error.
if n / 2 >= 16:
T.evaluate(0)


@T.prim_func
def impossible_inequality(n: T.int32):
# Prior to bugfix, this conditional set up a range of possible
# values for the expression "-2" as [0, kPosInf].
if -1 < -2:
if n / (-2) >= 16:
T.evaluate(0)


integer_condition = tvm.testing.parameter(
impossible_equality,
impossible_inequality,
)


def test_analyze_inside_integer_conditional(integer_condition):
"""Avoid crash occurring in ConstIntBoundAnalyzer.
Crash occurred when simplifying some expressions with provably
false integer expressions. If the expressions were renormalized
before calling Simplify, conditional statements could assign a
range of possible values to integers, as if they were variables.
This would result in divide by zero throwing an exception,
followed by a second exception during stack unwinding causing the
program to crash.
"""

# Similar issue would occur in most transformations that subclass
# IRMutatorWithAnalyzer. tir.transform.Simplify() is an
# exception, as it rewrites the integer conditionals first. These
# tests are written using RenormalizeSplitPattern as it is the
# first case identified.
transform = tvm.tir.transform.RenormalizeSplitPattern()

# Issue would result in an error through while applying the transformation.
mod = tvm.IRModule.from_expr(integer_condition)
transform(mod)


if __name__ == "__main__":
tesd_renormalize_split_pattern()
tvm.testing.main()

0 comments on commit ed638ef

Please sign in to comment.