Skip to content

Commit

Permalink
Remove useless cast nodes in graph after converting to fp16 model (#286)
Browse files Browse the repository at this point in the history
* update

* Update float16.py

* Update float16.py

* Update float16.py

* remove comments

* update

* upgrade to 1.16 onnx

Due to security issue

* update

* Update float16.py
  • Loading branch information
xiaowuhu authored Apr 22, 2024
1 parent 18f795f commit 741f937
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 11 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/linux-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ jobs:
python-version: ["3.8", "3.9", "3.10"]
include:
- python-version: "3.8"
onnx-version: 1.12.0
onnx-version: 1.16
- python-version: "3.9"
onnx-version: 1.13
onnx-version: 1.16
- python-version: "3.10"
onnx-version: 1.14.1
onnx-version: 1.16


steps:
Expand All @@ -39,7 +39,7 @@ jobs:
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install onnxruntime==1.15
pip install onnxruntime
pip install onnxmltools
pip install onnx==${{ matrix.onnx-version }}
pip install -e .
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/windows-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ jobs:
python-version: ["3.8", "3.9", "3.10"]
include:
- python-version: "3.8"
onnx-version: 1.12.0
onnx-version: 1.16
- python-version: "3.9"
onnx-version: 1.13
onnx-version: 1.16
- python-version: "3.10"
onnx-version: 1.14.1
onnx-version: 1.16

steps:
- uses: actions/checkout@v4
Expand All @@ -37,7 +37,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install onnxruntime==1.15
pip install onnxruntime
pip install onnxmltools
pip install onnx==${{ matrix.onnx-version }}
pip install -e .
Expand Down
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
6 changes: 3 additions & 3 deletions onnxconverter_common/auto_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def run_attempt(node_block_list, return_model=False):
print(valid)
return valid

if not run_attempt(node_names):
raise ValueError("validation failed for model with all nodes in node_block_list")
print("Sanity checks passed. Starting autoconvert.")
segments = SegmentList(node_names)
i = 0
while segments.get_largest() is not None:
Expand Down Expand Up @@ -129,6 +126,9 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio
# delayed import to avoid taking a strong dependancy on onnxruntime
import onnxruntime as ort
if output_names is None:
# Below code is for debug only, keep it for next time use
# sess_options = ort.SessionOptions()
# sess_options.optimized_model_filepath = "d:/optimized_model.onnx"
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
return sess.run(None, input_feed)
original_outputs = list(model.graph.output)
Expand Down
95 changes: 95 additions & 0 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4,
break

sort_topology(model.graph)
remove_unnecessary_cast_node(model.graph)
return model


Expand Down Expand Up @@ -388,3 +389,97 @@ def sort_topology(graph_proto):
for g in attr.graphs:
if isinstance(g, onnx_proto.GraphProto):
sort_topology(g) # sort sub-graph


def remove_unnecessary_cast_node(graph_proto):
# 1. find all cast nodes in the graph
cast_node_list = []
input_name_to_cast_node_dict = {}
output_name_to_cast_node_dict = {}
# using name as key to point to a node. because node cannot be key
name_to_node_dict = {}
for node in graph_proto.node:
if node.op_type == 'Cast':
if node.name not in ["graph_input_cast0", "graph_output_cast0"]:
cast_node_list.append(node)

name_to_node_dict[node.name] = node
for input_name in node.input:
input_name_to_cast_node_dict[input_name] = node
for output_name in node.output:
output_name_to_cast_node_dict[output_name] = node

# 2. find upstream and downstream node of the cast node
cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node
cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node
for current_node in graph_proto.node:
# find the downstream node(s)
for input_name in current_node.input:
if input_name in output_name_to_cast_node_dict:
# found the downstream node of the cast node, might be multiple
cast_node = output_name_to_cast_node_dict[input_name]
if cast_node.name not in cast_node_downstream_dict:
cast_node_downstream_dict[cast_node.name] = current_node
else: # already exists one downstream node, make it a list
existing_downstream_nodes = cast_node_downstream_dict[cast_node.name]
if isinstance(existing_downstream_nodes, list):
existing_downstream_nodes.append(current_node)
else: # make a list
existing_downstream_nodes = [existing_downstream_nodes, current_node]
cast_node_downstream_dict[cast_node.name] = existing_downstream_nodes
# find the upstream node
for output_name in current_node.output:
if output_name in input_name_to_cast_node_dict:
# found the upstream node of the cast node, should be unique
cast_node = input_name_to_cast_node_dict[output_name]
cast_node_upstream_dict[cast_node.name] = current_node

# 3. remove the cast node which upstream is 'Constant'
for cast_node_name, upstream_node in cast_node_upstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if upstream_node.op_type == 'Constant':
cast_node_list.remove(cast_node)

# 4. find the cast(to16) node which downstream is Cast(to32)
remove_candidate = []
for cast_node_name, downstream_node in cast_node_downstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if isinstance(downstream_node, list):
for dn in downstream_node:
if dn.op_type == 'Cast' and \
dn.attribute[0].i == 32 and \
cast_node.attribute[0].i == 16 and \
dn in cast_node_list and \
cast_node in cast_node_list:
remove_candidate.append((cast_node, dn))
else:
if downstream_node.op_type == 'Cast' and \
cast_node.attribute[0].i == 10 and \
downstream_node.attribute[0].i == 1 and \
downstream_node in cast_node_list and \
cast_node in cast_node_list:
remove_candidate.append((cast_node, downstream_node))

# 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream"
for cast_node_pair in remove_candidate:
first_cast_node = cast_node_pair[0]
second_cast_node = cast_node_pair[1]
upstream_node = cast_node_upstream_dict[first_cast_node.name]
downstream_node = cast_node_downstream_dict[second_cast_node.name]
# find the upstream node's output to first_cast_node
out = None
for output_name in upstream_node.output:
if output_name == first_cast_node.input[0]:
out = output_name
break
# find the downstream node's input as second_cast_node's output
for i, input_name in enumerate(downstream_node.input):
for output_name in second_cast_node.output:
if input_name == output_name:
# change the input as the upstream node's output
downstream_node.input[i] = out

# 6. remove the cast node pair
for cast_node_pair in remove_candidate:
graph_proto.node.remove(cast_node_pair[0])
graph_proto.node.remove(cast_node_pair[1])

0 comments on commit 741f937

Please sign in to comment.