test_trackers.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. //#define DEBUG_TEST
  6. #ifdef DEBUG_TEST
  7. #include <opencv2/highgui.hpp>
  8. #endif
  9. namespace opencv_test { namespace {
  10. //using namespace cv::tracking;
  11. #define TESTSET_NAMES testing::Values("david", "dudek", "faceocc2")
  12. const string TRACKING_DIR = "tracking";
  13. const string FOLDER_IMG = "data";
  14. const string FOLDER_OMIT_INIT = "initOmit";
  15. #include "test_trackers.impl.hpp"
  16. //[TESTDATA]
  17. PARAM_TEST_CASE(DistanceAndOverlap, string, int)
  18. {
  19. string dataset;
  20. int numFramesLimit;
  21. virtual void SetUp()
  22. {
  23. dataset = GET_PARAM(0);
  24. numFramesLimit = GET_PARAM(1);
  25. }
  26. };
  27. TEST_P(DistanceAndOverlap, MIL)
  28. {
  29. TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .65f, NoTransform);
  30. test.run(numFramesLimit);
  31. }
  32. TEST_P(DistanceAndOverlap, Shifted_Data_MIL)
  33. {
  34. TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .6f, CenterShiftLeft);
  35. test.run(numFramesLimit);
  36. }
  37. /***************************************************************************************/
  38. //Tests with scaled initial window
  39. TEST_P(DistanceAndOverlap, Scaled_Data_MIL)
  40. {
  41. TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .7f, Scale_1_1);
  42. test.run(numFramesLimit);
  43. }
  44. TEST_P(DistanceAndOverlap, GOTURN)
  45. {
  46. std::string model = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.prototxt");
  47. std::string weights = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.caffemodel", false);
  48. cv::TrackerGOTURN::Params params;
  49. params.modelTxt = model;
  50. params.modelBin = weights;
  51. TrackerTest<Tracker, Rect> test(TrackerGOTURN::create(params), dataset, 35, .35f, NoTransform);
  52. test.run(numFramesLimit);
  53. }
  54. INSTANTIATE_TEST_CASE_P(Tracking, DistanceAndOverlap,
  55. testing::Combine(
  56. TESTSET_NAMES,
  57. testing::Values(0)
  58. )
  59. );
  60. INSTANTIATE_TEST_CASE_P(Tracking5Frames, DistanceAndOverlap,
  61. testing::Combine(
  62. TESTSET_NAMES,
  63. testing::Values(5)
  64. )
  65. );
  66. static bool checkIOU(const Rect& r0, const Rect& r1, double threshold)
  67. {
  68. int interArea = (r0 & r1).area();
  69. double iouVal = (interArea * 1.0 )/ (r0.area() + r1.area() - interArea);;
  70. if (iouVal > threshold)
  71. return true;
  72. else
  73. {
  74. std::cout <<"Unmatched IOU: expect IOU val ("<<iouVal <<") > the IOU threadhold ("<<threshold<<")! Box 0 is "
  75. << r0 <<", and Box 1 is "<<r1<< std::endl;
  76. return false;
  77. }
  78. }
  79. static void checkTrackingAccuracy(cv::Ptr<Tracker>& tracker, double iouThreshold = 0.7)
  80. {
  81. // Template image
  82. Mat img0 = imread(findDataFile("tracking/bag/00000001.jpg"), 1);
  83. // Tracking image sequence.
  84. std::vector<Mat> imgs;
  85. imgs.push_back(imread(findDataFile("tracking/bag/00000002.jpg"), 1));
  86. imgs.push_back(imread(findDataFile("tracking/bag/00000003.jpg"), 1));
  87. imgs.push_back(imread(findDataFile("tracking/bag/00000004.jpg"), 1));
  88. imgs.push_back(imread(findDataFile("tracking/bag/00000005.jpg"), 1));
  89. imgs.push_back(imread(findDataFile("tracking/bag/00000006.jpg"), 1));
  90. cv::Rect roi(325, 164, 100, 100);
  91. std::vector<Rect> targetRois;
  92. targetRois.push_back(cv::Rect(278, 133, 99, 104));
  93. targetRois.push_back(cv::Rect(293, 88, 93, 110));
  94. targetRois.push_back(cv::Rect(287, 76, 89, 116));
  95. targetRois.push_back(cv::Rect(297, 74, 82, 122));
  96. targetRois.push_back(cv::Rect(311, 83, 78, 125));
  97. tracker->init(img0, roi);
  98. CV_Assert(targetRois.size() == imgs.size());
  99. for (int i = 0; i < (int)imgs.size(); i++)
  100. {
  101. bool res = tracker->update(imgs[i], roi);
  102. ASSERT_TRUE(res);
  103. ASSERT_TRUE(checkIOU(roi, targetRois[i], iouThreshold)) << cv::format("Fail at img %d.",i);
  104. }
  105. }
  106. TEST(GOTURN, accuracy)
  107. {
  108. std::string model = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.prototxt");
  109. std::string weights = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.caffemodel", false);
  110. cv::TrackerGOTURN::Params params;
  111. params.modelTxt = model;
  112. params.modelBin = weights;
  113. cv::Ptr<Tracker> tracker = TrackerGOTURN::create(params);
  114. // TODO! GOTURN have low accuracy. Try to remove this api at 5.x.
  115. checkTrackingAccuracy(tracker, 0.08);
  116. }
  117. TEST(DaSiamRPN, accuracy)
  118. {
  119. std::string model = cvtest::findDataFile("dnn/onnx/models/dasiamrpn_model.onnx", false);
  120. std::string kernel_r1 = cvtest::findDataFile("dnn/onnx/models/dasiamrpn_kernel_r1.onnx", false);
  121. std::string kernel_cls1 = cvtest::findDataFile("dnn/onnx/models/dasiamrpn_kernel_cls1.onnx", false);
  122. cv::TrackerDaSiamRPN::Params params;
  123. params.model = model;
  124. params.kernel_r1 = kernel_r1;
  125. params.kernel_cls1 = kernel_cls1;
  126. cv::Ptr<Tracker> tracker = TrackerDaSiamRPN::create(params);
  127. checkTrackingAccuracy(tracker, 0.7);
  128. }
  129. TEST(NanoTrack, accuracy_NanoTrack_V1)
  130. {
  131. std::string backbonePath = cvtest::findDataFile("dnn/onnx/models/nanotrack_backbone_sim.onnx", false);
  132. std::string neckheadPath = cvtest::findDataFile("dnn/onnx/models/nanotrack_head_sim.onnx", false);
  133. cv::TrackerNano::Params params;
  134. params.backbone = backbonePath;
  135. params.neckhead = neckheadPath;
  136. cv::Ptr<Tracker> tracker = TrackerNano::create(params);
  137. checkTrackingAccuracy(tracker);
  138. }
  139. TEST(NanoTrack, accuracy_NanoTrack_V2)
  140. {
  141. std::string backbonePath = cvtest::findDataFile("dnn/onnx/models/nanotrack_backbone_sim_v2.onnx", false);
  142. std::string neckheadPath = cvtest::findDataFile("dnn/onnx/models/nanotrack_head_sim_v2.onnx", false);
  143. cv::TrackerNano::Params params;
  144. params.backbone = backbonePath;
  145. params.neckhead = neckheadPath;
  146. cv::Ptr<Tracker> tracker = TrackerNano::create(params);
  147. checkTrackingAccuracy(tracker, 0.69);
  148. }
  149. TEST(vittrack, accuracy_vittrack)
  150. {
  151. std::string model = cvtest::findDataFile("dnn/onnx/models/vitTracker.onnx");
  152. cv::TrackerVit::Params params;
  153. params.net = model;
  154. cv::Ptr<Tracker> tracker = TrackerVit::create(params);
  155. checkTrackingAccuracy(tracker, 0.64);
  156. }
  157. }} // namespace opencv_test::