diff --git a/.github/workflows/linux-CI.yml b/.github/workflows/linux-CI.yml index 0b48b88..00e5fc5 100644 --- a/.github/workflows/linux-CI.yml +++ b/.github/workflows/linux-CI.yml @@ -21,11 +21,11 @@ jobs: python-version: ["3.8", "3.9", "3.10"] include: - python-version: "3.8" - onnx-version: 1.12.0 + onnx-version: 1.16 - python-version: "3.9" - onnx-version: 1.13 + onnx-version: 1.16 - python-version: "3.10" - onnx-version: 1.14.1 + onnx-version: 1.16 steps: @@ -39,7 +39,7 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install onnxruntime==1.15 + pip install onnxruntime pip install onnxmltools pip install onnx==${{ matrix.onnx-version }} pip install -e . diff --git a/.github/workflows/windows-CI.yml b/.github/workflows/windows-CI.yml index 467706d..786bc12 100644 --- a/.github/workflows/windows-CI.yml +++ b/.github/workflows/windows-CI.yml @@ -21,11 +21,11 @@ jobs: python-version: ["3.8", "3.9", "3.10"] include: - python-version: "3.8" - onnx-version: 1.12.0 + onnx-version: 1.16 - python-version: "3.9" - onnx-version: 1.13 + onnx-version: 1.16 - python-version: "3.10" - onnx-version: 1.14.1 + onnx-version: 1.16 steps: - uses: actions/checkout@v4 @@ -37,7 +37,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - pip install onnxruntime==1.15 + pip install onnxruntime pip install onnxmltools pip install onnx==${{ matrix.onnx-version }} pip install -e . diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..969d736 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/onnxconverter_common/auto_mixed_precision.py b/onnxconverter_common/auto_mixed_precision.py index 6d474b4..da422be 100644 --- a/onnxconverter_common/auto_mixed_precision.py +++ b/onnxconverter_common/auto_mixed_precision.py @@ -77,9 +77,6 @@ def run_attempt(node_block_list, return_model=False): print(valid) return valid - if not run_attempt(node_names): - raise ValueError("validation failed for model with all nodes in node_block_list") - print("Sanity checks passed. Starting autoconvert.") segments = SegmentList(node_names) i = 0 while segments.get_largest() is not None: @@ -129,6 +126,9 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio # delayed import to avoid taking a strong dependancy on onnxruntime import onnxruntime as ort if output_names is None: + # Below code is for debug only, keep it for next time use + # sess_options = ort.SessionOptions() + # sess_options.optimized_model_filepath = "d:/optimized_model.onnx" sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider']) return sess.run(None, input_feed) original_outputs = list(model.graph.output) diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index a63572f..b2291d2 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -302,6 +302,7 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, break sort_topology(model.graph) + remove_unnecessary_cast_node(model.graph) return model @@ -388,3 +389,97 @@ def sort_topology(graph_proto): for g in attr.graphs: if isinstance(g, onnx_proto.GraphProto): sort_topology(g) # sort sub-graph + + +def remove_unnecessary_cast_node(graph_proto): + # 1. find all cast nodes in the graph + cast_node_list = [] + input_name_to_cast_node_dict = {} + output_name_to_cast_node_dict = {} + # using name as key to point to a node. because node cannot be key + name_to_node_dict = {} + for node in graph_proto.node: + if node.op_type == 'Cast': + if node.name not in ["graph_input_cast0", "graph_output_cast0"]: + cast_node_list.append(node) + + name_to_node_dict[node.name] = node + for input_name in node.input: + input_name_to_cast_node_dict[input_name] = node + for output_name in node.output: + output_name_to_cast_node_dict[output_name] = node + + # 2. find upstream and downstream node of the cast node + cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node + cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node + for current_node in graph_proto.node: + # find the downstream node(s) + for input_name in current_node.input: + if input_name in output_name_to_cast_node_dict: + # found the downstream node of the cast node, might be multiple + cast_node = output_name_to_cast_node_dict[input_name] + if cast_node.name not in cast_node_downstream_dict: + cast_node_downstream_dict[cast_node.name] = current_node + else: # already exists one downstream node, make it a list + existing_downstream_nodes = cast_node_downstream_dict[cast_node.name] + if isinstance(existing_downstream_nodes, list): + existing_downstream_nodes.append(current_node) + else: # make a list + existing_downstream_nodes = [existing_downstream_nodes, current_node] + cast_node_downstream_dict[cast_node.name] = existing_downstream_nodes + # find the upstream node + for output_name in current_node.output: + if output_name in input_name_to_cast_node_dict: + # found the upstream node of the cast node, should be unique + cast_node = input_name_to_cast_node_dict[output_name] + cast_node_upstream_dict[cast_node.name] = current_node + + # 3. remove the cast node which upstream is 'Constant' + for cast_node_name, upstream_node in cast_node_upstream_dict.items(): + cast_node = name_to_node_dict[cast_node_name] + if upstream_node.op_type == 'Constant': + cast_node_list.remove(cast_node) + + # 4. find the cast(to16) node which downstream is Cast(to32) + remove_candidate = [] + for cast_node_name, downstream_node in cast_node_downstream_dict.items(): + cast_node = name_to_node_dict[cast_node_name] + if isinstance(downstream_node, list): + for dn in downstream_node: + if dn.op_type == 'Cast' and \ + dn.attribute[0].i == 32 and \ + cast_node.attribute[0].i == 16 and \ + dn in cast_node_list and \ + cast_node in cast_node_list: + remove_candidate.append((cast_node, dn)) + else: + if downstream_node.op_type == 'Cast' and \ + cast_node.attribute[0].i == 10 and \ + downstream_node.attribute[0].i == 1 and \ + downstream_node in cast_node_list and \ + cast_node in cast_node_list: + remove_candidate.append((cast_node, downstream_node)) + + # 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream" + for cast_node_pair in remove_candidate: + first_cast_node = cast_node_pair[0] + second_cast_node = cast_node_pair[1] + upstream_node = cast_node_upstream_dict[first_cast_node.name] + downstream_node = cast_node_downstream_dict[second_cast_node.name] + # find the upstream node's output to first_cast_node + out = None + for output_name in upstream_node.output: + if output_name == first_cast_node.input[0]: + out = output_name + break + # find the downstream node's input as second_cast_node's output + for i, input_name in enumerate(downstream_node.input): + for output_name in second_cast_node.output: + if input_name == output_name: + # change the input as the upstream node's output + downstream_node.input[i] = out + + # 6. remove the cast node pair + for cast_node_pair in remove_candidate: + graph_proto.node.remove(cast_node_pair[0]) + graph_proto.node.remove(cast_node_pair[1])