test_graph_simplifier.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. class Test_Graph_Simplifier : public ::testing::Test {
  7. public:
  8. bool required;
  9. Test_Graph_Simplifier() : required(true) {}
  10. void test_conformance(const std::string &basename, const std::string &expected_layer) {
  11. test(basename + std::string("/model"), std::vector<std::string>{expected_layer}, std::string("dnn/onnx/conformance/node/"));
  12. }
  13. void test(const std::string &basename, const std::string &expected_layer) {
  14. test(basename, std::vector<std::string>{expected_layer});
  15. }
  16. void test(const std::string &basename, const std::vector<std::string> &expected_layers, const std::string &model_path_prefix = std::string("dnn/onnx/models/")) {
  17. std::string model_path = findDataFile(model_path_prefix + basename + std::string(".onnx"), required);
  18. auto net = readNet(model_path);
  19. std::vector<std::string> layers;
  20. net.getLayerTypes(layers);
  21. // remove Const, Identity (output layer), __NetInputLayer__ (input layer)
  22. layers.erase(std::remove_if(layers.begin(), layers.end(), [] (const std::string l) { return l == "Const" || l == "Identity" || l == "__NetInputLayer__"; }), layers.end());
  23. EXPECT_EQ(layers, expected_layers);
  24. }
  25. };
  26. TEST_F(Test_Graph_Simplifier, GeluSubGraph) {
  27. test("gelu", "Gelu");
  28. test("bias_gelu", std::vector<std::string>{"Gelu", "NaryEltwise"});
  29. }
  30. TEST_F(Test_Graph_Simplifier, GeluApproximationSubGraph) {
  31. test("gelu_approximation", "GeluApproximation");
  32. }
  33. TEST_F(Test_Graph_Simplifier, LayerNormSubGraph) {
  34. test("layer_norm_expanded", "LayerNormalization");
  35. test("layer_norm_expanded_with_initializers", "LayerNormalization");
  36. }
  37. TEST_F(Test_Graph_Simplifier, LayerNormNoFusionSubGraph) {
  38. test("layer_norm_no_fusion", std::vector<std::string>{"NaryEltwise", "Reduce", "Sqrt"});
  39. }
  40. TEST_F(Test_Graph_Simplifier, ResizeSubgraph) {
  41. /* Test for 6 subgraphs:
  42. - GatherCastSubgraph
  43. - MulCastSubgraph
  44. - UpsampleSubgraph
  45. - ResizeSubgraph1
  46. - ResizeSubgraph2
  47. - ResizeSubgraph3
  48. */
  49. test("upsample_unfused_torch1.2", std::vector<std::string>{"BatchNorm", "Resize"});
  50. test("resize_nearest_unfused_opset11_torch1.3", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
  51. test("resize_nearest_unfused_opset11_torch1.4", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
  52. test("upsample_unfused_opset9_torch1.4", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
  53. test("two_resizes_with_shared_subgraphs", std::vector<std::string>{"NaryEltwise", "Resize"});
  54. }
  55. TEST_F(Test_Graph_Simplifier, SoftmaxSubgraph) {
  56. /* Test for 3 subgraphs
  57. - SoftMaxSubgraph
  58. - SoftMaxSubgraph2 (conformance)
  59. - LogSoftMaxSubgraph (conformance)
  60. */
  61. test("softmax_unfused", "Softmax");
  62. test_conformance("test_softmax_example_expanded", "Softmax");
  63. test_conformance("test_softmax_axis_2_expanded", "Softmax");
  64. test_conformance("test_softmax_default_axis_expanded", "Softmax");
  65. test_conformance("test_softmax_axis_0_expanded", "Softmax");
  66. test_conformance("test_softmax_axis_1_expanded", "Softmax");
  67. test_conformance("test_softmax_large_number_expanded", "Softmax");
  68. test_conformance("test_softmax_negative_axis_expanded", "Softmax");
  69. test_conformance("test_logsoftmax_axis_2_expanded", "Softmax");
  70. test_conformance("test_logsoftmax_example_1_expanded", "Softmax");
  71. test_conformance("test_logsoftmax_negative_axis_expanded", "Softmax");
  72. test_conformance("test_logsoftmax_axis_0_expanded", "Softmax");
  73. test_conformance("test_logsoftmax_axis_1_expanded", "Softmax");
  74. test_conformance("test_logsoftmax_large_number_expanded", "Softmax");
  75. test_conformance("test_logsoftmax_default_axis_expanded", "Softmax");
  76. }
  77. TEST_F(Test_Graph_Simplifier, HardSwishSubgraph) {
  78. test_conformance("test_hardswish_expanded", "HardSwish");
  79. }
  80. TEST_F(Test_Graph_Simplifier, CeluSubgraph) {
  81. test_conformance("test_celu_expanded", "Celu");
  82. }
  83. TEST_F(Test_Graph_Simplifier, NormalizeSubgraph) {
  84. /* Test for 6 subgraphs
  85. - NormalizeSubgraph1
  86. - NormalizeSubgraph2
  87. - NormalizeSubgraph2_2
  88. - NormalizeSubgraph3
  89. - NormalizeSubgraph4
  90. - NormalizeSubgraph5
  91. */
  92. test("reduceL2_subgraph_2", "Normalize");
  93. test("reduceL2_subgraph", "Normalize");
  94. test("normalize_fusion", "Normalize");
  95. }
  96. TEST_F(Test_Graph_Simplifier, BatchNormalizationSubgraph) {
  97. /* Test for 2 subgraphs
  98. - BatchNormalizationSubgraph1
  99. - BatchNormalizationSubgraph2
  100. */
  101. test("frozenBatchNorm2d", "BatchNorm");
  102. test("batch_norm_subgraph", "BatchNorm");
  103. }
  104. TEST_F(Test_Graph_Simplifier, ExpandSubgraph) {
  105. test("expand_neg_batch", "Expand");
  106. }
  107. TEST_F(Test_Graph_Simplifier, MishSubgraph) {
  108. /* Test for 2 subgraphs
  109. - SoftplusSubgraph
  110. - MishSubgraph
  111. */
  112. test("mish_no_softplus", "Mish");
  113. test("mish", "Mish");
  114. }
  115. TEST_F(Test_Graph_Simplifier, AttentionSubgraph) {
  116. /* Test for 2 subgraphs
  117. - AttentionSubgraph
  118. - AttentionSingleHeadSubgraph
  119. */
  120. test("attention", "Attention");
  121. test("attention_single_head", "Attention");
  122. }
  123. TEST_F(Test_Graph_Simplifier, BiasedMatMulSubgraph) {
  124. /* Test for 1 subgraphs
  125. - BiasedMatMulSubgraph
  126. */
  127. test("biased_matmul", "MatMul");
  128. }
  129. }}