Skip to content

Commit

Permalink
Improve tvmc error message from lazy-loading frontend imports
Browse files Browse the repository at this point in the history
When installing TVM from the python package, the Frontend frameworks dependencies such as TensorFlow, PyTorch, ONNX, etc, are not installed by default.
In case a user tries to run tvmc using a model whose framework was not installed, it will be presented with a very raw Python exception in the output.
The aim of this commit is to implement a better error messages for errors related to lazy-loading frontend frameworks in tvmc.

Change-Id: Ida52fac4116af392ee436390e14ea02c7090cef0
  • Loading branch information
ophirfrish committed Oct 7, 2021
1 parent d9aae9c commit 8c23001
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 7 deletions.
64 changes: 57 additions & 7 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def load(self, path, shape_dict=None, **kwargs):
"""


def create_import_error_string(library_name, import_name):
err = (
f"{library_name} is required and was not found. "
f'Please install it using "pip install {import_name}".'
)
return str(err)


def import_keras():
"""Lazy import function for Keras"""
# Keras writes the message "Using TensorFlow backend." to stderr
Expand All @@ -88,10 +96,54 @@ def import_keras():
from tensorflow import keras

return tf, keras
except ImportError:
raise TVMCException(create_import_error_string("Tensorflow", "tensorflow"))
finally:
sys.stderr = stderr


def import_onnx():
"""Lazy import function for onnx"""
try:
# pylint: disable=C0415
import onnx as _onnx
except ImportError:
raise TVMCException(create_import_error_string("ONNX", "onnx"))
return _onnx


def import_tensorflow():
"""Lazy import function for tensorflow"""
try:
# pylint: disable=C0415
import tensorflow as tf
except ImportError:
raise TVMCException(create_import_error_string("Tensorflow", "tensorflow"))
return tf


def import_torch():
"""Lazy import function for torch"""
try:
# pylint: disable=C0415
import torch as tc
except ImportError:
raise TVMCException(create_import_error_string("Torch", "torch"))

return tc


def import_tflite():
"""Lazy import function for tflite.Model"""
try:
# pylint: disable=C0415
import tflite.Model as model
except ImportError:
raise TVMCException(create_import_error_string("tflite", "tflite"))

return model


class KerasFrontend(Frontend):
"""Keras frontend for TVMC"""

Expand Down Expand Up @@ -162,8 +214,7 @@ def suffixes():
return ["onnx"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import onnx
onnx = import_onnx()

# pylint: disable=E1101
model = onnx.load(path)
Expand All @@ -184,9 +235,10 @@ def suffixes():

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing

tf = import_tensorflow()

with tf.io.gfile.GFile(path, "rb") as tf_graph:
content = tf_graph.read()

Expand All @@ -210,8 +262,7 @@ def suffixes():
return ["tflite"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tflite.Model as model
model = import_tflite()

with open(path, "rb") as tf_graph:
content = tf_graph.read()
Expand Down Expand Up @@ -249,8 +300,7 @@ def suffixes():
return ["pth", "zip"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import torch
torch = import_torch()

if shape_dict is None:
raise TVMCException("--input-shapes must be specified for %s" % self.name())
Expand Down
60 changes: 60 additions & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,28 @@
import tarfile

import pytest
import builtins

from unittest import mock
from tvm.ir.module import IRModule

from tvm.driver import tvmc
from tvm.driver.tvmc.common import TVMCException
from tvm.driver.tvmc.model import TVMCModel


orig_import = builtins.__import__


def mock_error_on_name(name):
def mock_imports(module_name, *args):
if module_name == name:
raise ImportError()
return orig_import(module_name, *args)

return mock_imports


def test_get_frontends_contains_only_strings():
sut = tvmc.frontends.get_frontend_names()
assert all([type(x) is str for x in sut]) is True
Expand Down Expand Up @@ -211,3 +225,49 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
)


def test_import_keras_friendly_message(keras_resnet50, monkeypatch):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("keras")
# keras is part of tensorflow
monkeypatch.setattr("builtins.__import__", mock_error_on_name("tensorflow"))

with pytest.raises(TVMCException) as e:
_ = tvmc.frontends.load_model(keras_resnet50, model_format="keras")


def test_import_onnx_friendly_message(onnx_resnet50, monkeypatch):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
monkeypatch.setattr("builtins.__import__", mock_error_on_name("onnx"))

with pytest.raises(TVMCException) as e:
_ = tvmc.frontends.load_model(onnx_resnet50, model_format="onnx")


def test_import_tensorflow_friendly_message(pb_mobilenet_v1_1_quant, monkeypatch):
# some CI environments wont offer tensorflow, so skip in case it is not present
pytest.importorskip("tensorflow")
monkeypatch.setattr("builtins.__import__", mock_error_on_name("tensorflow"))

with pytest.raises(TVMCException) as e:
_ = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant, model_format="pb")


def test_import_torch_friendly_message(pytorch_resnet18, monkeypatch):
# some CI environments wont offer pytorch, so skip in case it is not present
pytest.importorskip("torch")
monkeypatch.setattr("builtins.__import__", mock_error_on_name("torch"))

with pytest.raises(TVMCException) as e:
_ = tvmc.frontends.load_model(pytorch_resnet18, model_format="pytorch")


def test_import_tflite_friendly_message(tflite_mobilenet_v1_1_quant, monkeypatch):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
monkeypatch.setattr("builtins.__import__", mock_error_on_name("tflite.Model"))

with pytest.raises(TVMCException) as e:
_ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="tflite")

0 comments on commit 8c23001

Please sign in to comment.