Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FAI-917: Add error message capturing within Python models #137

Merged
merged 5 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ build-backend = "setuptools.build_meta"
package-dir = { "" = "src" }

[tool.pytest.ini_options]
log_cli = true
addopts = '-m="not block_plots"'
markers = [
"block_plots: Test plots will block execution of subsequent tests until closed"
Expand Down
27 changes: 24 additions & 3 deletions src/trustyai/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# pylint: disable = unused-import, wrong-import-order
# pylint: disable = consider-using-f-string
"""General model classes"""
import logging
import traceback
import uuid as _uuid
from typing import List, Optional, Union, Callable
import pandas as pd
Expand Down Expand Up @@ -322,7 +324,8 @@ def __init__(self, predict_fun, **kwargs):
transfer between Java and Python. If false, Arrow will be automatically used in
situations where it is advantageous to do so.
"""
self.predict_fun = predict_fun

self.predict_fun = self._error_catcher(predict_fun)
self.kwargs = kwargs

self.prediction_provider_arrow = None
Expand All @@ -332,6 +335,24 @@ def __init__(self, predict_fun, **kwargs):
# set model to use non-arrow by default, as this requires no dataset information
self._set_nonarrow()

def _error_catcher(self, predict_fun):
"""Wrapper for predict function to capture errors to Python logger before the JVM dies"""

def wrapper(x):
try:
return predict_fun(x)
except Exception as e:
logging.error(
" Fatal runtime error within the `predict_fun` supplied to trustyai.Model"
)
logging.error(
" The error message has been captured and reproduced below:"
)
logging.error(" %s", traceback.format_exc())
raise e

return wrapper

@property
def dataframe_input(self):
"""Get dataframe_input kwarg value"""
Expand Down Expand Up @@ -483,7 +504,7 @@ def __enter__(self):
self.previous_model_state = self.model.prediction_provider
self.model._set_arrow(self.paradigm_input)

def __exit__(self, exit_type, value, traceback):
def __exit__(self, exit_type, value, tb):
if self.model_is_python:
self.model.prediction_provider = self.previous_model_state

Expand All @@ -502,7 +523,7 @@ def __enter__(self):
self.previous_model_state = self.model.prediction_provider
self.model._set_nonarrow()

def __exit__(self, exit_type, value, traceback):
def __exit__(self, exit_type, value, tb):
if self.model_is_python:
self.model.prediction_provider = self.previous_model_state

Expand Down
12 changes: 12 additions & 0 deletions tests/general/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name
"""Test model provider interface"""
from trustyai.explainers import LimeExplainer

from common import *
from trustyai.model import Model, Dataset, feature
Expand Down Expand Up @@ -46,3 +47,14 @@ def test_cast_output_arrow():
output_val = m.predictAsync(pis).get()
assert len(output_val) == 25


def test_error_model(caplog):
"""test that a broken model spits out useful debugging info"""
m = Model(lambda x: str(x) - str(x))
try:
LimeExplainer().explain(0, 0, m)
except Exception:
pass

assert "Fatal runtime error" in caplog.text
assert "TypeError: unsupported operand type(s) for -: 'str' and 'str'" in caplog.text