Skip to content

Commit

Permalink
support parameter shape_override (#497)
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre authored Sep 14, 2021
1 parent 8e52e09 commit 1767f05
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def _convert_tf_wrapper(frozen_graph_def,
doc_string='',
target_opset=None,
channel_first_inputs=None,
debug_mode=False, custom_op_conversions=None):
debug_mode=False, custom_op_conversions=None,
**kwargs):
"""
convert a tensorflow graph def into a ONNX model proto, just like how keras does.
:param graph_def: the frozen tensorflow graph
Expand All @@ -220,6 +221,8 @@ def _convert_tf_wrapper(frozen_graph_def,
:param target_opset: the targeted onnx model opset
:param channel_first_inputs: A list of channel first input (not supported yet)
:param debug_mode: will enable the log and try to convert as much as possible on conversion
:param kwargs: additional parameters of function `processs_tf_graph
<https://github.com/onnx/tensorflow-onnx#creating-custom-op-mappings-from-python>`_
:return an ONNX ModelProto
"""
import tensorflow as tf
Expand All @@ -244,7 +247,8 @@ def _convert_tf_wrapper(frozen_graph_def,
custom_op_handlers=custom_op_conversions,
inputs_as_nchw=channel_first_inputs,
output_names=output_names,
input_names=input_names)
input_names=input_names,
**kwargs)

onnx_graph = tf2onnx.optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(doc_string)
Expand All @@ -257,10 +261,12 @@ def convert_tensorflow(frozen_graph_def,
doc_string='',
target_opset=None,
channel_first_inputs=None,
debug_mode=False, custom_op_conversions=None):
debug_mode=False, custom_op_conversions=None,
**kwargs):
import pkgutil
if not pkgutil.find_loader('tf2onnx'):
raise RuntimeError('tf2onnx is not installed, please install it before calling this function.')

return _convert_tf_wrapper(frozen_graph_def, name, input_names, output_names, doc_string,
target_opset, channel_first_inputs, debug_mode, custom_op_conversions)
target_opset, channel_first_inputs, debug_mode, custom_op_conversions,
**kwargs)

0 comments on commit 1767f05

Please sign in to comment.