Skip to content

Commit

Permalink
[TVMC] run: Don't use static path to find model.tar (apache#9712)
Browse files Browse the repository at this point in the history
* [TVMC] run: Don't use static path to find model.tar

Currently 'tvmc run' when '--device micro' is specified looks for the
model in the project directory at <project_dir>/model.tar. That works
for Zephyr but fails on Arduino because model.tar is actually located at
<project_dir>/src/model/model.tar. As a consequence 'tvmc run' when used
to run a model on Arduino exists because model.tar is never found.

This commit fixes it by using the MLF path returned by the Project API
instead of using a static path.

This commit also adds a project_dir attribute to TVMCPackage that can be
set when a MLF archive is loaded/imported so the project dir can be
conveniently found (similarly to package_path attribute).

Signed-off-by: Gustavo Romero <gustavo.romero@linaro.org>

* [TVMC] test: Add test for importing a MLF with project_dir

Add test for TVMCPackage when importing a MLF archive and setting a
project directory too. Setting a project dir is only supported when a
MLF model is imported, so it must fail on Classic format.

Signed-off-by: Gustavo Romero <gustavo.romero@linaro.org>
  • Loading branch information
gromero authored and baoxinqi committed Dec 27, 2021
1 parent 8fc376b commit 259c877
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
12 changes: 10 additions & 2 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import tarfile
import json
from typing import Optional, Union, Dict, Callable, TextIO
from pathlib import Path
import numpy as np

import tvm
Expand Down Expand Up @@ -258,7 +259,7 @@ def export_package(
Parameters
----------
executor_factory : GraphExecutorFactoryModule
The factory containing compiled the compiled artifacts needed to run this model.
The factory containing the compiled artifacts needed to run this model.
package_path : str, None
Where the model should be saved. Note that it will be packaged as a .tar file.
If not provided, the package will be saved to a generically named file in tmp.
Expand Down Expand Up @@ -311,13 +312,20 @@ class TVMCPackage(object):
----------
package_path : str
The path to the saved TVMCPackage that will be loaded.
project_dir : Path, str
If given and loading a MLF file, the path to the project directory that contains the file.
"""

def __init__(self, package_path: str):
def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None):
self._tmp_dir = utils.tempdir()
self.package_path = package_path
self.import_package(self.package_path)

if project_dir and self.type != "mlf":
raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!")
self.project_dir = project_dir

def import_package(self, package_path: str):
"""Load a TVMCPackage from a previously exported TVMCModel.
Expand Down
20 changes: 12 additions & 8 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Dict, List, Optional, Union
from tarfile import ReadError
import argparse
import os
import sys
import numpy as np

Expand Down Expand Up @@ -161,11 +160,13 @@ def add_run_parser(subparsers, main_parser):

project_info = project_.info()
options_by_method = get_project_options(project_info)
mlf_path = project_info["model_library_format_path"]

parser.formatter_class = (
argparse.RawTextHelpFormatter
) # Set raw help text so customized help_text format works
parser.set_defaults(valid_options=options_by_method["open_transport"])

parser.set_defaults(valid_options=options_by_method["open_transport"], mlf_path=mlf_path)

required = any([opt["required"] for opt in options_by_method["open_transport"]])
nargs = "+" if required else "*"
Expand Down Expand Up @@ -195,10 +196,14 @@ def drive_run(args):

path = pathlib.Path(args.PATH)
options = None
project_dir = None
if args.device == "micro":
path = path / "model.tar"
if not path.is_file():
TVMCException(f"Could not find model '{path}'!")
# If it's a micro device, then grab the model.tar path from Project API instead.
# args.PATH will be used too since it points to the project directory. N.B.: there is no
# way to determine the model.tar path from the project dir or vice-verse (each platform
# is free to put model.tar whereever it's convenient).
project_dir = path
path = pathlib.Path(args.mlf_path)

# Check for options unavailable for micro targets.

Expand Down Expand Up @@ -232,7 +237,7 @@ def drive_run(args):
)

try:
tvmc_package = TVMCPackage(package_path=path)
tvmc_package = TVMCPackage(package_path=path, project_dir=project_dir)
except IsADirectoryError:
raise TVMCException(f"File {path} must be an archive, not a directory.")
except FileNotFoundError:
Expand Down Expand Up @@ -497,8 +502,7 @@ def run_module(
if tvmc_package.type != "mlf":
raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.")

project_dir = get_project_dir(tvmc_package.package_path)
project_dir = os.path.dirname(project_dir)
project_dir = get_project_dir(tvmc_package.project_dir)

# This is guaranteed to work since project_dir was already checked when
# building the dynamic parser to accommodate the project options, so no
Expand Down
1 change: 1 addition & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def model_compiler(model_file, **overrides):
args = {"target": "llvm", **overrides}
return tvmc.compiler.compile_model(tvmc_model, package_path=package_path, **args)

# Returns a TVMCPackage
return model_compiler


Expand Down
27 changes: 27 additions & 0 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_project_dir(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

# Generate a MLF archive.
compiled_model_mlf_tvmc_package = tflite_compile_model(
tflite_mobilenet_v1_1_quant, output_format="mlf"
)

# Import the MLF archive setting 'project_dir'. It must succeed.
mlf_archive_path = compiled_model_mlf_tvmc_package.package_path
tvmc_package = TVMCPackage(mlf_archive_path, project_dir="/tmp/foobar")
assert tvmc_package.type == "mlf", "Can't load the MLF archive passing the project directory!"

# Generate a Classic archive.
compiled_model_classic_tvmc_package = tflite_compile_model(tflite_mobilenet_v1_1_quant)

# Import the Classic archive setting 'project_dir'.
# It must fail since setting 'project_dir' is only support when importing a MLF archive.
classic_archive_path = compiled_model_classic_tvmc_package.package_path
with pytest.raises(TVMCException) as exp:
tvmc_package = TVMCPackage(classic_archive_path, project_dir="/tmp/foobar")

expected_reason = "Setting 'project_dir' is only allowed when importing a MLF.!"
on_error = "A TVMCException was caught but its reason is not the expected one."
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

Expand Down

0 comments on commit 259c877

Please sign in to comment.