Skip to content

Commit

Permalink
Improve tvmc error message from lazy-loading frontend imports (apache…
Browse files Browse the repository at this point in the history
…#9074)

When installing TVM from the python package, the Frontend frameworks dependencies such as TensorFlow, PyTorch, ONNX, etc, are not installed by default.
In case a user tries to run tvmc using a model whose framework was not installed, it will be presented with a very raw Python exception in the output.
The aim of this commit is to implement a better error messages for errors related to lazy-loading frontend frameworks in tvmc.

Change-Id: Ida52fac4116af392ee436390e14ea02c7090cef0
  • Loading branch information
ophirfrish authored and baoxinqi committed Dec 27, 2021
1 parent f2195b1 commit 5c5bab9
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 25 deletions.
4 changes: 4 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def exit(self, status=0, message=None):
raise TVMCException()


class TVMCImportError(TVMCException):
"""TVMC TVMCImportError"""


def convert_graph_layout(mod, desired_layout):
"""Alter the layout of the input graph.
Expand Down
43 changes: 18 additions & 25 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
loading the tool.
"""
import logging
import os
import sys
import importlib
from abc import ABC
from abc import abstractmethod
from typing import Optional, List, Dict
Expand All @@ -32,6 +32,7 @@

from tvm import relay
from tvm.driver.tvmc.common import TVMCException
from tvm.driver.tvmc.common import TVMCImportError
from tvm.driver.tvmc.model import TVMCModel


Expand Down Expand Up @@ -76,20 +77,15 @@ def load(self, path, shape_dict=None, **kwargs):
"""


def import_keras():
"""Lazy import function for Keras"""
# Keras writes the message "Using TensorFlow backend." to stderr
# Redirect stderr during the import to disable this
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
def lazy_import(pkg_name, from_pkg_name=None, hide_stderr=False):
"""Lazy import a frontend package or subpackage"""
try:
# pylint: disable=C0415
import tensorflow as tf
from tensorflow import keras

return tf, keras
return importlib.import_module(pkg_name, package=from_pkg_name)
except ImportError as error:
raise TVMCImportError(pkg_name) from error
finally:
sys.stderr = stderr
if hide_stderr:
sys.stderr = stderr


class KerasFrontend(Frontend):
Expand All @@ -105,7 +101,8 @@ def suffixes():

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0103
tf, keras = import_keras()
tf = lazy_import("tensorflow")
keras = lazy_import("keras", from_pkg_name="tensorflow")

# tvm build currently imports keras directly instead of tensorflow.keras
try:
Expand Down Expand Up @@ -136,11 +133,11 @@ def load(self, path, shape_dict=None, **kwargs):
return relay.frontend.from_keras(model, input_shapes, **kwargs)

def is_sequential_p(self, model):
_, keras = import_keras()
keras = lazy_import("keras", from_pkg_name="tensorflow")
return isinstance(model, keras.models.Sequential)

def sequential_to_functional(self, model):
_, keras = import_keras()
keras = lazy_import("keras", from_pkg_name="tensorflow")
assert self.is_sequential_p(model)
input_layer = keras.layers.Input(batch_shape=model.layers[0].input_shape)
prev_layer = input_layer
Expand All @@ -162,8 +159,7 @@ def suffixes():
return ["onnx"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import onnx
onnx = lazy_import("onnx")

# pylint: disable=E1101
model = onnx.load(path)
Expand All @@ -183,9 +179,8 @@ def suffixes():
return ["pb"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
tf = lazy_import("tensorflow")
tf_testing = lazy_import("tvm.relay.testing.tf")

with tf.io.gfile.GFile(path, "rb") as tf_graph:
content = tf_graph.read()
Expand All @@ -210,8 +205,7 @@ def suffixes():
return ["tflite"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tflite.Model as model
model = lazy_import("tflite.Model")

with open(path, "rb") as tf_graph:
content = tf_graph.read()
Expand Down Expand Up @@ -249,8 +243,7 @@ def suffixes():
return ["pth", "zip"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import torch
torch = lazy_import("torch")

if shape_dict is None:
raise TVMCException("--input-shapes must be specified for %s" % self.name())
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/driver/tvmc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import tvm

from tvm.driver.tvmc.common import TVMCException
from tvm.driver.tvmc.common import TVMCImportError


REGISTERED_PARSER = []
Expand Down Expand Up @@ -91,6 +92,11 @@ def _main(argv):

try:
return args.func(args)
except TVMCImportError as err:
sys.stderr.write(
f'Package "{err}" is not installed. ' f'Hint: "pip install tlcpack[tvmc]".'
)
return 5
except TVMCException as err:
sys.stderr.write("Error: %s\n" % err)
return 4
Expand Down
52 changes: 52 additions & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,31 @@
# under the License.

import pytest
import builtins
import importlib

import tvm
from unittest import mock
from tvm.ir.module import IRModule

from tvm.driver import tvmc
from tvm.driver.tvmc.common import TVMCException
from tvm.driver.tvmc.common import TVMCImportError
from tvm.driver.tvmc.model import TVMCModel


orig_import = importlib.import_module


def mock_error_on_name(name):
def mock_imports(module_name, package=None):
if module_name == name:
raise ImportError()
return orig_import(module_name, package)

return mock_imports


def test_get_frontends_contains_only_strings():
sut = tvmc.frontends.get_frontend_names()
assert all([type(x) is str for x in sut]) is True
Expand Down Expand Up @@ -367,3 +383,39 @@ def _is_layout_transform(node):
tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


def test_import_keras_friendly_message(keras_resnet50, monkeypatch):
# keras is part of tensorflow
monkeypatch.setattr("importlib.import_module", mock_error_on_name("tensorflow"))

with pytest.raises(TVMCImportError, match="tensorflow") as e:
_ = tvmc.frontends.load_model(keras_resnet50, model_format="keras")


def test_import_onnx_friendly_message(onnx_resnet50, monkeypatch):
monkeypatch.setattr("importlib.import_module", mock_error_on_name("onnx"))

with pytest.raises(TVMCImportError, match="onnx") as e:
_ = tvmc.frontends.load_model(onnx_resnet50, model_format="onnx")


def test_import_tensorflow_friendly_message(pb_mobilenet_v1_1_quant, monkeypatch):
monkeypatch.setattr("importlib.import_module", mock_error_on_name("tensorflow"))

with pytest.raises(TVMCImportError, match="tensorflow") as e:
_ = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant, model_format="pb")


def test_import_torch_friendly_message(pytorch_resnet18, monkeypatch):
monkeypatch.setattr("importlib.import_module", mock_error_on_name("torch"))

with pytest.raises(TVMCImportError, match="torch") as e:
_ = tvmc.frontends.load_model(pytorch_resnet18, model_format="pytorch")


def test_import_tflite_friendly_message(tflite_mobilenet_v1_1_quant, monkeypatch):
monkeypatch.setattr("importlib.import_module", mock_error_on_name("tflite.Model"))

with pytest.raises(TVMCImportError, match="tflite.Model") as e:
_ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="tflite")

0 comments on commit 5c5bab9

Please sign in to comment.