From 834e975cfe29d44627aedd9d30f720bdd5881d0b Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Fri, 2 Dec 2022 14:49:10 +0200 Subject: [PATCH] [Relax][AOT] Add pass that lowers relax.builtin.alloc_tensor to relax.memory.alloc_{storage,tensor} when USMP is not enabled --- CMakeLists.txt | 1 + include/tvm/relax/backend.h | 8 ++ python/tvm/relax/aot.py | 4 +- python/tvm/relax/transform/transform.py | 11 ++ src/relax/backend/aot/aot_lower_main.cc | 4 +- src/relax/backend/aot/aot_memory_lower.cc | 116 ++++++++++++++++++ src/relax/backend/aot/codegen_aot.cc | 13 +- src/target/llvm/codegen_llvm.cc | 2 +- tests/python/relax/aot/test_aot_build.py | 47 ++++--- .../relax/aot/test_pass_aot_lower_main.py | 36 ++++-- tests/python/relax/test_transform.py | 35 ++++++ 11 files changed, 246 insertions(+), 31 deletions(-) create mode 100644 src/relax/backend/aot/aot_memory_lower.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e5bfcb326541..04f3f2f5aee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -291,6 +291,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/usmp/*.cc src/relax/transform/*.cc + src/relax/backend/aot/*.cc src/relax/backend/vm/*.cc src/relax/backend/aot/*.cc src/relax/backend/task_extraction.cc diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 596905ae9d90..85df333137a8 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -44,6 +44,14 @@ TVM_DLL Pass VMMemoryLower(); */ TVM_DLL Pass VMShapeLower(); +/*! + * \brief Perform memory lowering in AOT. Lowers the relax.builtin.alloc_tensor intrinsic to + * relax.memory.* intrinsics. + * + * \return The Pass. + */ +TVM_DLL Pass AOTMemoryLower(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/aot.py b/python/tvm/relax/aot.py index 2109d6d74508..0302735ee7b3 100644 --- a/python/tvm/relax/aot.py +++ b/python/tvm/relax/aot.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, redefined-builtin, no-else-return -"""The Relax virtual machine""" +"""The Relax AOT executor""" from typing import Callable, List, Optional, Union, Dict import tvm @@ -63,7 +63,7 @@ def build( if not isinstance(ir_mod, IRModule): raise ValueError("Type of input parameter mod must be tvm.IRModule") - ctxt = tvm.transform.PassContext() + ctxt = tvm.transform.PassContext.current() config = make_compilation_config(ctxt, target, target_host) ir_mod = lower(ir_mod) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 702548b9e3c2..3f3eb276997b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -122,6 +122,17 @@ def VMShapeLower() -> tvm.ir.transform.Pass: return _ffi_api.VMShapeLower() # type: ignore +def AOTMemoryLower() -> tvm.ir.transform.Pass: + """Perform memory lowering in AOT. Lowers the relax.builtin.alloc_tensor intrinsic to + relax.memory.* intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AOTMemoryLower() # type: ignore + + def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. diff --git a/src/relax/backend/aot/aot_lower_main.cc b/src/relax/backend/aot/aot_lower_main.cc index 800a5f1dd676..75dd28a742d9 100644 --- a/src/relax/backend/aot/aot_lower_main.cc +++ b/src/relax/backend/aot/aot_lower_main.cc @@ -58,7 +58,7 @@ class AOTMainLowerer : public ExprVisitor { IRModule Lower(IRModule mod, String mod_name) { IRModule lowered_mod = GetRef(mod.CopyOnWrite()); - + auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -76,7 +76,7 @@ class AOTMainLowerer : public ExprVisitor { .value_or(Map())); VisitExpr(lowered_main_func); - + // Remove the Relay main and replace it with the lowered TIR version mod->Remove(lowered_mod->GetGlobalVar("main")); auto tir_main_func = CreateMainFunc(mod_name); diff --git a/src/relax/backend/aot/aot_memory_lower.cc b/src/relax/backend/aot/aot_memory_lower.cc new file mode 100644 index 000000000000..6995d783fbab --- /dev/null +++ b/src/relax/backend/aot/aot_memory_lower.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/aot/aot_memory_lower.cc + * \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to + * relax.memory.alloc_storage + relax.memory.alloc_tensor. + */ +#include +#include +#include +#include +#include + +#include "../../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// MemLowerMutator +// Lower the relax.builtin.alloc_tensor op to relax.memory builtin functions. +// Example: +// x = relax.builtin.alloc_tensor((m, n), relax.attrs.AllocTensorAttrs) +// --> +// gv0 = relax.memory.alloc_storage(m * n * dtype, relax.attrs.MemAllocStorageAttrs) +// gv1 = relax.memory.alloc_tensor(gv0, (m, n), relax.attrs.MemAllocTensorAttrs) + +class AOTMemLowerMutator : public ExprMutator { + + // TODO(gigiblender): Dedup this function with the one in VMMemoryLower. + Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const { + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes()); + PrimExpr add = num + 7; + PrimExpr ret = 1; + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + ret = ret * (add / PrimExpr(8)); + return ShapeExpr({ret}); + } + // Fully dynamic shape case will need to dedup with ComputeStorageInRelay when we upstream + Expr prod = relay::Prod(shape, Array(nullptr), false, false); + Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); + Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7)); + Expr div = relay::MakeConstantScalar(DataType::Int(64), 8); + Expr ret = relay::Multiply(prod, relay::Divide(add, div)); + return ret; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& memory_alloc_storage_op = Op::Get("relax.memory.alloc_storage"); + static const Op& memory_alloc_tensor_op = Op::Get("relax.memory.alloc_tensor"); + if (call->op == alloc_tensor_op) { + ShapeExpr output_shape = Downcast(call->args[0]); + auto alloc_attrs = call->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs"; + DataType dtype = alloc_attrs->dtype; + Expr storage_size = ComputeStorageSize(output_shape, dtype); + auto storage_attr = make_object(); + storage_attr->dtype = dtype; + + Var storage = + builder_->Emit(Call(memory_alloc_storage_op, {storage_size}, Attrs(storage_attr)), + "storage"); + auto tensor_attr = make_object(); + tensor_attr->offset = 0; + tensor_attr->dtype = dtype; + Expr shape = call->args[0]; + return Call(memory_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)); + } + + return GetRef(call); + } +}; + +Expr AOTMemoryLower(const Expr& e) { return AOTMemLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass AOTMemoryLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(AOTMemoryLower(f)); }; + return CreateFunctionPass(pass_func, 0, "AOTMemoryLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AOTMemoryLower").set_body_typed(AOTMemoryLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/aot/codegen_aot.cc b/src/relax/backend/aot/codegen_aot.cc index 5a3da9756194..12b767d25995 100644 --- a/src/relax/backend/aot/codegen_aot.cc +++ b/src/relax/backend/aot/codegen_aot.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,7 @@ #include #include #include +#include #include #include @@ -64,8 +66,15 @@ runtime::Module Build(IRModule mod, String mod_name, CompilationConfig config, r Integer constant_byte_alignment = executor->GetAttr("constant-byte-alignment").value_or(16); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_usmp = pass_ctx->GetConfig(kUSMPRelaxEnableOption, Bool(false)).value(); + mod = LowerModule(mod); - mod = relax::transform::UnifiedStaticMemoryPlanner()(mod); + if (enable_usmp) { + mod = relax::transform::UnifiedStaticMemoryPlanner()(mod); + } else { + mod = relax::transform::AOTMemoryLower()(mod); + } mod = AOTLowerMain(mod_name, config)(mod); mod = tir::transform::LegalizePackedCalls()(mod); @@ -85,4 +94,4 @@ TVM_REGISTER_GLOBAL("relax.aot.build") } // namespace aot } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index aa25b1436737..3e0664e64492 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1784,7 +1784,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); - param_symbol->setAlignment(data.DataType().bits()); + param_symbol->setAlignment(llvm::Align(data.DataType().bits())); var_map_[op->buffer_var.operator->()] = param_symbol; this->VisitStmt(op->body); } diff --git a/tests/python/relax/aot/test_aot_build.py b/tests/python/relax/aot/test_aot_build.py index a84a783a113d..9bb811ac1d01 100644 --- a/tests/python/relax/aot/test_aot_build.py +++ b/tests/python/relax/aot/test_aot_build.py @@ -35,7 +35,8 @@ def _export_mod(mod): return tvm.runtime.load_module(test_so_path) -def test_single_elementwise(): +@pytest.mark.parametrize("enable_usmp", [True, False]) +def test_single_elementwise(enable_usmp): dtype = "int32" target = "llvm" inputs = {"x": np.array([[-10, 5], [1, 2]], dtype=dtype)} @@ -48,13 +49,14 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.abs(x) # abs - + relax_mod = relay_translator.from_relay( _relay(), target, ) - mod = build(relax_mod, target) + with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}): + mod = build(relax_mod, target) loaded_mod = _export_mod(mod) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) @@ -62,7 +64,8 @@ def _reference(inputs): assert (runner.get_output(0).numpy() == _reference(inputs)).all() -def test_scalar_constant(): +@pytest.mark.parametrize("enable_usmp", [True, False]) +def test_scalar_constant(enable_usmp): dtype = "int32" target = "llvm" inputs = {"x": np.array([[-10, 5], [1, 2]], dtype=dtype)} @@ -75,13 +78,14 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.add(x, -1) # add - + relax_mod = relay_translator.from_relay( _relay(), target, ) - mod = build(relax_mod, target) + with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}): + mod = build(relax_mod, target) loaded_mod = _export_mod(mod) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) @@ -89,7 +93,8 @@ def _reference(inputs): assert (runner.get_output(0).numpy() == _reference(inputs)).all() -def test_tensor_constant(): +@pytest.mark.parametrize("enable_usmp", [True, False]) +def test_tensor_constant(enable_usmp): dtype = "int32" target = "llvm" inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype)} @@ -102,13 +107,14 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.add(x, np.array([[1, 2], [3, 4]])) # add - + relax_mod = relay_translator.from_relay( _relay(), target, ) - mod = build(relax_mod, target) + with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}): + mod = build(relax_mod, target) loaded_mod = _export_mod(mod) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) @@ -116,10 +122,14 @@ def _reference(inputs): assert (runner.get_output(0).numpy() == _reference(inputs)).all() -def test_multi_input(): +@pytest.mark.parametrize("enable_usmp", [True, False]) +def test_multi_input(enable_usmp): dtype = "int32" target = "llvm" - inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype), "y": np.array([[1, 2], [3, 4]], dtype=dtype)} + inputs = { + "x": np.array([[-10, 1], [5, 1]], dtype=dtype), + "y": np.array([[1, 2], [3, 4]], dtype=dtype), + } def _relay(): x = relay.var("x", shape=(2, 2), dtype=dtype) @@ -131,13 +141,14 @@ def _reference(inputs): x = inputs["x"] y = inputs["y"] return np.add(x, y) # add - + relax_mod = relay_translator.from_relay( _relay(), target, ) - mod = build(relax_mod, target) + with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}): + mod = build(relax_mod, target) loaded_mod = _export_mod(mod) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) @@ -145,7 +156,8 @@ def _reference(inputs): assert (runner.get_output(0).numpy() == _reference(inputs)).all() -def test_multi_output(): +@pytest.mark.parametrize("enable_usmp", [True, False]) +def test_multi_output(enable_usmp): dtype = "int32" target = "llvm" inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype)} @@ -159,16 +171,17 @@ def _relay(): def _reference(inputs): x = inputs["x"] - abs = np.abs(x) # abs + abs = np.abs(x) # abs out = abs - 1 return [abs, out] - + relax_mod = relay_translator.from_relay( _relay(), target, ) - mod = build(relax_mod, target) + with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}): + mod = build(relax_mod, target) loaded_mod = _export_mod(mod) runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) runner.set_input(**inputs) diff --git a/tests/python/relax/aot/test_pass_aot_lower_main.py b/tests/python/relax/aot/test_pass_aot_lower_main.py index 7a0d5380f072..c0a664f7921f 100644 --- a/tests/python/relax/aot/test_pass_aot_lower_main.py +++ b/tests/python/relax/aot/test_pass_aot_lower_main.py @@ -124,7 +124,11 @@ def test_multi_input(): @tvm.script.ir_module class MultiInput: @R.function - def main(a: R.Tensor((5, 7), "float32"), b: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): + def main( + a: R.Tensor((5, 7), "float32"), + b: R.Tensor((5, 7), "float32"), + output: R.Tensor((5, 7), "float32"), + ): R.func_attr({"input_vars": [a, b], "output_vars": [output]}) tid_0 = output _ = R.call_packed("add", a, b, tid_0, type_args=R.Tensor(ndim=2, dtype="float32")) @@ -149,11 +153,17 @@ def test_multi_output(): @tvm.script.ir_module class MultiOutput: @R.function - def main(a: R.Tensor((5, 7), "float32"), output_0: R.Tensor((5, 7), "float32"), output_1: R.Tensor((5, 7), "float32")): + def main( + a: R.Tensor((5, 7), "float32"), + output_0: R.Tensor((5, 7), "float32"), + output_1: R.Tensor((5, 7), "float32"), + ): R.func_attr({"input_vars": [a], "output_vars": [output_0, output_1]}) tid_0 = output_0 tid_1 = output_1 - _ = R.call_packed("duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -202,7 +212,9 @@ class TupleGetItem: def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): R.func_attr({"input_vars": [a], "output_vars": [output]}) tup = (a, a) - _ = R.call_packed("identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -235,7 +247,9 @@ def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): tid_2 = R.memory.alloc_tensor(alloc_2, (5, 7), offset=0, dtype="float32") _ = R.call_packed("identity", tid_0, tid_2, type_args=R.Tensor(ndim=2, dtype="float32")) tid_3 = output - _ = R.call_packed("add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -268,9 +282,17 @@ def test_device_hooks(): @tvm.script.ir_module class DeviceHooks: @T.prim_func - def identity(a: T.handle, output: T.handle, device_context_example_target_hook: T.handle) -> None: + def identity( + a: T.handle, output: T.handle, device_context_example_target_hook: T.handle + ) -> None: # function attr dict - T.func_attr({"global_symbol": "identity", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]})}) + T.func_attr( + { + "global_symbol": "identity", + "runner_function": True, + "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), + } + ) a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) # body diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 86f85a974042..a17c6b3237dd 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -322,6 +322,41 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert s2.op.global_symbol == "test.op.identity" +def test_aot_memory_lower(): + # fmt:off + @tvm.script.ir_module + class TestAOTMemoryLower: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: + m, n = T.var("int64"), T.var("int64") + alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + _ = R.call_packed("test.op.identity", x, alloc, type_args=(R.Tensor(ndim=2, dtype="float32"))) + gv0 = alloc + return gv0 + # fmt:on + + mod = TestAOTMemoryLower + + # after aot memory lowering + new_mod = relax.transform.AOTMemoryLower()(mod) + func = new_mod["foo"] + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.memory.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, tvm.relay.Call) + assert s2.op.name == "relax.memory.alloc_tensor" + s3 = block.bindings[2].value + assert isinstance(s3, tvm.relay.Call) + assert isinstance(s3.op, relax.ExternFunc) + assert s3.op.global_symbol == "test.op.identity" + + def test_vm_memory_lower(): @tvm.script.ir_module class TestVMMemoryLower: