diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 606bc28ddd22..22a91b91b946 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -142,8 +142,6 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value + fb->value); - } else { - return NullOpt; } } if (fa && fa->value == 0) return b; @@ -171,8 +169,6 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value - fb->value); - } else { - return NullOpt; } } if (fb && fb->value == 0) return a; @@ -202,8 +198,6 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value * fb->value); - } else { - return NullOpt; } } if (fa) { @@ -243,8 +237,6 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value / fb->value); - } else { - return NullOpt; } } if (fa && fa->value == 0) return a; diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 9f187685991e..9db3035fd944 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te @@ -124,6 +125,22 @@ def test_div_simplify(): ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16)) +def test_fp16_const_fold(): + ck = CanonicalChecker() + zero = tvm.tir.const(0, "float16") + one = tvm.tir.const(1, "float16") + half = tvm.tir.const(0.5, "float16") + + ck.verify(zero + half, half) + ck.verify(half - zero, half) + + ck.verify(zero * half, zero) + ck.verify(half * one, half) + + ck.verify(half / one, half) + ck.verify(zero / half, zero) + + def test_floormod_simplify(): ck = CanonicalChecker() flm = tvm.te.floormod @@ -356,14 +373,4 @@ def test_simplify_cast(): if __name__ == "__main__": - test_floormod_simplify() - test_mul_sum_simplify() - test_simplify_if_then_else() - test_div_simplify() - test_reduce_simplify() - test_reduce_combiner_simplify() - - test_split_index_simplify() - test_canonical_mixed() - test_complex_cases() - test_simplify_cast() + tvm.testing.main()