From 39fcb8074b6b326fd849795a8619e09551347895 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 8 May 2021 04:31:31 +0800 Subject: [PATCH] [TensorIR] CreatePrimFunc from TE (#7987) Co-authored-by: Tianqi Chen Co-authored-by: Wuwei Lin Co-authored-by: Ruihang Lai --- include/tvm/tir/var.h | 1 + python/tvm/te/__init__.py | 1 + python/tvm/te/operation.py | 50 +++ src/te/operation/create_primfunc.cc | 306 ++++++++++++++++++ .../unittest/test_te_create_primfunc.py | 292 +++++++++++++++++ 5 files changed, 650 insertions(+) create mode 100644 src/te/operation/create_primfunc.cc create mode 100644 tests/python/unittest/test_te_create_primfunc.py diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 67e3a76f97b1..65c5c12a701b 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -298,6 +298,7 @@ class IterVar : public ObjectRef { inline operator PrimExpr() const; TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode); }; // inline implementations diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 939956c1a005..250c165caf9a 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -33,6 +33,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var from .operation import thread_axis, reduce_axis +from .operation import create_prim_func from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 32b83dff1baa..52eb591c48d4 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -17,6 +17,7 @@ """ Operation class for computation declaration.""" # pylint: disable=invalid-name from numbers import Integral as _Integral +from typing import List import tvm._ffi import tvm.tir @@ -426,3 +427,52 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): An iteration variable representing the value. """ return tvm.tir.IterVar(dom, name, 2, thread_tag, span) + + +def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from tensor expression + Parameters + ---------- + ops : List[Tensor] + The source expression. + + Example + ------- + We define a matmul kernel using following code: + + .. code-block:: python + + import tvm + from tvm import te + + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + func = create_prim_func([A, B, C]) + print(tvm.script.asscript(func)) + + If we want to use TensorIR schedule to do transformations on such kernel, + we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. + The generated function looks like: + + .. code-block:: python + + @tvm.script.tir + def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: + with tir.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(ops, list): + ops = [ops] + return _ffi_api.CreatePrimFunc(ops) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc new file mode 100644 index 000000000000..386bc539b924 --- /dev/null +++ b/src/te/operation/create_primfunc.cc @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include + +#include "../schedule/graph.h" + +namespace tvm { +namespace tir { + +/*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */ +class ProducerToBufferTransformer : public StmtExprMutator { + public: + explicit ProducerToBufferTransformer(const std::unordered_map& tensor2buffers) + : tensor2buffers_(tensor2buffers) {} + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + te::Tensor tensor = Downcast(op->producer); + auto it = tensor2buffers_.find(tensor); + ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor; + const Buffer& buffer = it->second; + return BufferLoad(buffer, op->indices); + } + + private: + /*! \brief The Map from Operations to buffers */ + const std::unordered_map& tensor2buffers_; +}; + +/*! \brief Helper data structural to store informations. */ +struct CreateFuncInfo { + /*! \brief The Tensor arg_list. */ + Array arg_list; + /*! \brief The map from each Tensor to its corresponding buffer. */ + std::unordered_map tensor2buffers; + /*! \brief The transformer from ProducerLoad to BufferLoad. */ + ProducerToBufferTransformer transformer; + /*! \brief The buffers should be allocated at function root. */ + Array root_alloc; + /*! \brief The count map to make block name unique. */ + std::unordered_map name_count; + + explicit CreateFuncInfo(Array arg_list) + : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} + + bool IsArg(const te::Tensor& tensor) const { + return std::any_of(arg_list.begin(), arg_list.end(), + [&tensor](const te::Tensor& arg) { return tensor == arg; }); + } + + String GetUniqueName(const String& prefix) { + String unique_prefix = prefix; + auto it = name_count.find(prefix); + while (name_count.count(unique_prefix)) { + unique_prefix = prefix + "_" + std::to_string(++it->second); + } + name_count[unique_prefix] = 0; + return unique_prefix; + } +}; + +BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor, + Array bindings, PrimExpr expr_body, + CreateFuncInfo* info) { + // Step 1. Push_back data_par axis and reduce_axis into block_vars. + Array iter_vars; + std::unordered_map var_map; + iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size()); + auto f_push_block_vars = [&iter_vars, &var_map](const Array& iters) { + for (IterVar iter_var : iters) { + // Create new var + Var new_var(iter_var->var->name_hint, iter_var->var->dtype); + var_map[iter_var->var.get()] = new_var; + + IterVarNode* iter_var_node = iter_var.CopyOnWrite(); + iter_var_node->dom = Range::FromMinExtent(iter_var->dom->min, iter_var->dom->extent); + iter_var_node->var = new_var; + iter_vars.push_back(iter_var); + } + }; + f_push_block_vars(compute_op->axis); + f_push_block_vars(compute_op->reduce_axis); + + // Step 2. Declare buffer and update op2buffers + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint()); + info->tensor2buffers[tensor] = buffer; + + // Step 3. Add Buffer to root_alloc + if (!info->IsArg(tensor)) { + info->root_alloc.push_back(buffer); + } + + // Step 4. Calculate indices for BufferStore + Array indices; + indices.reserve(compute_op->axis.size()); + for (const IterVar& iter_var : compute_op->axis) { + auto it = var_map.find(iter_var->var.get()); + ICHECK(it != var_map.end()); + indices.push_back(it->second); + } + + // Step 5. Create block body. + Optional init = NullOpt; + Stmt body; + if (const auto* reduce = expr_body.as()) { + // Case 1. Reduce compute + ICHECK_EQ(reduce->source.size(), 1); + const PrimExpr& lhs = BufferLoad(buffer, indices); + const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map); + ICHECK(lhs->dtype == rhs->dtype); + body = BufferStore(buffer, reduce->combiner.get()->operator()({lhs}, {rhs})[0], indices); + init = BufferStore(buffer, reduce->combiner->identity_element[0], indices); + } else { + // Case 2. Data parallel compute + body = BufferStore(buffer, Substitute(info->transformer(expr_body), var_map), indices); + } + + // Step 6. Add script_parsing_detect_access attr for auto complete the whole IR. + Map annotations = compute_op->attrs; + annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3)); + + // Step 7. Create Block and BlockRealize. + return BlockRealize(/*iter_values=*/std::move(bindings), + /*predicate=*/Bool(true), + /*block=*/ + Block(/*iter_vars=*/std::move(iter_vars), + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/info->GetUniqueName(tensor->GetNameHint()), + /*body=*/std::move(body), + /*init=*/std::move(init), + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/std::move(annotations))); +} + +Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info) { + // Step 1. Creating loop vars for block bindings. + Array axes = compute_op->axis; + axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); + Array bindings; + for (size_t i = 0; i < axes.size(); ++i) { + bindings.push_back(Var("i" + std::to_string(i))); + } + // Step 2. Generate block bodies. + Array seq_stmt; + for (int i = 0; i < compute_op->num_outputs(); ++i) { + const te::Tensor& tensor = compute_op.output(i); + PrimExpr expr_body = compute_op->body[i]; + seq_stmt.push_back( + GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), info)); + } + Stmt body = SeqStmt::Flatten(seq_stmt); + + // Step 3. Generate loop nesting. + for (size_t i = axes.size(); i > 0; --i) { + const IterVar& axis = axes[i - 1]; + const Var& loop_var = Downcast(bindings[i - 1]); + body = For(loop_var, axis->dom->min, axis->dom->extent, ForKind::kSerial, body); + } + + return body; +} + +Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) { + // Step 1. Check all inputs are visited before and update var_map. + std::unordered_map var_map; + ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); + for (size_t i = 0; i < extern_op->inputs.size(); ++i) { + const Buffer& placeholder = extern_op->input_placeholders[i]; + const te::Tensor& input_tensor = extern_op->inputs[i]; + auto it = info->tensor2buffers.find(input_tensor); + ICHECK(it != info->tensor2buffers.end()); + var_map[placeholder->data.get()] = it->second->data; + } + + // Step 2. Update info with its output tensor and placeholder buffer. + ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size()); + for (int i = 0; i < extern_op->num_outputs(); ++i) { + const Buffer& placeholder = extern_op->output_placeholders[i]; + const te::Tensor& output_tensor = extern_op.output(i); + info->tensor2buffers[output_tensor] = placeholder; + if (!info->IsArg(output_tensor)) { + info->root_alloc.push_back(placeholder); + } + } + + // Step 3. Collect Access Region + Array reads, writes; + for (const te::Tensor& tensor : extern_op->inputs) { + // We have ICHECK before so it is not needed here. + reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor])); + } + for (const Buffer& buffer : extern_op->output_placeholders) { + writes.push_back(BufferRegion::FullRegion(buffer)); + } + + Stmt body = Substitute(extern_op->body, var_map); + + // Step 4. Generate opaque block as body. + return BlockRealize(/*iter_values=*/{}, + /*predicate=*/Bool(true), + /*block=*/ + Block(/*iter_vars=*/{}, + /*reads=*/std::move(reads), + /*writes=*/std::move(writes), + /*name_hint=*/info->GetUniqueName(extern_op->name), + /*body=*/std::move(body), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/extern_op->attrs)); +} + +/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ +PrimFunc CreatePrimFunc(const Array& arg_list) { + // Step 1. Create tensor read graph. + Array arg_ops; + for (const te::Tensor& arg : arg_list) { + arg_ops.push_back(arg->op); + } + te::ReadGraph g = te::CreateReadGraph(arg_ops); + Array order = te::PostDFSOrder(arg_ops, g); + + // Step 2. Checking all Operations are supported. + for (const te::Operation& op : order) { + if (!(op->IsInstance() || op->IsInstance() || + op->IsInstance())) + LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " + << "Only te.placeholder and te.compute are allowed for now."; + } + + // Infomations used in CreatePrimFunc and its sub-funtions. + CreateFuncInfo info(arg_list); + // Root body stmts. + Array root_stmts; + + // Step 3. Rewrite compute stages into blocks. + for (const te::Operation& op : order) { + if (const auto* placeholder = op.as()) { + // Case 1. PlaceholderOp (te.placeholder) + ICHECK_EQ(op->num_outputs(), 1); + const te::Tensor& tensor = op.output(0); + // Check op is in op list + ICHECK(info.IsArg(tensor)); + const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + info.tensor2buffers[tensor] = buffer; + } else if (const auto* compute_op = op.as()) { + // Case 2. ComputeOp (te.compute) + root_stmts.push_back(GenerateStmtFromCompute(GetRef(compute_op), &info)); + } else if (const auto extern_op = op.as()) { + // Case 3. ExternOp (te.extern) + root_stmts.push_back(GenerateStmtFromExternOp(GetRef(extern_op), &info)); + } else { + ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " + << "Only te.placeholder and te.compute are allowed for now."; + } + } + + // Step 4. Create func and complete it. + Array parameters; + Map buffer_map; + for (const te::Tensor& tensor : arg_list) { + Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); + parameters.push_back(arg); + auto it = info.tensor2buffers.find(tensor); + ICHECK(it != info.tensor2buffers.end()); + buffer_map.Set(arg, it->second); + } + PrimFunc func = PrimFunc(/*params=*/std::move(parameters), + /*body=*/SeqStmt::Flatten(root_stmts), + /*ret_type=*/VoidType(), + /*buffer_map=*/std::move(buffer_map)); + + const auto* complete = runtime::Registry::Get("script.Complete"); + ICHECK(complete); + + return (*complete)(func, info.root_alloc); +} // namespace tir + +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array& tensors) { + return CreatePrimFunc(tensors); +}); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py new file mode 100644 index 000000000000..b3ef8d5570b6 --- /dev/null +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +from tvm.script import ty +from tvm import te, tir + + +def test_unique_name(): + A = te.placeholder((16, 16), name="A") + B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") + C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") + func = te.create_prim_func([A, C]) + s = tir.Schedule(func, debug_mode=True) + assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) + + +def _check_workload(te_workload, tir_workload): + func = te.create_prim_func(te_workload()) + tvm.ir.assert_structural_equal(func, tir_workload) + # make sure that we can create schedule from the func + s = tir.Schedule(func, debug_mode=True) + assert s + + +def te_matmul(): + k = te.reduce_axis((0, 128), "k") + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + return [A, B, C] + + +@tvm.script.tir +def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: + with tir.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + +def test_matmul(): + _check_workload(te_matmul, tir_matmul) + + +def te_element_wise(): + A = te.placeholder((128, 128), name="A") + B = te.compute((128, 128), lambda x, y: A[x, y] * 2, name="B") + C = te.compute((128, 128), lambda x, y: B[x, y] + 1, name="C") + return [A, C] + + +@tvm.script.tir +def tir_element_wise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + + with tir.block([128, 128]) as [i, j]: + B[i, j] = A[i, j] * 2.0 + with tir.block([128, 128]) as [i, j]: + C[i, j] = B[i, j] + 1.0 + + +def test_element_wise(): + _check_workload(te_element_wise, tir_element_wise) + + +def te_conv2d(): + batch = 16 + in_channel = 16 + out_channel = 32 + size = 14 + kernel = 3 + + A = te.placeholder((batch, in_channel, size, size), name="A") + W = te.placeholder((in_channel, kernel, kernel, out_channel), name="W") + Apad = te.compute( + (batch, in_channel, size + 2, size + 2), + lambda nn, cc, yy, xx: tvm.tir.if_then_else( + tvm.tir.all(yy >= 1, yy - 1 < size, xx >= 1, xx - 1 < size), + A[nn, cc, yy - 1, xx - 1], + 0.0, + ), + name="Apad", + ) + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel), name="ry") + rx = te.reduce_axis((0, kernel), name="rx") + B = te.compute( + (batch, out_channel, size, size), + lambda nn, ff, yy, xx: te.sum( + Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff], axis=[rc, ry, rx] + ), + name="B", + ) + return [A, W, B] + + +@tvm.script.tir +def tir_conv2d(a: ty.handle, w: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 14, 14]) + W = tir.match_buffer(w, [16, 3, 3, 32]) + B = tir.match_buffer(b, [16, 32, 14, 14]) + Apad = tir.alloc_buffer([16, 16, 16, 16]) + + with tir.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: + Apad[nn, cc, yy, xx] = tir.if_then_else( + yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, + A[nn, cc, yy - 1, xx - 1], + 0.0, + dtype="float32", + ) + with tir.block( + [16, 32, 14, 14, tir.reduce_axis(0, 16), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "B" + ) as [nn, ff, yy, xx, rc, ry, rx]: + with tir.init(): + B[nn, ff, yy, xx] = 0.0 + B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] + + +def test_conv2d(): + _check_workload(te_conv2d, tir_conv2d) + + +def te_multi_output(): + n = te.var("n") + m = te.var("m") + A0 = te.placeholder((m, n), name="A0") + A1 = te.placeholder((m, n), name="A1") + B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B") + return [A0, A1, B0, B1] + + +@tvm.script.tir +def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None: + m = tir.var("int32") + n = tir.var("int32") + A0 = tir.match_buffer(a0, (m, n)) + A1 = tir.match_buffer(a1, (m, n)) + B0 = tir.match_buffer(b0, (m, n)) + B1 = tir.match_buffer(b1, (m, n)) + + for i0, i1 in tir.grid(m, n): + with tir.block([m, n], "B.v0") as [i, j]: + B0[i, j] = A0[i, j] + 2.0 + with tir.block([m, n], "B.v1") as [i, j]: + B1[i, j] = A1[i, j] * 3.0 + + +def test_multi_output(): + _check_workload(te_multi_output, tir_multi_output) + + +def te_extern(): + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + C = te.extern( + (128, 128), + [A, B], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0 + ), + name="C", + ) + return [A, B, C] + + +@tvm.script.tir +def tir_extern(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + # body + with tir.block([], "C"): + tir.reads([A[0:128, 0:128], B[0:128, 0:128]]) + tir.writes([C[0:128, 0:128]]) + tir.evaluate( + tir.tvm_call_packed( + "tvm.contrib.cblas.matmul", + tir.tvm_stack_make_array( + A.data, + tir.tvm_stack_make_shape(128, 128, dtype="handle"), + 0, + 2, + 0.0, + 0, + dtype="handle", + ), + tir.tvm_stack_make_array( + B.data, + tir.tvm_stack_make_shape(128, 128, dtype="handle"), + 0, + 2, + 0.0, + 0, + dtype="handle", + ), + tir.tvm_stack_make_array( + C.data, + tir.tvm_stack_make_shape(128, 128, dtype="handle"), + 0, + 2, + 0.0, + 0, + dtype="handle", + ), + 0, + 0, + dtype="int32", + ) + ) + + +def test_extern(): + _check_workload(te_extern, tir_extern) + + +def te_reordered_matmul(): + k = te.reduce_axis((0, 128), "k") + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + return [C, A, B] + + +@tvm.script.tir +def tir_reordered_matmul(c: ty.handle, a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: + with tir.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + +def test_arg_order(): + _check_workload(te_reordered_matmul, tir_reordered_matmul) + + +def te_scan(): + m = te.var("m") + n = te.var("n") + X = te.placeholder((m, n), name="X") + s_state = te.placeholder((m, n)) + s_init = te.compute((1, n), lambda _, i: X[0, i]) + s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) + s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X]) + return [X, s_scan] + + +def test_error_reporting(): + try: + te.create_prim_func(te_scan()) + assert False + except TypeError as e: + error_message = str(e) + assert error_message.find("Unsupported Operation: ScanOp.") != -1 + return + assert False + + +if __name__ == "__main__": + test_unique_name() + test_matmul() + test_element_wise() + test_conv2d() + test_multi_output() + test_extern() + test_arg_order() + test_error_reporting()