vit_tracker.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // VitTracker
  2. // model: https://github.com/opencv/opencv_zoo/tree/main/models/object_tracking_vittrack
  3. #include <iostream>
  4. #include <cmath>
  5. #include <opencv2/dnn.hpp>
  6. #include <opencv2/imgproc.hpp>
  7. #include <opencv2/highgui.hpp>
  8. #include <opencv2/video.hpp>
  9. using namespace cv;
  10. using namespace cv::dnn;
  11. const char *keys =
  12. "{ help h | | Print help message }"
  13. "{ input i | | Full path to input video folder, the specific camera index. (empty for camera 0) }"
  14. "{ net | vitTracker.onnx | Path to onnx model of vitTracker.onnx}"
  15. "{ tracking_score_threshold t | 0.3 | Tracking score threshold. If a bbox of score >= 0.3, it is considered as found }"
  16. "{ backend | 0 | Choose one of computation backends: "
  17. "0: automatically (by default), "
  18. "1: Halide language (http://halide-lang.org/), "
  19. "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  20. "3: OpenCV implementation, "
  21. "4: VKCOM, "
  22. "5: CUDA },"
  23. "{ target | 0 | Choose one of target computation devices: "
  24. "0: CPU target (by default), "
  25. "1: OpenCL, "
  26. "2: OpenCL fp16 (half-float precision), "
  27. "3: VPU, "
  28. "4: Vulkan, "
  29. "6: CUDA, "
  30. "7: CUDA fp16 (half-float preprocess) }"
  31. ;
  32. static
  33. int run(int argc, char** argv)
  34. {
  35. // Parse command line arguments.
  36. CommandLineParser parser(argc, argv, keys);
  37. if (parser.has("help"))
  38. {
  39. parser.printMessage();
  40. return 0;
  41. }
  42. std::string inputName = parser.get<String>("input");
  43. std::string net = parser.get<String>("net");
  44. int backend = parser.get<int>("backend");
  45. int target = parser.get<int>("target");
  46. float tracking_score_threshold = parser.get<float>("tracking_score_threshold");
  47. Ptr<TrackerVit> tracker;
  48. try
  49. {
  50. TrackerVit::Params params;
  51. params.net = samples::findFile(net);
  52. params.backend = backend;
  53. params.target = target;
  54. params.tracking_score_threshold = tracking_score_threshold;
  55. tracker = TrackerVit::create(params);
  56. }
  57. catch (const cv::Exception& ee)
  58. {
  59. std::cerr << "Exception: " << ee.what() << std::endl;
  60. std::cout << "Can't load the network by using the following files:" << std::endl;
  61. std::cout << "net : " << net << std::endl;
  62. return 2;
  63. }
  64. const std::string winName = "vitTracker";
  65. namedWindow(winName, WINDOW_AUTOSIZE);
  66. // Open a video file or an image file or a camera stream.
  67. VideoCapture cap;
  68. if (inputName.empty() || (isdigit(inputName[0]) && inputName.size() == 1))
  69. {
  70. int c = inputName.empty() ? 0 : inputName[0] - '0';
  71. std::cout << "Trying to open camera #" << c << " ..." << std::endl;
  72. if (!cap.open(c))
  73. {
  74. std::cout << "Capture from camera #" << c << " didn't work. Specify -i=<video> parameter to read from video file" << std::endl;
  75. return 2;
  76. }
  77. }
  78. else if (inputName.size())
  79. {
  80. inputName = samples::findFileOrKeep(inputName);
  81. if (!cap.open(inputName))
  82. {
  83. std::cout << "Could not open: " << inputName << std::endl;
  84. return 2;
  85. }
  86. }
  87. // Read the first image.
  88. Mat image;
  89. cap >> image;
  90. if (image.empty())
  91. {
  92. std::cerr << "Can't capture frame!" << std::endl;
  93. return 2;
  94. }
  95. Mat image_select = image.clone();
  96. putText(image_select, "Select initial bounding box you want to track.", Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  97. putText(image_select, "And Press the ENTER key.", Point(0, 35), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  98. Rect selectRect = selectROI(winName, image_select);
  99. std::cout << "ROI=" << selectRect << std::endl;
  100. if (selectRect.empty())
  101. {
  102. std::cerr << "Invalid ROI!" << std::endl;
  103. return 2;
  104. }
  105. tracker->init(image, selectRect);
  106. TickMeter tickMeter;
  107. for (int count = 0; ; ++count)
  108. {
  109. cap >> image;
  110. if (image.empty())
  111. {
  112. std::cerr << "Can't capture frame " << count << ". End of video stream?" << std::endl;
  113. break;
  114. }
  115. Rect rect;
  116. tickMeter.start();
  117. bool ok = tracker->update(image, rect);
  118. tickMeter.stop();
  119. float score = tracker->getTrackingScore();
  120. std::cout << "frame " << count;
  121. if (ok) {
  122. std::cout << ": predicted score=" << score <<
  123. "\trect=" << rect <<
  124. "\ttime=" << tickMeter.getTimeMilli() << "ms" << std::endl;
  125. rectangle(image, rect, Scalar(0, 255, 0), 2);
  126. std::string timeLabel = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
  127. std::string scoreLabel = format("Score: %f", score);
  128. putText(image, timeLabel, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  129. putText(image, scoreLabel, Point(0, 35), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
  130. } else {
  131. std::cout << ": target lost" << std::endl;
  132. putText(image, "Target lost", Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 255));
  133. }
  134. imshow(winName, image);
  135. tickMeter.reset();
  136. int c = waitKey(1);
  137. if (c == 27 /*ESC*/ || c == 'q' || c == 'Q')
  138. break;
  139. }
  140. std::cout << "Exit" << std::endl;
  141. return 0;
  142. }
  143. int main(int argc, char **argv)
  144. {
  145. try
  146. {
  147. return run(argc, argv);
  148. }
  149. catch (const std::exception& e)
  150. {
  151. std::cerr << "FATAL: C++ exception: " << e.what() << std::endl;
  152. return 1;
  153. }
  154. }