Skip to content

Commit

Permalink
[TIR] Fix int32 vs int64 mismatch in For construct. (#10595)
Browse files Browse the repository at this point in the history
* Respect dtype in Scalarize.

* Add unittest.

* Fix lint.

* Promote dtype of IntImm to match loop_var in For.

* Fix dtype mismatches.

* Lint

* Lint.

* jostle ci

* Match dtype in hybrid parser.
  • Loading branch information
lazycal committed Apr 5, 2022
1 parent ceed331 commit 41cfd3d
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 18 deletions.
18 changes: 12 additions & 6 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,18 +467,24 @@ def enter_scope(

self.node = node
self.context = context
# generate loop vars
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans)
]
# collect loop infos by calling self.func
call_with_error_reporting(context.report_error, span, self.func, *arg_list)
if len(self.loop_vars) != len(self.loop_info):
if len(loop_var_names) != len(self.loop_info):
self.context.report_error(
f"Inconsistent number of vars and loops, got {len(self.loop_vars)} "
f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
+ f"vs {len(self.loop_info)}",
self.node.span,
)
# generate loop vars
self.loop_vars = []
for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
if not li.begin.dtype.startswith("int"):
raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
if not li.extent.dtype.startswith("int"):
raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))

for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
context.update_symbol(loop_var.name, loop_var, node)
context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,13 @@ def visit_For(self, node):

if iter_var is None:
_internal_assert(kind is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if isinstance(ext, _expr.PrimExpr):
dtype = ext.dtype
elif isinstance(ext, int):
dtype = "int32"
else:
raise NotImplementedError(f"Unsupported type of ext: {type(ext)}")
offset = iter_var = tvm.te.var(_name, dtype=dtype)
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
"""Create a for iteration scope.
Parameters
Expand Down Expand Up @@ -240,6 +240,26 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
self.nidx += 1
self._seq_stack.append([])

# auto infer dtype when it's not specified
def get_dtype(expr):
if isinstance(expr, _expr.PrimExpr):
if not expr.dtype.startswith("int"):
raise NotImplementedError(
f"Infer loop_var dtype failed:"
f" unsupported dtype in loop begin or end {expr.dtype}"
)
return expr.dtype
if isinstance(expr, int):
return "int32"
raise NotImplementedError(
f"Infer loop_var dtype failed:"
f" unsupported dtype in loop begin or end {expr.dtype}"
)

if dtype is None:
dtype = "int64" if "int64" in [get_dtype(begin), get_dtype(end)] else "int32"

loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else (end - begin)

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
# Up Sweep of exclusive scan
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
Expand Down Expand Up @@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype)

with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def assign_j():
with ib.else_scope():
assign_j()

with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width:
with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width:
width = 2 << (l2_width + lower_lim)
# Define and launch the cuda kernel
with ib.new_scope():
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op));
value_map[iv] = promote_to_bound_dtype(dom->min);
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
nest[i + 1].emplace_back(For(var, 0, promote_to_bound_dtype(dom->extent), kind, no_op));
value_map[iv] = promote_to_bound_dtype(var);
} else {
Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype());
Expand Down
20 changes: 20 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
ICHECK(loop_var.dtype().is_scalar());
ICHECK(body.defined());

// When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them
// without raising errors.
auto try_promote_imm_dtype = [&](const PrimExpr& e) {
ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
<< " Loop variable's dtype (" << loop_var.dtype()
<< ") is narrower than that of `min` or `extent` (" << e.dtype() << ")";
const IntImmNode* a = e.as<IntImmNode>();
if (a && e.dtype().bits() < loop_var.dtype().bits()) {
return make_const(loop_var.dtype(), a->value);
} else {
return e;
}
};

min = try_promote_imm_dtype(min);
extent = try_promote_imm_dtype(extent);

ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype();
ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();

ObjectPtr<ForNode> node = make_object<ForNode>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
std::vector<PrimExpr> iter_values;
// Create loop vars and block vars' binding_value
for (const Range& axis_range : cache_region->region) {
Var loop_var("ax" + std::to_string(loop_vars.size()));
Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype());
loop_vars.push_back(loop_var);
iter_values.push_back(axis_range->min + loop_var);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return For(idx, 0, var_lanes_, ForKind::kSerial, stmt);
return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial,
stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_buffer_vload_nullptr():
buf_load = tvm.tir.expr.BufferLoad(buffer=buf, indices=tvm.runtime.convert([0]))
buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load)
for_loop = tvm.tir.stmt.For(
loop_var=var, kind=0, min_val=0, extent=buf_load, body=buf_load_stmt
loop_var=var, kind=0, min_val=0, extent=tvm.tir.Cast("int32", buf_load), body=buf_load_stmt
)
buf_func = tvm.tir.PrimFunc(params={}, body=for_loop)
mod = tvm.IRModule({"main": buf_func})
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def test_device_ir(A, B):
temp[tx] = Aptr[tx]
depth = tvm.tir.log2(cast(n, "float32"))

with ib.for_range(0, depth) as i:
with ib.for_range(0, cast(tvm.tir.ceil(depth), n.dtype)) as i:
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
d = n >> (i + 1)
with ib.if_scope(tx < d):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_transform_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_convert_ssa():
var_type = ir.PointerType(ir.PrimType(dtype))
v = tir.Var("i1", var_type)
buf = tir.decl_buffer([16], dtype=dtype, data=v)
for_stmt = tir.For(v, zero, zero, tir.ForKind.SERIAL, nop)
let = tir.LetStmt(v, v, nop)
load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
seq = tir.SeqStmt([for_stmt, for_stmt, load])
seq = tir.SeqStmt([let, let, load])
func = tir.PrimFunc([], seq)
mod = tvm.IRModule({"main": func})
mod = tir.transform.InjectVirtualThread()(
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def test_vectorize_with_if():
assert isinstance(stmt.else_case, tvm.tir.For)


def test_vectorize_with_if_cond_int64():
m = te.size_var("m", dtype="int64")
A = te.placeholder((m,), name="A", dtype="float32")
B = te.compute((m,), lambda i: te.if_then_else(i < 2, A[i], A[i] * 2), name="B")
s = te.create_schedule(B.op)
x, y = s[B].split(B.op.axis[0], factor=4)
s[B].vectorize(y)
f = tvm.build(s, [A, B], "llvm")


def test_vectorize_let():
v = tvm.tir.Var("v", "float32")
ib = tvm.tir.ir_builder.create()
Expand Down

0 comments on commit 41cfd3d

Please sign in to comment.