model_diagnostics.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /*************************************************
  2. USAGE:
  3. ./model_diagnostics -m <model file location>
  4. **************************************************/
  5. #include <opencv2/dnn.hpp>
  6. #include <opencv2/core/utils/filesystem.hpp>
  7. #include <opencv2/dnn/utils/debug_utils.hpp>
  8. #include <iostream>
  9. using namespace cv;
  10. using namespace dnn;
  11. static
  12. int diagnosticsErrorCallback(int /*status*/, const char* /*func_name*/,
  13. const char* /*err_msg*/, const char* /*file_name*/,
  14. int /*line*/, void* /*userdata*/)
  15. {
  16. fflush(stdout);
  17. fflush(stderr);
  18. return 0;
  19. }
  20. static std::string checkFileExists(const std::string& fileName)
  21. {
  22. if (fileName.empty() || utils::fs::exists(fileName))
  23. return fileName;
  24. CV_Error(Error::StsObjectNotFound, "File " + fileName + " was not found! "
  25. "Please, specify a full path to the file.");
  26. }
  27. static std::vector<int> parseShape(const std::string &shape_str) {
  28. std::stringstream ss(shape_str);
  29. std::string item;
  30. std::vector<std::string> items;
  31. while (std::getline(ss, item, ',')) {
  32. items.push_back(item);
  33. }
  34. std::vector<int> shape;
  35. for (size_t i = 0; i < items.size(); i++) {
  36. shape.push_back(std::stoi(items[i]));
  37. }
  38. return shape;
  39. }
  40. std::string diagnosticKeys =
  41. "{ model m | | Path to the model file. }"
  42. "{ config c | | Path to the model configuration file. }"
  43. "{ framework f | | [Optional] Name of the model framework. }"
  44. "{ input0_name | | [Optional] Name of input0. Use with input0_shape}"
  45. "{ input0_shape | | [Optional] Shape of input0. Use with input0_name}"
  46. "{ input1_name | | [Optional] Name of input1. Use with input1_shape}"
  47. "{ input1_shape | | [Optional] Shape of input1. Use with input1_name}"
  48. "{ input2_name | | [Optional] Name of input2. Use with input2_shape}"
  49. "{ input2_shape | | [Optional] Shape of input2. Use with input2_name}"
  50. "{ input3_name | | [Optional] Name of input3. Use with input3_shape}"
  51. "{ input3_shape | | [Optional] Shape of input3. Use with input3_name}"
  52. "{ input4_name | | [Optional] Name of input4. Use with input4_shape}"
  53. "{ input4_shape | | [Optional] Shape of input4. Use with input4_name}";
  54. int main( int argc, const char** argv )
  55. {
  56. CommandLineParser argParser(argc, argv, diagnosticKeys);
  57. argParser.about("Use this tool to run the diagnostics of provided ONNX/TF model"
  58. "to obtain the information about its support (supported layers).");
  59. if (argc == 1)
  60. {
  61. argParser.printMessage();
  62. return 0;
  63. }
  64. std::string model = checkFileExists(argParser.get<std::string>("model"));
  65. std::string config = checkFileExists(argParser.get<std::string>("config"));
  66. std::string frameworkId = argParser.get<std::string>("framework");
  67. std::string input0_name = argParser.get<std::string>("input0_name");
  68. std::string input0_shape = argParser.get<std::string>("input0_shape");
  69. std::string input1_name = argParser.get<std::string>("input1_name");
  70. std::string input1_shape = argParser.get<std::string>("input1_shape");
  71. std::string input2_name = argParser.get<std::string>("input2_name");
  72. std::string input2_shape = argParser.get<std::string>("input2_shape");
  73. std::string input3_name = argParser.get<std::string>("input3_name");
  74. std::string input3_shape = argParser.get<std::string>("input3_shape");
  75. std::string input4_name = argParser.get<std::string>("input4_name");
  76. std::string input4_shape = argParser.get<std::string>("input4_shape");
  77. CV_Assert(!model.empty());
  78. enableModelDiagnostics(true);
  79. skipModelImport(true);
  80. redirectError(diagnosticsErrorCallback, NULL);
  81. Net ocvNet = readNet(model, config, frameworkId);
  82. std::vector<std::string> input_names;
  83. std::vector<std::vector<int>> input_shapes;
  84. if (!input0_name.empty() || !input0_shape.empty()) {
  85. CV_CheckFalse(input0_name.empty(), "input0_name cannot be empty");
  86. CV_CheckFalse(input0_shape.empty(), "input0_shape cannot be empty");
  87. input_names.push_back(input0_name);
  88. input_shapes.push_back(parseShape(input0_shape));
  89. }
  90. if (!input1_name.empty() || !input1_shape.empty()) {
  91. CV_CheckFalse(input1_name.empty(), "input1_name cannot be empty");
  92. CV_CheckFalse(input1_shape.empty(), "input1_shape cannot be empty");
  93. input_names.push_back(input1_name);
  94. input_shapes.push_back(parseShape(input1_shape));
  95. }
  96. if (!input2_name.empty() || !input2_shape.empty()) {
  97. CV_CheckFalse(input2_name.empty(), "input2_name cannot be empty");
  98. CV_CheckFalse(input2_shape.empty(), "input2_shape cannot be empty");
  99. input_names.push_back(input2_name);
  100. input_shapes.push_back(parseShape(input2_shape));
  101. }
  102. if (!input3_name.empty() || !input3_shape.empty()) {
  103. CV_CheckFalse(input3_name.empty(), "input3_name cannot be empty");
  104. CV_CheckFalse(input3_shape.empty(), "input3_shape cannot be empty");
  105. input_names.push_back(input3_name);
  106. input_shapes.push_back(parseShape(input3_shape));
  107. }
  108. if (!input4_name.empty() || !input4_shape.empty()) {
  109. CV_CheckFalse(input4_name.empty(), "input4_name cannot be empty");
  110. CV_CheckFalse(input4_shape.empty(), "input4_shape cannot be empty");
  111. input_names.push_back(input4_name);
  112. input_shapes.push_back(parseShape(input4_shape));
  113. }
  114. if (!input_names.empty() && !input_shapes.empty() && input_names.size() == input_shapes.size()) {
  115. ocvNet.setInputsNames(input_names);
  116. for (size_t i = 0; i < input_names.size(); i++) {
  117. Mat input(input_shapes[i], CV_32F);
  118. ocvNet.setInput(input, input_names[i]);
  119. }
  120. size_t dot_index = model.rfind('.');
  121. std::string graph_filename = model.substr(0, dot_index) + ".pbtxt";
  122. ocvNet.dumpToPbtxt(graph_filename);
  123. }
  124. return 0;
  125. }