Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] run: Don't use static path to find model.tar #9712

Merged
merged 2 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import tarfile
import json
from typing import Optional, Union, Dict, Callable, TextIO
from pathlib import Path
import numpy as np

import tvm
Expand Down Expand Up @@ -258,7 +259,7 @@ def export_package(
Parameters
----------
executor_factory : GraphExecutorFactoryModule
The factory containing compiled the compiled artifacts needed to run this model.
The factory containing the compiled artifacts needed to run this model.
package_path : str, None
Where the model should be saved. Note that it will be packaged as a .tar file.
If not provided, the package will be saved to a generically named file in tmp.
Expand Down Expand Up @@ -311,13 +312,20 @@ class TVMCPackage(object):
----------
package_path : str
The path to the saved TVMCPackage that will be loaded.

project_dir : Path, str
If given and loading a MLF file, the path to the project directory that contains the file.
"""

def __init__(self, package_path: str):
def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None):
self._tmp_dir = utils.tempdir()
self.package_path = package_path
self.import_package(self.package_path)

if project_dir and self.type != "mlf":
raise TVMCException("Setting 'project_dir' is only allowed when importing a MLF.!")
self.project_dir = project_dir

def import_package(self, package_path: str):
"""Load a TVMCPackage from a previously exported TVMCModel.

Expand Down
20 changes: 12 additions & 8 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Dict, List, Optional, Union
from tarfile import ReadError
import argparse
import os
import sys
import numpy as np

Expand Down Expand Up @@ -161,11 +160,13 @@ def add_run_parser(subparsers, main_parser):

project_info = project_.info()
options_by_method = get_project_options(project_info)
mlf_path = project_info["model_library_format_path"]

parser.formatter_class = (
argparse.RawTextHelpFormatter
) # Set raw help text so customized help_text format works
parser.set_defaults(valid_options=options_by_method["open_transport"])

parser.set_defaults(valid_options=options_by_method["open_transport"], mlf_path=mlf_path)

required = any([opt["required"] for opt in options_by_method["open_transport"]])
nargs = "+" if required else "*"
Expand Down Expand Up @@ -195,10 +196,14 @@ def drive_run(args):

path = pathlib.Path(args.PATH)
options = None
project_dir = None
if args.device == "micro":
path = path / "model.tar"
if not path.is_file():
TVMCException(f"Could not find model '{path}'!")
# If it's a micro device, then grab the model.tar path from Project API instead.
# args.PATH will be used too since it points to the project directory. N.B.: there is no
# way to determine the model.tar path from the project dir or vice-verse (each platform
# is free to put model.tar whereever it's convenient).
project_dir = path
path = pathlib.Path(args.mlf_path)

# Check for options unavailable for micro targets.

Expand Down Expand Up @@ -232,7 +237,7 @@ def drive_run(args):
)

try:
tvmc_package = TVMCPackage(package_path=path)
tvmc_package = TVMCPackage(package_path=path, project_dir=project_dir)
except IsADirectoryError:
raise TVMCException(f"File {path} must be an archive, not a directory.")
except FileNotFoundError:
Expand Down Expand Up @@ -497,8 +502,7 @@ def run_module(
if tvmc_package.type != "mlf":
raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.")

project_dir = get_project_dir(tvmc_package.package_path)
project_dir = os.path.dirname(project_dir)
project_dir = get_project_dir(tvmc_package.project_dir)

# This is guaranteed to work since project_dir was already checked when
# building the dynamic parser to accommodate the project options, so no
Expand Down
1 change: 1 addition & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def model_compiler(model_file, **overrides):
args = {"target": "llvm", **overrides}
return tvmc.compiler.compile_model(tvmc_model, package_path=package_path, **args)

# Returns a TVMCPackage
return model_compiler


Expand Down
27 changes: 27 additions & 0 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_project_dir(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

# Generate a MLF archive.
compiled_model_mlf_tvmc_package = tflite_compile_model(
tflite_mobilenet_v1_1_quant, output_format="mlf"
)

# Import the MLF archive setting 'project_dir'. It must succeed.
mlf_archive_path = compiled_model_mlf_tvmc_package.package_path
tvmc_package = TVMCPackage(mlf_archive_path, project_dir="/tmp/foobar")
assert tvmc_package.type == "mlf", "Can't load the MLF archive passing the project directory!"

# Generate a Classic archive.
compiled_model_classic_tvmc_package = tflite_compile_model(tflite_mobilenet_v1_1_quant)

# Import the Classic archive setting 'project_dir'.
# It must fail since setting 'project_dir' is only support when importing a MLF archive.
classic_archive_path = compiled_model_classic_tvmc_package.package_path
with pytest.raises(TVMCException) as exp:
tvmc_package = TVMCPackage(classic_archive_path, project_dir="/tmp/foobar")

expected_reason = "Setting 'project_dir' is only allowed when importing a MLF.!"
on_error = "A TVMCException was caught but its reason is not the expected one."
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

Expand Down