diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index 2fdee6a46d10..d74b73341758 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -29,6 +29,14 @@ unroll, vectorized, ) -from .prim_func_frame import arg, func_attr, func_ret, prim_func, match_buffer, preflattened_buffer -from .var import Buffer from .op import * +from .prim_func_frame import ( + arg, + func_attr, + func_name, + func_ret, + match_buffer, + preflattened_buffer, + prim_func, +) +from .var import Buffer diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 59b5ce251743..386dc974d2cd 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -15,14 +15,13 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Prim Func Frame""" -from typing import Union, Dict, Any +from typing import Any, Callable, Dict, Optional, Union from tvm._ffi import register_object as _register_object +from tvm.ir import Type from tvm.tir.buffer import Buffer from tvm.tir.expr import Var -from tvm.ir import Type -from ..builder import Builder from . import _ffi_api from .base import TIRFrame @@ -32,15 +31,23 @@ class PrimFuncFrame(TIRFrame): ... -def prim_func(name) -> PrimFuncFrame: - return _ffi_api.PrimFuncFrame(name) # pylint: disable=no-member # type: ignore +def prim_func(f: Optional[Callable] = None) -> PrimFuncFrame: + if f is not None: + from tvm.script.parse import parse # pylint: disable=import-outside-toplevel + + return parse(f) + return _ffi_api.PrimFuncFrame() # pylint: disable=no-member # type: ignore + + +setattr(prim_func, "dispatch_token", "tir") def arg(name, obj) -> Union[Var, Buffer]: return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore -setattr(prim_func, "dispatch_token", "tir") +def func_name(name) -> str: + return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore def func_attr(attrs: Dict[str, Any]) -> None: @@ -65,7 +72,7 @@ def match_buffer( axis_separators=None, span=None, ) -> Buffer: - return _ffi_api.MatchBuffer( + return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore param, shape, dtype, @@ -95,7 +102,7 @@ def preflattened_buffer( axis_separators=None, span=None, ) -> None: - _ffi_api.PreflattenedBuffer( + _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore postflattened, shape, dtype, diff --git a/python/tvm/script/parse/entry.py b/python/tvm/script/parse/entry.py index b4f756b0582c..a6a3e114ce85 100644 --- a/python/tvm/script/parse/entry.py +++ b/python/tvm/script/parse/entry.py @@ -40,7 +40,6 @@ def __init__(self, program: Union[str, doc.AST]): else: self.source_name = inspect.getsourcefile(program) # type: ignore lines, self.start_line = inspect.getsourcelines(program) # type: ignore - if lines: self.start_column = len(lines[0]) - len(lines[0].lstrip()) else: @@ -48,7 +47,7 @@ def __init__(self, program: Union[str, doc.AST]): if self.start_column and lines: self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) else: - self.source = "" + self.source = "".join(lines) try: # It will cause a problem when running in Jupyter Notebook. # `mod` will be , which is a built-in module @@ -69,16 +68,16 @@ def as_ast(self) -> doc.AST: return doc.parse(self.source) -def parse( - program: Union[doc.AST, Any, str], - extra_vars: Optional[Dict[str, Any]] = None, -): +def parse(program: Union[doc.AST, Any, str]): + # TODO: `extra_vars` is a hack + from tvm.script.builder import tir as T + + extra_vars = {"T": T} program_ast = SourceCode(program).as_ast() parser = Parser() with Builder() as builder: with parser.var_table.with_frame(): - if extra_vars: - for k, v in extra_vars.items(): - parser.var_table.add(k, v) + for k, v in extra_vars.items(): + parser.var_table.add(k, v) parser.visit(program_ast) return builder.get() diff --git a/python/tvm/script/parse/tir/tir.py b/python/tvm/script/parse/tir/tir.py index 202f51614a9b..f7893e124782 100644 --- a/python/tvm/script/parse/tir/tir.py +++ b/python/tvm/script/parse/tir/tir.py @@ -70,7 +70,8 @@ def visit_with(self: Parser, node: doc.With) -> None: def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: with self.var_table.with_frame(): self.var_table.add("range", T.serial) - with T.prim_func(node.name): + with T.prim_func(): + T.func_name(node.name) with self.with_dispatch_token("tir"): # TODO: define the GlobalVar, handle the return value self.visit(node.args) diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index d4acc151c571..ecc6f97d663e 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -50,9 +50,9 @@ void PrimFuncFrameNode::ExitWithScope() { } } -PrimFuncFrame PrimFunc_(String name) { +PrimFuncFrame PrimFunc_() { ObjectPtr n = make_object(); - n->name = name; + n->name = ""; n->args.clear(); n->ret_type = TupleType::Empty(); n->buffer_map.clear(); @@ -78,6 +78,11 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { return buffer; } +void FuncName(String name) { + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + frame->name = name; +} + void FuncAttrs(Map attrs) { using namespace tvm::tir; PrimFuncFrame frame = Builder::Current()->FindFrame().value(); @@ -165,6 +170,7 @@ TVM_REGISTER_GLOBAL("script.builder.tir.Arg") LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; throw; }); +TVM_REGISTER_GLOBAL("script.builder.tir.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.builder.tir.FuncAttrs").set_body_typed(FuncAttrs); TVM_REGISTER_GLOBAL("script.builder.tir.FuncRet").set_body_typed(FuncRet); TVM_REGISTER_GLOBAL("script.builder.tir.MatchBuffer").set_body_typed(MatchBuffer); diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 4b03985a0bbd..519603343259 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -57,9 +57,10 @@ class PrimFuncFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; -PrimFuncFrame PrimFunc_(String name); +PrimFuncFrame PrimFunc_(); tvm::tir::Var Arg(String name, tvm::tir::Var var); tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer); +void FuncName(String name); void FuncAttrs(Map attrs); tvm::Type FuncRet(tvm::Type ret_type); diff --git a/tests/python/tvmscript/test_builder_basic.py b/tests/python/tvmscript/test_builder_basic.py index ee50b8a717ec..c22265119ed9 100644 --- a/tests/python/tvmscript/test_builder_basic.py +++ b/tests/python/tvmscript/test_builder_basic.py @@ -22,15 +22,16 @@ def test_builder_basic(): with Builder() as b: - with T.prim_func(name="main"): + with T.prim_func(): + T.func_name("main") T.func_attr({"global_symbol": "main"}) + T.func_ret(tvm.ir.PrimType("int8")) arg_a = T.arg("a", T.handle()) arg_b = T.arg("b", T.handle()) buffer_c = T.Buffer((128,), "float32") buffer_d = T.Buffer((128,), "float32") arg_c = T.arg("c", buffer_c) arg_d = T.arg("d", buffer_d) - T.func_ret(tvm.ir.PrimType("int8")) A = def_("A", T.match_buffer(arg_a, (128, 128, 128))) B = def_("B", T.match_buffer(arg_b, (128, 128, 128))) T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data) diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py index 2dac332feccc..8a9ecf4656dc 100644 --- a/tests/python/tvmscript/test_parse_basic.py +++ b/tests/python/tvmscript/test_parse_basic.py @@ -1,11 +1,11 @@ from tvm.script.builder import tir as T -from tvm.script.parse import parse -elementwise = """ + +# pylint: disable=unused-argument,unused-variable,invalid-name @T.prim_func def elementwise( - A: T.Buffer(shape=(128, 128, 128), dtype="float32"), - B: T.Buffer(shape=(128, 128, 128), dtype="float32"), + A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore + B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore ) -> None: for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128): with T.block("inner_block"): @@ -13,11 +13,13 @@ def elementwise( vi = T.axis.S(128, i + 1) vj = T.axis.S(128, j + 20) vk = T.axis.R(128, k - i) -""" + + +# pylint: enable=unused-argument,unused-variable,invalid-name def main(): - result = parse(elementwise, extra_vars={"T": T}) + result = elementwise print(result.script())