Skip to content

Commit

Permalink
Enable option zipmap for LGBM converter (fix issue #451) (#452)
Browse files Browse the repository at this point in the history
* Enable option zipmap for LGBM converter
* add one more unittest
* support booster
  • Loading branch information
xadupre authored Mar 15, 2021
1 parent 331df2e commit 88a8f91
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 28 deletions.
28 changes: 18 additions & 10 deletions onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def _parse_lightgbm_simple_model(scope, model, inputs):
return this_operator.outputs


def _parse_sklearn_classifier(scope, model, inputs):
def _parse_sklearn_classifier(scope, model, inputs, zipmap=True):
probability_tensor = _parse_lightgbm_simple_model(
scope, model, inputs)
this_operator = scope.declare_local_operator('LgbmZipMap')
this_operator.inputs = probability_tensor
this_operator.zipmap = zipmap

classes = model.classes_
label_type = Int64Type()

Expand All @@ -116,33 +118,39 @@ def _parse_sklearn_classifier(scope, model, inputs):
label_type = StringType()

output_label = scope.declare_local_variable('label', label_type)
output_probability = scope.declare_local_variable(
'probabilities',
SequenceType(DictionaryType(label_type, FloatTensorType())))
if zipmap:
output_probability = scope.declare_local_variable(
'probabilities',
SequenceType(DictionaryType(label_type, FloatTensorType())))
else:
output_probability = scope.declare_local_variable(
'probabilities', FloatTensorType())
this_operator.outputs.append(output_label)
this_operator.outputs.append(output_probability)
return this_operator.outputs


def _parse_lightgbm(scope, model, inputs):
def _parse_lightgbm(scope, model, inputs, zipmap=True):
'''
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input
model's type.
:param scope: Scope object
:param model: A lightgbm object
:param inputs: A list of variables
:param zipmap: add operator ZipMap after operator TreeEnsembleClassifier
:return: The output variables produced by the input model
'''
if isinstance(model, LGBMClassifier):
return _parse_sklearn_classifier(scope, model, inputs)
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
if (isinstance(model, WrappedBooster) and
model.operator_name == 'LgbmClassifier'):
return _parse_sklearn_classifier(scope, model, inputs)
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
return _parse_lightgbm_simple_model(scope, model, inputs)


def parse_lightgbm(model, initial_types=None, target_opset=None,
custom_conversion_functions=None, custom_shape_calculators=None):
custom_conversion_functions=None, custom_shape_calculators=None,
zipmap=True):
raw_model_container = LightGbmModelContainer(model)
topology = Topology(raw_model_container, default_batch_size='None',
initial_types=initial_types, target_opset=target_opset,
Expand All @@ -157,9 +165,9 @@ def parse_lightgbm(model, initial_types=None, target_opset=None,
for variable in inputs:
raw_model_container.add_input(variable)

outputs = _parse_lightgbm(scope, model, inputs)
outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap)

for variable in outputs:
raw_model_container.add_output(variable)

return topology
return topology
6 changes: 4 additions & 2 deletions onnxmltools/convert/lightgbm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def convert(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None, without_onnx_ml=False):
custom_shape_calculators=None, without_onnx_ml=False, zipmap=True):
'''
This function produces an equivalent ONNX model of the given lightgbm model.
The supported lightgbm modules are listed below.
Expand All @@ -34,6 +34,7 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:param without_onnx_ml: whether to generate a model composed by ONNX operators only, or to allow the converter
:param zipmap: remove operator ZipMap from the ONNX graph
to use ONNX-ML operators as well.
:return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model
'''
Expand All @@ -50,7 +51,8 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
name = str(uuid4().hex)

target_opset = target_opset if target_opset else get_maximum_opset_supported()
topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators)
topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions,
custom_shape_calculators, zipmap=zipmap)
topology.compile()
onnx_ml_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)

Expand Down
24 changes: 17 additions & 7 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numbers
import numpy as np
from collections import Counter
from ...common._apply_operation import apply_div, apply_reshape, apply_sub, apply_cast, apply_identity
from ...common._apply_operation import (
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
from ...common._registration import register_converter
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
from ....proto import onnx_proto
Expand Down Expand Up @@ -453,23 +454,32 @@ def str2number(val):

def convert_lgbm_zipmap(scope, operator, container):
zipmap_attrs = {'name': scope.get_unique_operator_name('ZipMap')}
to_type = onnx_proto.TensorProto.INT64

if hasattr(operator, 'classlabels_int64s'):
zipmap_attrs['classlabels_int64s'] = operator.classlabels_int64s
to_type = onnx_proto.TensorProto.INT64
elif hasattr(operator, 'classlabels_strings'):
zipmap_attrs['classlabels_strings'] = operator.classlabels_strings
to_type = onnx_proto.TensorProto.STRING

else:
raise RuntimeError("Unknown class type.")
if to_type == onnx_proto.TensorProto.STRING:
apply_identity(scope, operator.inputs[0].full_name,
operator.outputs[0].full_name, container)
else:
apply_cast(scope, operator.inputs[0].full_name,
operator.outputs[0].full_name, container, to=to_type)
container.add_node('ZipMap', operator.inputs[1].full_name,
operator.outputs[1].full_name,
op_domain='ai.onnx.ml', **zipmap_attrs)

if operator.zipmap:
container.add_node('ZipMap', operator.inputs[1].full_name,
operator.outputs[1].full_name,
op_domain='ai.onnx.ml', **zipmap_attrs)
else:
# This should be apply_identity but optimization fails in
# onnxconverter-common when trying to remove identity nodes.
apply_clip(scope, operator.inputs[1].full_name,
operator.outputs[1].full_name, container,
min=np.array([0], dtype=np.float32),
max=np.array([1], dtype=np.float32))


register_converter('LgbmClassifier', convert_lightgbm)
Expand Down
5 changes: 3 additions & 2 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target

def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None, without_onnx_ml=False):
custom_shape_calculators=None, without_onnx_ml=False, zipmap=True):
if not utils.lightgbm_installed():
raise RuntimeError('lightgbm is not installed. Please install lightgbm to use this feature.')

from .lightgbm.convert import convert
return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx,
custom_conversion_functions, custom_shape_calculators, without_onnx_ml)
custom_conversion_functions, custom_shape_calculators, without_onnx_ml,
zipmap=zipmap)


def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None,
Expand Down
14 changes: 7 additions & 7 deletions onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
return names


def convert_model(model, name, input_types, without_onnx_ml=False):
def convert_model(model, name, input_types, without_onnx_ml=False, **kwargs):
"""
Runs the appropriate conversion method.
Expand All @@ -201,26 +201,26 @@ def convert_model(model, name, input_types, without_onnx_ml=False):
from sklearn.base import BaseEstimator
if model.__class__.__name__.startswith("LGBM"):
from onnxmltools.convert import convert_lightgbm
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm"
elif model.__class__.__name__.startswith("XGB"):
from onnxmltools.convert import convert_xgboost
model, prefix = convert_xgboost(model, name, input_types), "XGB"
model, prefix = convert_xgboost(model, name, input_types, **kwargs), "XGB"
elif model.__class__.__name__ == 'Booster':
import lightgbm
if isinstance(model, lightgbm.Booster):
from onnxmltools.convert import convert_lightgbm
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm"
else:
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
elif model.__class__.__name__.startswith("CatBoost"):
from onnxmltools.convert import convert_catboost
model, prefix = convert_catboost(model, name, input_types), "CatBoost"
model, prefix = convert_catboost(model, name, input_types, **kwargs), "CatBoost"
elif isinstance(model, BaseEstimator):
from onnxmltools.convert import convert_sklearn
model, prefix = convert_sklearn(model, name, input_types), "Sklearn"
model, prefix = convert_sklearn(model, name, input_types, **kwargs), "Sklearn"
else:
from onnxmltools.convert import convert_coreml
model, prefix = convert_coreml(model, name, input_types), "Cml"
model, prefix = convert_coreml(model, name, input_types, **kwargs), "Cml"
if model is None:
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
return model, prefix
Expand Down
62 changes: 62 additions & 0 deletions tests/lightgbm/test_LightGbmTreeEnsembleConverters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import lightgbm
import numpy
from numpy.testing import assert_almost_equal
from lightgbm import LGBMClassifier, LGBMRegressor
import onnxruntime
from onnxmltools.convert.common.utils import hummingbird_installed
from onnxmltools.convert.common.data_types import FloatTensorType
from onnxmltools.convert import convert_lightgbm
from onnxmltools.utils import dump_data_and_model
from onnxmltools.utils import dump_binary_classification, dump_multiple_classification
from onnxmltools.utils import dump_single_regression
Expand All @@ -32,6 +34,50 @@ def test_lightgbm_classifier_zipmap(self):
model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))])
assert "zipmap" in str(onx).lower()

def test_lightgbm_classifier_nozipmap(self):
X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]]
X = numpy.array(X, dtype=numpy.float32)
y = [0, 1, 0, 1, 1, 0]
model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2)
model.fit(X, y)
onx = convert_model(
model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))],
zipmap=False)
assert "zipmap" not in str(onx).lower()
onxs = onx[0].SerializeToString()
try:
sess = onnxruntime.InferenceSession(onxs)
except Exception as e:
raise AssertionError(
"Model cannot be loaded by onnxruntime due to %r\n%s." % (
e, onx[0]))
exp = model.predict(X), model.predict_proba(X)
got = sess.run(None, {'X': X})
assert_almost_equal(exp[0], got[0])
assert_almost_equal(exp[1], got[1])

def test_lightgbm_classifier_nozipmap2(self):
X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]]
X = numpy.array(X, dtype=numpy.float32)
y = [0, 1, 0, 1, 1, 0]
model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2)
model.fit(X, y)
onx = convert_lightgbm(
model, 'dummy', initial_types=[('X', FloatTensorType([None, X.shape[1]]))],
zipmap=False)
assert "zipmap" not in str(onx).lower()
onxs = onx.SerializeToString()
try:
sess = onnxruntime.InferenceSession(onxs)
except Exception as e:
raise AssertionError(
"Model cannot be loaded by onnxruntime due to %r\n%s." % (
e, onx[0]))
exp = model.predict(X), model.predict_proba(X)
got = sess.run(None, {'X': X})
assert_almost_equal(exp[0], got[0])
assert_almost_equal(exp[1], got[1])

def test_lightgbm_regressor(self):
model = LGBMRegressor(n_estimators=3, min_child_samples=1)
dump_single_regression(model)
Expand All @@ -58,6 +104,22 @@ def test_lightgbm_booster_classifier(self):
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
basename=prefix + "BoosterBin" + model.__class__.__name__)

def test_lightgbm_booster_classifier_nozipmap(self):
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
X = numpy.array(X, dtype=numpy.float32)
y = [0, 1, 0, 1]
data = lightgbm.Dataset(X, label=y)
model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary',
'n_estimators': 3, 'min_child_samples': 1},
data)
model_onnx, prefix = convert_model(model, 'tree-based classifier',
[('input', FloatTensorType([None, 2]))],
zipmap=False)
assert "zipmap" not in str(model_onnx).lower()
dump_data_and_model(X, model, model_onnx,
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
basename=prefix + "BoosterBin" + model.__class__.__name__)

def test_lightgbm_booster_classifier_zipmap(self):
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
X = numpy.array(X, dtype=numpy.float32)
Expand Down

0 comments on commit 88a8f91

Please sign in to comment.