Skip to content

Commit

Permalink
check consumers of dq node before swap dq and transpose (#12099)
Browse files Browse the repository at this point in the history
* check consumers of dq node before swap dq and transpose

* add unit test
  • Loading branch information
yufenglee authored and RandySheriffH committed Jul 6, 2022
1 parent 0f20e84 commit 57730b2
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1979,11 +1979,18 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) {
continue;
}

auto consumers = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]);
bool is_part_of_qdq_group = std::find_if(consumers->nodes.cbegin(), consumers->nodes.cend(),
// Check if Transpose node is the only consumer of dq node
auto consumers_of_dq_node = ctx.graph.GetValueConsumers(dq_node->Outputs()[0]);
if (!consumers_of_dq_node->comprehensive || consumers_of_dq_node->nodes.size() > 1) {
continue;
}

auto consumers_of_transpose_node = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]);
bool is_part_of_qdq_group = std::find_if(consumers_of_transpose_node->nodes.cbegin(),
consumers_of_transpose_node->nodes.cend(),
[](const std::unique_ptr<api::NodeRef>& node) {
return node->OpType() == "QuantizeLinear";
}) != consumers->nodes.cend();
}) != consumers_of_transpose_node->nodes.cend();
if (is_part_of_qdq_group) {
continue;
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/optimizer/qdq_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,15 @@ GetQDQTestCaseFn BuildQDQMatMulTestCase(const std::vector<int64_t>& input1_shape
};
}

std::vector<std::string> GetNodeOpTypesInTopologicalOrder(const Graph& graph) {
std::vector<std::string> op_types{};
GraphViewer graph_viewer{graph};
const auto& ordering = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : ordering) {
op_types.push_back(graph.GetNode(node_idx)->OpType());
}
return op_types;
}

} // namespace test
} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#pragma once

#include <vector>
#include <string>

#include "graph_transform_test_builder.h"

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
Expand Down Expand Up @@ -359,5 +362,7 @@ GetQDQTestCaseFn BuildQDQGemmTestCase(const std::vector<int64_t>& input1_shape,
};
}

std::vector<std::string> GetNodeOpTypesInTopologicalOrder(const Graph& graph);

} // namespace test
} // namespace onnxruntime
10 changes: 0 additions & 10 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,6 @@
namespace onnxruntime {
namespace test {

static std::vector<std::string> GetNodeOpTypesInTopologicalOrder(const Graph& graph) {
std::vector<std::string> op_types{};
GraphViewer graph_viewer{graph};
const auto& ordering = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : ordering) {
op_types.push_back(graph.GetNode(node_idx)->OpType());
}
return op_types;
}

#if !defined(DISABLE_CONTRIB_OPS)

template <typename InputType, typename WeightType, typename BiasType, typename OutputType>
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "graph_transform_test_builder.h"

#include "core/graph/graph.h"
#include "qdq_test_utils.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"

Expand Down Expand Up @@ -3591,6 +3592,42 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) {
/*opset_version*/ 10);
}

TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<uint8_t>(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5);
auto* input1_arg = MakeInput<float>(builder, {std::vector<int64_t>{}}, std::vector<int64_t>{}, {2.3f});
auto* input2_arg = MakeInput<uint8_t>(builder, {std::vector<int64_t>{}}, std::vector<int64_t>{}, {10});
auto* dequantizelinear_1_out_0 = builder.MakeIntermediate();
auto* transpose_1_out_0 = builder.MakeOutput();
auto* transpose_2_out_0 = builder.MakeOutput();

builder.AddNode("DequantizeLinear", {input0_arg, input1_arg, input2_arg}, {dequantizelinear_1_out_0});

auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0});
transpose_1.AddAttribute("perm", std::vector<int64_t>{0, 3, 1, 2});

auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0});
transpose_2.AddAttribute("perm", std::vector<int64_t>{0, 2, 3, 1});
};

auto check_graph = [&](InferenceSessionWrapper& session) {
std::vector<std::string> expected_op_types_in_order{
"DequantizeLinear",
"Transpose",
"Transpose"};

const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph());
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
};


TransformerTester(build_test_case_1,
check_graph,
TransformerLevel::Default,
TransformerLevel::Level1,
/*opset_version*/ 10);
}

TEST(TransposeOptimizerTests, TestCast) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<int32_t>(builder, {{-1, 4, -1, 5}}, {2, 4, 6, 5}, -1, 5);
Expand Down

0 comments on commit 57730b2

Please sign in to comment.