From 88a8f91cd228fdfd6228fc7d898e4c3da0c8fd54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 15 Mar 2021 23:06:09 +0100 Subject: [PATCH] Enable option zipmap for LGBM converter (fix issue #451) (#452) * Enable option zipmap for LGBM converter * add one more unittest * support booster --- onnxmltools/convert/lightgbm/_parse.py | 28 ++++++--- onnxmltools/convert/lightgbm/convert.py | 6 +- .../lightgbm/operator_converters/LightGbm.py | 24 ++++--- onnxmltools/convert/main.py | 5 +- onnxmltools/utils/tests_helper.py | 14 ++--- .../test_LightGbmTreeEnsembleConverters.py | 62 +++++++++++++++++++ 6 files changed, 111 insertions(+), 28 deletions(-) diff --git a/onnxmltools/convert/lightgbm/_parse.py b/onnxmltools/convert/lightgbm/_parse.py index 0f023a4e0..d37c153ed 100644 --- a/onnxmltools/convert/lightgbm/_parse.py +++ b/onnxmltools/convert/lightgbm/_parse.py @@ -87,11 +87,13 @@ def _parse_lightgbm_simple_model(scope, model, inputs): return this_operator.outputs -def _parse_sklearn_classifier(scope, model, inputs): +def _parse_sklearn_classifier(scope, model, inputs, zipmap=True): probability_tensor = _parse_lightgbm_simple_model( scope, model, inputs) this_operator = scope.declare_local_operator('LgbmZipMap') this_operator.inputs = probability_tensor + this_operator.zipmap = zipmap + classes = model.classes_ label_type = Int64Type() @@ -116,33 +118,39 @@ def _parse_sklearn_classifier(scope, model, inputs): label_type = StringType() output_label = scope.declare_local_variable('label', label_type) - output_probability = scope.declare_local_variable( - 'probabilities', - SequenceType(DictionaryType(label_type, FloatTensorType()))) + if zipmap: + output_probability = scope.declare_local_variable( + 'probabilities', + SequenceType(DictionaryType(label_type, FloatTensorType()))) + else: + output_probability = scope.declare_local_variable( + 'probabilities', FloatTensorType()) this_operator.outputs.append(output_label) this_operator.outputs.append(output_probability) return this_operator.outputs -def _parse_lightgbm(scope, model, inputs): +def _parse_lightgbm(scope, model, inputs, zipmap=True): ''' This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input model's type. :param scope: Scope object :param model: A lightgbm object :param inputs: A list of variables + :param zipmap: add operator ZipMap after operator TreeEnsembleClassifier :return: The output variables produced by the input model ''' if isinstance(model, LGBMClassifier): - return _parse_sklearn_classifier(scope, model, inputs) + return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap) if (isinstance(model, WrappedBooster) and model.operator_name == 'LgbmClassifier'): - return _parse_sklearn_classifier(scope, model, inputs) + return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap) return _parse_lightgbm_simple_model(scope, model, inputs) def parse_lightgbm(model, initial_types=None, target_opset=None, - custom_conversion_functions=None, custom_shape_calculators=None): + custom_conversion_functions=None, custom_shape_calculators=None, + zipmap=True): raw_model_container = LightGbmModelContainer(model) topology = Topology(raw_model_container, default_batch_size='None', initial_types=initial_types, target_opset=target_opset, @@ -157,9 +165,9 @@ def parse_lightgbm(model, initial_types=None, target_opset=None, for variable in inputs: raw_model_container.add_input(variable) - outputs = _parse_lightgbm(scope, model, inputs) + outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap) for variable in outputs: raw_model_container.add_output(variable) - return topology + return topology \ No newline at end of file diff --git a/onnxmltools/convert/lightgbm/convert.py b/onnxmltools/convert/lightgbm/convert.py index 5495f27de..d1ac2b051 100644 --- a/onnxmltools/convert/lightgbm/convert.py +++ b/onnxmltools/convert/lightgbm/convert.py @@ -15,7 +15,7 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, - custom_shape_calculators=None, without_onnx_ml=False): + custom_shape_calculators=None, without_onnx_ml=False, zipmap=True): ''' This function produces an equivalent ONNX model of the given lightgbm model. The supported lightgbm modules are listed below. @@ -34,6 +34,7 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No :param custom_conversion_functions: a dictionary for specifying the user customized conversion function :param custom_shape_calculators: a dictionary for specifying the user customized shape calculator :param without_onnx_ml: whether to generate a model composed by ONNX operators only, or to allow the converter + :param zipmap: remove operator ZipMap from the ONNX graph to use ONNX-ML operators as well. :return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model ''' @@ -50,7 +51,8 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No name = str(uuid4().hex) target_opset = target_opset if target_opset else get_maximum_opset_supported() - topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators) + topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, + custom_shape_calculators, zipmap=zipmap) topology.compile() onnx_ml_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx) diff --git a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py index 6bd8d0315..976434806 100644 --- a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py +++ b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py @@ -4,7 +4,8 @@ import numbers import numpy as np from collections import Counter -from ...common._apply_operation import apply_div, apply_reshape, apply_sub, apply_cast, apply_identity +from ...common._apply_operation import ( + apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip) from ...common._registration import register_converter from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs from ....proto import onnx_proto @@ -453,23 +454,32 @@ def str2number(val): def convert_lgbm_zipmap(scope, operator, container): zipmap_attrs = {'name': scope.get_unique_operator_name('ZipMap')} - to_type = onnx_proto.TensorProto.INT64 - if hasattr(operator, 'classlabels_int64s'): zipmap_attrs['classlabels_int64s'] = operator.classlabels_int64s + to_type = onnx_proto.TensorProto.INT64 elif hasattr(operator, 'classlabels_strings'): zipmap_attrs['classlabels_strings'] = operator.classlabels_strings to_type = onnx_proto.TensorProto.STRING - + else: + raise RuntimeError("Unknown class type.") if to_type == onnx_proto.TensorProto.STRING: apply_identity(scope, operator.inputs[0].full_name, operator.outputs[0].full_name, container) else: apply_cast(scope, operator.inputs[0].full_name, operator.outputs[0].full_name, container, to=to_type) - container.add_node('ZipMap', operator.inputs[1].full_name, - operator.outputs[1].full_name, - op_domain='ai.onnx.ml', **zipmap_attrs) + + if operator.zipmap: + container.add_node('ZipMap', operator.inputs[1].full_name, + operator.outputs[1].full_name, + op_domain='ai.onnx.ml', **zipmap_attrs) + else: + # This should be apply_identity but optimization fails in + # onnxconverter-common when trying to remove identity nodes. + apply_clip(scope, operator.inputs[1].full_name, + operator.outputs[1].full_name, container, + min=np.array([0], dtype=np.float32), + max=np.array([1], dtype=np.float32)) register_converter('LgbmClassifier', convert_lightgbm) diff --git a/onnxmltools/convert/main.py b/onnxmltools/convert/main.py index 20164b58c..c29479a0d 100644 --- a/onnxmltools/convert/main.py +++ b/onnxmltools/convert/main.py @@ -52,13 +52,14 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, - custom_shape_calculators=None, without_onnx_ml=False): + custom_shape_calculators=None, without_onnx_ml=False, zipmap=True): if not utils.lightgbm_installed(): raise RuntimeError('lightgbm is not installed. Please install lightgbm to use this feature.') from .lightgbm.convert import convert return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx, - custom_conversion_functions, custom_shape_calculators, without_onnx_ml) + custom_conversion_functions, custom_shape_calculators, without_onnx_ml, + zipmap=zipmap) def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None, diff --git a/onnxmltools/utils/tests_helper.py b/onnxmltools/utils/tests_helper.py index f2ebda7f3..4989c2692 100644 --- a/onnxmltools/utils/tests_helper.py +++ b/onnxmltools/utils/tests_helper.py @@ -191,7 +191,7 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None, return names -def convert_model(model, name, input_types, without_onnx_ml=False): +def convert_model(model, name, input_types, without_onnx_ml=False, **kwargs): """ Runs the appropriate conversion method. @@ -201,26 +201,26 @@ def convert_model(model, name, input_types, without_onnx_ml=False): from sklearn.base import BaseEstimator if model.__class__.__name__.startswith("LGBM"): from onnxmltools.convert import convert_lightgbm - model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm" + model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm" elif model.__class__.__name__.startswith("XGB"): from onnxmltools.convert import convert_xgboost - model, prefix = convert_xgboost(model, name, input_types), "XGB" + model, prefix = convert_xgboost(model, name, input_types, **kwargs), "XGB" elif model.__class__.__name__ == 'Booster': import lightgbm if isinstance(model, lightgbm.Booster): from onnxmltools.convert import convert_lightgbm - model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm" + model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm" else: raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model))) elif model.__class__.__name__.startswith("CatBoost"): from onnxmltools.convert import convert_catboost - model, prefix = convert_catboost(model, name, input_types), "CatBoost" + model, prefix = convert_catboost(model, name, input_types, **kwargs), "CatBoost" elif isinstance(model, BaseEstimator): from onnxmltools.convert import convert_sklearn - model, prefix = convert_sklearn(model, name, input_types), "Sklearn" + model, prefix = convert_sklearn(model, name, input_types, **kwargs), "Sklearn" else: from onnxmltools.convert import convert_coreml - model, prefix = convert_coreml(model, name, input_types), "Cml" + model, prefix = convert_coreml(model, name, input_types, **kwargs), "Cml" if model is None: raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model))) return model, prefix diff --git a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py index 80307e3ac..e6eae16ff 100644 --- a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py +++ b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py @@ -5,10 +5,12 @@ import lightgbm import numpy +from numpy.testing import assert_almost_equal from lightgbm import LGBMClassifier, LGBMRegressor import onnxruntime from onnxmltools.convert.common.utils import hummingbird_installed from onnxmltools.convert.common.data_types import FloatTensorType +from onnxmltools.convert import convert_lightgbm from onnxmltools.utils import dump_data_and_model from onnxmltools.utils import dump_binary_classification, dump_multiple_classification from onnxmltools.utils import dump_single_regression @@ -32,6 +34,50 @@ def test_lightgbm_classifier_zipmap(self): model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))]) assert "zipmap" in str(onx).lower() + def test_lightgbm_classifier_nozipmap(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1, 1, 0] + model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2) + model.fit(X, y) + onx = convert_model( + model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))], + zipmap=False) + assert "zipmap" not in str(onx).lower() + onxs = onx[0].SerializeToString() + try: + sess = onnxruntime.InferenceSession(onxs) + except Exception as e: + raise AssertionError( + "Model cannot be loaded by onnxruntime due to %r\n%s." % ( + e, onx[0])) + exp = model.predict(X), model.predict_proba(X) + got = sess.run(None, {'X': X}) + assert_almost_equal(exp[0], got[0]) + assert_almost_equal(exp[1], got[1]) + + def test_lightgbm_classifier_nozipmap2(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1, 1, 0] + model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2) + model.fit(X, y) + onx = convert_lightgbm( + model, 'dummy', initial_types=[('X', FloatTensorType([None, X.shape[1]]))], + zipmap=False) + assert "zipmap" not in str(onx).lower() + onxs = onx.SerializeToString() + try: + sess = onnxruntime.InferenceSession(onxs) + except Exception as e: + raise AssertionError( + "Model cannot be loaded by onnxruntime due to %r\n%s." % ( + e, onx[0])) + exp = model.predict(X), model.predict_proba(X) + got = sess.run(None, {'X': X}) + assert_almost_equal(exp[0], got[0]) + assert_almost_equal(exp[1], got[1]) + def test_lightgbm_regressor(self): model = LGBMRegressor(n_estimators=3, min_child_samples=1) dump_single_regression(model) @@ -58,6 +104,22 @@ def test_lightgbm_booster_classifier(self): allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", basename=prefix + "BoosterBin" + model.__class__.__name__) + def test_lightgbm_booster_classifier_nozipmap(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', + 'n_estimators': 3, 'min_child_samples': 1}, + data) + model_onnx, prefix = convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], + zipmap=False) + assert "zipmap" not in str(model_onnx).lower() + dump_data_and_model(X, model, model_onnx, + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", + basename=prefix + "BoosterBin" + model.__class__.__name__) + def test_lightgbm_booster_classifier_zipmap(self): X = [[0, 1], [1, 1], [2, 0], [1, 2]] X = numpy.array(X, dtype=numpy.float32)