From 41cfd3d92d1ccd397b76ac45075ccddcf7da6628 Mon Sep 17 00:00:00 2001 From: Jinkun Lin Date: Tue, 5 Apr 2022 02:08:36 -0400 Subject: [PATCH] [TIR] Fix int32 vs int64 mismatch in For construct. (#10595) * Respect dtype in Scalarize. * Add unittest. * Fix lint. * Promote dtype of IntImm to match loop_var in For. * Fix dtype mismatches. * Lint * Lint. * jostle ci * Match dtype in hybrid parser. --- python/tvm/script/tir/scope_handler.py | 18 ++++++++++----- python/tvm/te/hybrid/parser.py | 8 ++++++- python/tvm/tir/ir_builder.py | 22 ++++++++++++++++++- python/tvm/topi/cuda/scan.py | 4 ++-- python/tvm/topi/cuda/sort.py | 2 +- src/te/operation/op_utils.cc | 2 +- src/tir/ir/stmt.cc | 20 +++++++++++++++++ .../schedule/primitive/cache_read_write.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 3 ++- tests/python/unittest/test_tir_buffer.py | 2 +- tests/python/unittest/test_tir_ir_builder.py | 2 +- .../unittest/test_tir_transform_ir_utils.py | 4 ++-- .../unittest/test_tir_transform_vectorize.py | 10 +++++++++ 13 files changed, 81 insertions(+), 18 deletions(-) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 2da7b78b16cd..2e1d5b605913 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -467,18 +467,24 @@ def enter_scope( self.node = node self.context = context - # generate loop vars - self.loop_vars = [ - tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans) - ] # collect loop infos by calling self.func call_with_error_reporting(context.report_error, span, self.func, *arg_list) - if len(self.loop_vars) != len(self.loop_info): + if len(loop_var_names) != len(self.loop_info): self.context.report_error( - f"Inconsistent number of vars and loops, got {len(self.loop_vars)} " + f"Inconsistent number of vars and loops, got {len(loop_var_names)} " + f"vs {len(self.loop_info)}", self.node.span, ) + # generate loop vars + self.loop_vars = [] + for name, lv_span, li in zip(loop_var_names, spans, self.loop_info): + if not li.begin.dtype.startswith("int"): + raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}") + if not li.extent.dtype.startswith("int"): + raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}") + dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32" + self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span)) + for loop_var, loop_info in zip(self.loop_vars, self.loop_info): context.update_symbol(loop_var.name, loop_var, node) context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 442aeb6f1027..1e1e4c50f7b9 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -511,7 +511,13 @@ def visit_For(self, node): if iter_var is None: _internal_assert(kind is not None, "The loop iterating function parse error!") - offset = iter_var = tvm.te.var(_name) + if isinstance(ext, _expr.PrimExpr): + dtype = ext.dtype + elif isinstance(ext, int): + dtype = "int32" + else: + raise NotImplementedError(f"Unsupported type of ext: {type(ext)}") + offset = iter_var = tvm.te.var(_name, dtype=dtype) if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 334902b53229..ce8cd1b403bc 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -201,7 +201,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): + def for_range(self, begin, end, name="i", dtype=None, kind="serial"): """Create a for iteration scope. Parameters @@ -240,6 +240,26 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3) self.nidx += 1 self._seq_stack.append([]) + + # auto infer dtype when it's not specified + def get_dtype(expr): + if isinstance(expr, _expr.PrimExpr): + if not expr.dtype.startswith("int"): + raise NotImplementedError( + f"Infer loop_var dtype failed:" + f" unsupported dtype in loop begin or end {expr.dtype}" + ) + return expr.dtype + if isinstance(expr, int): + return "int32" + raise NotImplementedError( + f"Infer loop_var dtype failed:" + f" unsupported dtype in loop begin or end {expr.dtype}" + ) + + if dtype is None: + dtype = "int64" if "int64" in [get_dtype(begin), get_dtype(end)] else "int32" + loop_var = _expr.Var(name, dtype=dtype) extent = end if begin == 0 else (end - begin) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0d19a92f2058..3be13d7711db 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i # Up Sweep of exclusive scan lim = ceil_log2(scan_axis_size) - with ib.for_range(0, lim, dtype="int64") as l2_width: + with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width: width = 2 << l2_width with ib.new_scope(): @@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i reduction[bx] = output[(bx + 1) * scan_axis_size - 1] output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype) - with ib.for_range(0, lim, dtype="int64") as l2_width: + with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) with ib.new_scope(): diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 25cc7a4e2cfb..b23c3db007f3 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -323,7 +323,7 @@ def assign_j(): with ib.else_scope(): assign_j() - with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width: + with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width: width = 2 << (l2_width + lower_lim) # Define and launch the cuda kernel with ib.new_scope(): diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index bedea414474f..fd2a5c89f324 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -128,7 +128,7 @@ std::vector > MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op)); value_map[iv] = promote_to_bound_dtype(dom->min); } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op)); + nest[i + 1].emplace_back(For(var, 0, promote_to_bound_dtype(dom->extent), kind, no_op)); value_map[iv] = promote_to_bound_dtype(var); } else { Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype()); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d46132b89713..43c2d3745964 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -147,6 +147,26 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); + // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them + // without raising errors. + auto try_promote_imm_dtype = [&](const PrimExpr& e) { + ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) + << " Loop variable's dtype (" << loop_var.dtype() + << ") is narrower than that of `min` or `extent` (" << e.dtype() << ")"; + const IntImmNode* a = e.as(); + if (a && e.dtype().bits() < loop_var.dtype().bits()) { + return make_const(loop_var.dtype(), a->value); + } else { + return e; + } + }; + + min = try_promote_imm_dtype(min); + extent = try_promote_imm_dtype(extent); + + ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); + ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 13b7a5a328ea..1bba2ae4fc61 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -108,7 +108,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, std::vector iter_values; // Create loop vars and block vars' binding_value for (const Range& axis_range : cache_region->region) { - Var loop_var("ax" + std::to_string(loop_vars.size())); + Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype()); loop_vars.push_back(loop_var); iter_values.push_back(axis_range->min + loop_var); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index feb396569ff9..5c5a47e86a9a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -569,7 +569,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return For(idx, 0, var_lanes_, ForKind::kSerial, stmt); + return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, + stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index e790ffc199e5..990d0a22c817 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -99,7 +99,7 @@ def test_buffer_vload_nullptr(): buf_load = tvm.tir.expr.BufferLoad(buffer=buf, indices=tvm.runtime.convert([0])) buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load) for_loop = tvm.tir.stmt.For( - loop_var=var, kind=0, min_val=0, extent=buf_load, body=buf_load_stmt + loop_var=var, kind=0, min_val=0, extent=tvm.tir.Cast("int32", buf_load), body=buf_load_stmt ) buf_func = tvm.tir.PrimFunc(params={}, body=for_loop) mod = tvm.IRModule({"main": buf_func}) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 9438da17ede2..8a39337575a7 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -517,7 +517,7 @@ def test_device_ir(A, B): temp[tx] = Aptr[tx] depth = tvm.tir.log2(cast(n, "float32")) - with ib.for_range(0, depth) as i: + with ib.for_range(0, cast(tvm.tir.ceil(depth), n.dtype)) as i: ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) d = n >> (i + 1) with ib.if_scope(tx < d): diff --git a/tests/python/unittest/test_tir_transform_ir_utils.py b/tests/python/unittest/test_tir_transform_ir_utils.py index 8030b77f9946..0946b32cca3f 100644 --- a/tests/python/unittest/test_tir_transform_ir_utils.py +++ b/tests/python/unittest/test_tir_transform_ir_utils.py @@ -26,9 +26,9 @@ def test_convert_ssa(): var_type = ir.PointerType(ir.PrimType(dtype)) v = tir.Var("i1", var_type) buf = tir.decl_buffer([16], dtype=dtype, data=v) - for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop) + let = tir.LetStmt(v, v, nop) load = tir.Evaluate(tir.BufferLoad(buf, [zero])) - seq = tir.SeqStmt([for_stmt, for_stmt, load]) + seq = tir.SeqStmt([let, let, load]) func = tir.PrimFunc([], seq) mod = tvm.IRModule({"main": func}) mod = tir.transform.InjectVirtualThread()( diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index 6558de31c00b..5b6f7de97bc6 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -85,6 +85,16 @@ def test_vectorize_with_if(): assert isinstance(stmt.else_case, tvm.tir.For) +def test_vectorize_with_if_cond_int64(): + m = te.size_var("m", dtype="int64") + A = te.placeholder((m,), name="A", dtype="float32") + B = te.compute((m,), lambda i: te.if_then_else(i < 2, A[i], A[i] * 2), name="B") + s = te.create_schedule(B.op) + x, y = s[B].split(B.op.axis[0], factor=4) + s[B].vectorize(y) + f = tvm.build(s, [A, B], "llvm") + + def test_vectorize_let(): v = tvm.tir.Var("v", "float32") ib = tvm.tir.ir_builder.create()