diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index d260c98b6721..df56e3b9825d 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -95,7 +95,8 @@ def add_compile_parser(subparsers, _): metavar=("name=value"), help="configurations to be used at compile time. This option can be provided multiple " "times, each one to set one configuration value, " - "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", + "e.g. '--pass-config relay.backend.use_auto_scheduler=0', " + "e.g. '--pass-config tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.", ) generate_target_args(parser) diff --git a/python/tvm/driver/tvmc/pass_config.py b/python/tvm/driver/tvmc/pass_config.py index 7cf0f0143e60..dde5b9c659d8 100644 --- a/python/tvm/driver/tvmc/pass_config.py +++ b/python/tvm/driver/tvmc/pass_config.py @@ -18,10 +18,41 @@ TVMC PassContext Interface """ +import importlib + import tvm from tvm.driver.tvmc import TVMCException +def load_function(full_name): + """Dynamic loading a function by the full name. + Parameters + ---------- + full_name: str + The name of a PackedFunc or a string of the form "path.to.module.func" + that indicates the module that can be imported. + You must be aware of the load order here, it first tries to find it via + TVM global function, if not find, try to import it by "importlib.import_module". + Returns + ------- + func: function or PackedFunc + The loaded fucntion. + """ + global_func = tvm.get_global_func(full_name, allow_missing=True) + if global_func is not None: + return global_func + + # split full name "path.to.module.func" into two parts ["path.to.module", "func"] + module_name, func_name = full_name.rsplit(".", 1) + + # import module and find the function + module = importlib.import_module(module_name) + if hasattr(module, func_name): + return getattr(module, func_name) + + raise TVMCException(f"No function '{func_name}' found in module '{module_name}'.") + + def get_pass_config_value(name, value, config_type): """Get a PassContext configuration value, based on its config data type. @@ -41,6 +72,8 @@ def get_pass_config_value(name, value, config_type): specified by config_type. """ + parsed_value = None + if config_type == "IntImm": # "Bool" configurations in the PassContext are recognized as # IntImm, so deal with this case here @@ -56,11 +89,44 @@ def get_pass_config_value(name, value, config_type): parsed_value = mapping_values.get(value.lower(), None) if parsed_value is None: - raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ") + raise TVMCException(f"Invalid value '{value}' for configuration '{name}'.") - if config_type == "runtime.String": + elif config_type == "runtime.String": parsed_value = value + elif config_type == "Array": + if name == "tir.add_lower_pass": + pass_list = value.split(",") + if len(pass_list) % 2 != 0: + raise TVMCException( + f"The configuration of '{name}' must be of the form " + "'tir.add_lower_pass=opt_level1,pass1,opt_evel2,pass2'" + ) + + parsed_value = [] + for i in range(0, len(pass_list), 2): + level, pass_func = pass_list[i].strip(), pass_list[i + 1].strip() + try: + level = int(level) + except ValueError: + raise TVMCException(f"Only integer is allow for configuration '{name}'.") + + # TODO (@leeexyz) We should parse configurations of each tir Pass. + # For now, we only use the defaults. Currently, There are four config nodes: + # `tir.transform.LoopPartitionConfig` + # `tir.transform.UnrollLoopConfig` + # `tir.transform.HoistIfThenElseConfig` + # `tir.transform.InjectDoubleBufferConfig` + # loading pass func and calling it to get the Pass + pass_func = load_function(pass_func)() + parsed_value.append((level, pass_func)) + else: + raise TVMCException(f"Unsupported configuration '{name}' for '{config_type}' type.") + + else: + # not raise here cause we alreay checked before calling this function + pass + return parsed_value @@ -81,7 +147,7 @@ def parse_configs(input_configs): return {} all_configs = tvm.ir.transform.PassContext.list_configs() - supported_config_types = ("IntImm", "runtime.String") + supported_config_types = ("IntImm", "runtime.String", "Array") supported_configs = [ name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types ] @@ -116,7 +182,13 @@ def parse_configs(input_configs): f"The following configurations are supported: {', '.join(supported_configs)}" ) - parsed_value = get_pass_config_value(name, value, all_configs[name]["type"]) - pass_context_configs[name] = parsed_value + config_type = all_configs[name]["type"] + parsed_value = get_pass_config_value(name, value, config_type) + + if config_type == "Array" and name in pass_context_configs: + # merge configs if the configuration exists + pass_context_configs[name].extend(parsed_value) + else: + pass_context_configs[name] = parsed_value return pass_context_configs diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py index bb815e1dc8aa..f928c8a31293 100644 --- a/tests/python/driver/tvmc/test_pass_config.py +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -16,11 +16,13 @@ # under the License. import pytest +from unittest import mock from tvm.contrib.target.vitis_ai import vitis_ai_available from tvm.driver.tvmc import TVMCException from tvm.driver.tvmc.pass_config import parse_configs +from tvm.tir.transform import PrimFuncPass def test_config_invalid_format(): @@ -71,3 +73,89 @@ def test_config_valid_multiple_configs(): assert configs["tir.detect_global_barrier"] == 10 assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" + + +def test_add_lower_pass_multi_built_in_pass(): + configs = parse_configs( + [ + "tir.add_lower_pass=1,tir.transform.UnrollLoop", + "tir.add_lower_pass=1,tir.transform.HoistIfThenElse,2,tir.transform.LoopPartition", + ] + ) + + assert len(configs["tir.add_lower_pass"]) == 3 + # opt_level: 1, pass: tir.transform.UnrollLoop + assert configs["tir.add_lower_pass"][0][0] == 1 + assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass) + # opt_level: 1, pass: tir.transform.HoistIfThenElse + assert configs["tir.add_lower_pass"][1][0] == 1 + assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass) + # opt_level: 2, pass: tir.transform.LoopPartition + assert configs["tir.add_lower_pass"][2][0] == 2 + assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass) + + +def test_add_lower_pass_multi_external_pass(): + fake_pass_1 = mock.MagicMock() + fake_pass_2 = mock.MagicMock() + fake_pass_3 = mock.MagicMock() + with mock.patch.dict( + "sys.modules", + {"fake_module": fake_pass_1, "fake_module": fake_pass_2, "fake_module": fake_pass_3}, + ): + configs = parse_configs( + [ + "tir.add_lower_pass=1,fake_module.fake_pass_1,2,fake_module.fake_pass2", + "tir.add_lower_pass=3,fake_module.fake_pass_3", + ] + ) + assert len(configs["tir.add_lower_pass"]) == 3 + # opt_level: 1, pass: fake_module.fake_pass_1 + assert configs["tir.add_lower_pass"][0][0] == 1 + # opt_level: 2, pass: fake_module.fake_pass_2 + assert configs["tir.add_lower_pass"][1][0] == 2 + # opt_level: 3, pass: fake_module.fake_pass_3 + assert configs["tir.add_lower_pass"][2][0] == 3 + + +def test_add_lower_pass_multi_mix_pass(): + fake_pass_1 = mock.MagicMock() + fake_pass_2 = mock.MagicMock() + with mock.patch.dict("sys.modules", {"fake_module": fake_pass_1, "fake_module": fake_pass_2}): + configs = parse_configs( + [ + "tir.add_lower_pass=1,fake_module.fake_pass_1,1,tir.transform.UnrollLoop", + "tir.add_lower_pass=2,fake_module.fake_pass_2,2,tir.transform.LoopPartition", + ] + ) + assert len(configs["tir.add_lower_pass"]) == 4 + # opt_level: 1, pass: fake_module.fake_pass_1 + assert configs["tir.add_lower_pass"][0][0] == 1 + # opt_level: 1, pass: tir.transform.UnrollLoop + assert configs["tir.add_lower_pass"][1][0] == 1 + assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass) + # opt_level: 2, pass: fake_module.fake_pass_2 + assert configs["tir.add_lower_pass"][2][0] == 2 + # opt_level: 2, pass: tir.transform.LoopPartition + assert configs["tir.add_lower_pass"][3][0] == 2 + assert isinstance(configs["tir.add_lower_pass"][3][1], PrimFuncPass) + + +def test_add_lower_pass_invalid_format(): + # wrong format + with pytest.raises(TVMCException): + _ = parse_configs(["tir.add_lower_pass=tir.transform.UnrollLoop,1"]) + # missing pass name + with pytest.raises(TVMCException): + _ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,3"]) + # wrong opt level + with pytest.raises(TVMCException): + _ = parse_configs(["tir.add_lower_pass=a,tir.transform.UnrollLoop"]) + # fake module + with pytest.raises(ModuleNotFoundError): + _ = parse_configs( + ["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,path.to.module.fake_func"] + ) + # real module and fake func + with pytest.raises(TVMCException): + _ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,tvm.tir.fake_func"])