diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 48bb052124ee..5110aed21378 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -47,6 +47,7 @@ import tarfile import json from typing import Optional, Union, Dict, Callable, TextIO +from pathlib import Path import numpy as np import tvm @@ -258,7 +259,7 @@ def export_package( Parameters ---------- executor_factory : GraphExecutorFactoryModule - The factory containing compiled the compiled artifacts needed to run this model. + The factory containing the compiled artifacts needed to run this model. package_path : str, None Where the model should be saved. Note that it will be packaged as a .tar file. If not provided, the package will be saved to a generically named file in tmp. @@ -311,13 +312,20 @@ class TVMCPackage(object): ---------- package_path : str The path to the saved TVMCPackage that will be loaded. + + project_dir : Path, str + If given and loading a MLF file, the path to the project directory that contains the file. """ - def __init__(self, package_path: str): + def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None): self._tmp_dir = utils.tempdir() self.package_path = package_path self.import_package(self.package_path) + if project_dir and self.type != "mlf": + raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!") + self.project_dir = project_dir + def import_package(self, package_path: str): """Load a TVMCPackage from a previously exported TVMCModel. diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 76149b24ca8f..fd342a569956 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -24,7 +24,6 @@ from typing import Dict, List, Optional, Union from tarfile import ReadError import argparse -import os import sys import numpy as np @@ -161,11 +160,13 @@ def add_run_parser(subparsers, main_parser): project_info = project_.info() options_by_method = get_project_options(project_info) + mlf_path = project_info["model_library_format_path"] parser.formatter_class = ( argparse.RawTextHelpFormatter ) # Set raw help text so customized help_text format works - parser.set_defaults(valid_options=options_by_method["open_transport"]) + + parser.set_defaults(valid_options=options_by_method["open_transport"], mlf_path=mlf_path) required = any([opt["required"] for opt in options_by_method["open_transport"]]) nargs = "+" if required else "*" @@ -195,10 +196,14 @@ def drive_run(args): path = pathlib.Path(args.PATH) options = None + project_dir = None if args.device == "micro": - path = path / "model.tar" - if not path.is_file(): - TVMCException(f"Could not find model '{path}'!") + # If it's a micro device, then grab the model.tar path from Project API instead. + # args.PATH will be used too since it points to the project directory. N.B.: there is no + # way to determine the model.tar path from the project dir or vice-verse (each platform + # is free to put model.tar whereever it's convenient). + project_dir = path + path = pathlib.Path(args.mlf_path) # Check for options unavailable for micro targets. @@ -232,7 +237,7 @@ def drive_run(args): ) try: - tvmc_package = TVMCPackage(package_path=path) + tvmc_package = TVMCPackage(package_path=path, project_dir=project_dir) except IsADirectoryError: raise TVMCException(f"File {path} must be an archive, not a directory.") except FileNotFoundError: @@ -497,8 +502,7 @@ def run_module( if tvmc_package.type != "mlf": raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.") - project_dir = get_project_dir(tvmc_package.package_path) - project_dir = os.path.dirname(project_dir) + project_dir = get_project_dir(tvmc_package.project_dir) # This is guaranteed to work since project_dir was already checked when # building the dynamic parser to accommodate the project options, so no diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index ca4ab2247bd9..e56b7016ab77 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -187,6 +187,7 @@ def model_compiler(model_file, **overrides): args = {"target": "llvm", **overrides} return tvmc.compiler.compile_model(tvmc_model, package_path=package_path, **args) + # Returns a TVMCPackage return model_compiler diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 4f61aec946d7..045562ad5bb6 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -88,6 +88,33 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): assert str(exp.value) == expected_reason, on_error +def test_tvmc_import_package_project_dir(tflite_mobilenet_v1_1_quant, tflite_compile_model): + pytest.importorskip("tflite") + + # Generate a MLF archive. + compiled_model_mlf_tvmc_package = tflite_compile_model( + tflite_mobilenet_v1_1_quant, output_format="mlf" + ) + + # Import the MLF archive setting 'project_dir'. It must succeed. + mlf_archive_path = compiled_model_mlf_tvmc_package.package_path + tvmc_package = TVMCPackage(mlf_archive_path, project_dir="/tmp/foobar") + assert tvmc_package.type == "mlf", "Can't load the MLF archive passing the project directory!" + + # Generate a Classic archive. + compiled_model_classic_tvmc_package = tflite_compile_model(tflite_mobilenet_v1_1_quant) + + # Import the Classic archive setting 'project_dir'. + # It must fail since setting 'project_dir' is only support when importing a MLF archive. + classic_archive_path = compiled_model_classic_tvmc_package.package_path + with pytest.raises(TVMCException) as exp: + tvmc_package = TVMCPackage(classic_archive_path, project_dir="/tmp/foobar") + + expected_reason = "Setting 'project_dir' is only allowed when importing a MLF.!" + on_error = "A TVMCException was caught but its reason is not the expected one." + assert str(exp.value) == expected_reason, on_error + + def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model): pytest.importorskip("tflite")