test_emd.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html
  4. #include "opencv2/imgproc.hpp"
  5. #include "test_precomp.hpp"
  6. using namespace cv;
  7. using namespace std;
  8. namespace opencv_test { namespace {
  9. //==============================================================================
  10. // Utility
  11. template <typename T>
  12. inline T sqr(T val)
  13. {
  14. return val * val;
  15. }
  16. inline static float calcEMD(Mat w1, Mat w2, Mat& flow, int dist, int dims)
  17. {
  18. float mass1 = 0.f, mass2 = 0.f, work = 0.f;
  19. for (int i = 0; i < flow.rows; ++i)
  20. {
  21. mass1 += w1.at<float>(i, 0);
  22. for (int j = 0; j < flow.cols; ++j)
  23. {
  24. if (i == 0)
  25. mass2 += w2.at<float>(j, 0);
  26. float dist_ = 0.f;
  27. switch (dist)
  28. {
  29. case DIST_L1:
  30. {
  31. for (int k = 1; k <= dims; ++k)
  32. {
  33. dist_ += abs(w1.at<float>(i, k) - w2.at<float>(j, k));
  34. }
  35. break;
  36. }
  37. case DIST_L2:
  38. {
  39. for (int k = 1; k <= dims; ++k)
  40. {
  41. dist_ += sqr(w1.at<float>(i, k) - w2.at<float>(j, k));
  42. }
  43. dist_ = sqrt(dist_);
  44. break;
  45. }
  46. case DIST_C:
  47. {
  48. for (int k = 1; k <= dims; ++k)
  49. {
  50. const float val = abs(w1.at<float>(i, k) - w2.at<float>(j, k));
  51. if (val > dist_)
  52. dist_ = val;
  53. }
  54. break;
  55. }
  56. }
  57. const float weight = flow.at<float>(i, j);
  58. work += dist_ * weight;
  59. }
  60. }
  61. return work / max(mass1, mass2);
  62. }
  63. //==============================================================================
  64. TEST(Imgproc_EMD, regression)
  65. {
  66. // input data
  67. const float M = 10000;
  68. Matx<float, 4, 1> w1 {50, 60, 50, 50};
  69. Matx<float, 5, 1> w2 {30, 20, 70, 30, 60};
  70. Matx<float, 4, 5> cost {16, 16, 13, 22, 17, 14, 14, 13, 19, 15,
  71. 19, 19, 20, 23, M, M, 0, M, 0, 0};
  72. // expected results
  73. const double emd0 = 2460. / 210;
  74. Matx<float, 4, 5> flow0 {0, 0, 50, 0, 0, 0, 0, 20, 0, 40, 30, 20, 0, 0, 0, 0, 0, 0, 30, 20};
  75. // basic call with cost
  76. {
  77. float emd = 0.f;
  78. ASSERT_NO_THROW(emd = EMD(w1, w2, DIST_USER, cost));
  79. EXPECT_NEAR(emd, emd0, 1e-6 * emd0);
  80. }
  81. // basic call with cost and flow output
  82. {
  83. Mat flow;
  84. float emd = 0.f;
  85. ASSERT_NO_THROW(emd = EMD(w1, w2, DIST_USER, cost, nullptr, flow));
  86. EXPECT_NEAR(emd, emd0, 1e-6 * emd0);
  87. EXPECT_MAT_NEAR(Mat(flow0), flow, 1e-6);
  88. }
  89. // no cost and DIST_USER - error
  90. {
  91. Mat flow;
  92. EXPECT_THROW(EMD(w1, w2, DIST_USER, noArray(), nullptr, flow), cv::Exception);
  93. EXPECT_THROW(EMD(w1, w2, DIST_USER), cv::Exception);
  94. }
  95. }
  96. TEST(Imgproc_EMD, distance_types)
  97. {
  98. // 1D (sum = 210)
  99. Matx<float, 4, 2> w1 {50, 1, 60, 2, 50, 3, 50, 4};
  100. Matx<float, 5, 2> w2 {30, 1, 20, 2, 70, 3, 30, 4, 60, 5};
  101. // 2D (sum = 210)
  102. Matx<float, 4, 3> w3 {50, 0, 0, 60, 0, 1, 50, 1, 0, 50, 1, 1};
  103. Matx<float, 5, 3> w4 {20, 0, 1, 70, 1, 0, 30, 1, 1, 60, 2, 2, 30, 3, 3};
  104. // basic call with all distance types
  105. {
  106. const vector<DistanceTypes> good_types {DIST_L1, DIST_L2, DIST_C};
  107. for (const auto& dt : good_types)
  108. {
  109. SCOPED_TRACE(cv::format("dt=%d", dt));
  110. float emd = 0.f;
  111. Mat flow;
  112. // 1D
  113. {
  114. ASSERT_NO_THROW(emd = EMD(w1, w2, dt, noArray(), nullptr, flow));
  115. const float emd0 = calcEMD(Mat(w1), Mat(w2), flow, dt, 1);
  116. EXPECT_NEAR(emd0, emd, 1e-6);
  117. }
  118. // 2D
  119. {
  120. ASSERT_NO_THROW(emd = EMD(w3, w4, dt, noArray(), nullptr, flow));
  121. const float emd0 = calcEMD(Mat(w3), Mat(w4), flow, dt, 2);
  122. EXPECT_NEAR(emd0, emd, 1e-6);
  123. }
  124. }
  125. }
  126. }
  127. typedef testing::TestWithParam<int> Imgproc_EMD_dist;
  128. TEST_P(Imgproc_EMD_dist, random_flow_verify)
  129. {
  130. const int dist = GetParam();
  131. for (size_t iter = 0; iter < 100; ++iter)
  132. {
  133. SCOPED_TRACE(cv::format("iter=%zu", iter));
  134. RNG& rng = TS::ptr()->get_rng();
  135. const int dims = rng.uniform(1, 10);
  136. Mat w1(rng.uniform(1, 10), dims + 1, CV_32FC1);
  137. Mat w2(rng.uniform(1, 10), dims + 1, CV_32FC1);
  138. // weights > 0
  139. {
  140. Mat w1_weights = w1.col(0);
  141. Mat w2_weights = w2.col(0);
  142. cvtest::randUni(rng, w1_weights, 0, 100);
  143. cvtest::randUni(rng, w2_weights, 0, 100);
  144. }
  145. // coord
  146. {
  147. Mat w1_coord = w1.colRange(1, dims + 1);
  148. Mat w2_coord = w2.colRange(1, dims + 1);
  149. cvtest::randUni(rng, w1_coord, -10, +10);
  150. cvtest::randUni(rng, w2_coord, -10, +10);
  151. }
  152. float emd1 = 0.f, emd2 = 0.f;
  153. const float eps = 1e-5f;
  154. Mat flow;
  155. {
  156. ASSERT_NO_THROW(emd1 = EMD(w1, w2, dist, noArray(), nullptr, flow));
  157. const float emd0 = calcEMD(w1, w2, flow, dist, dims);
  158. EXPECT_NEAR(emd0, emd1, eps);
  159. }
  160. {
  161. ASSERT_NO_THROW(emd2 = EMD(w2, w1, dist, noArray(), nullptr, flow));
  162. const float emd0 = calcEMD(w2, w1, flow, dist, dims);
  163. EXPECT_NEAR(emd0, emd2, eps);
  164. }
  165. EXPECT_NEAR(emd1, emd2, eps);
  166. }
  167. }
  168. INSTANTIATE_TEST_CASE_P(, Imgproc_EMD_dist, testing::Values(DIST_L1, DIST_L2, DIST_C));
  169. TEST(Imgproc_EMD, invalid)
  170. {
  171. Matx<float, 4, 2> w1 {50, 1, 60, 2, 50, 3, 50, 4};
  172. Matx<float, 5, 2> w2 {30, 1, 20, 2, 70, 3, 30, 4, 60, 5};
  173. // empty signature
  174. {
  175. Mat empty;
  176. EXPECT_THROW(EMD(empty, w2, DIST_USER), cv::Exception);
  177. EXPECT_THROW(EMD(w1, empty, DIST_USER), cv::Exception);
  178. }
  179. // zero total weight, negative weight
  180. {
  181. Matx<float, 3, 1> wz {0, 0, 0};
  182. Matx<float, 3, 2> wz1 {0, 1, 0, 2, 0, 3};
  183. Matx<float, 3, 1> wn {0, 3, -2};
  184. Matx<float, 3, 2> wn1 {0, 1, 3, 2, -2, 3};
  185. EXPECT_THROW(EMD(wz, w2, DIST_USER), cv::Exception);
  186. EXPECT_THROW(EMD(wz1, w2, DIST_USER), cv::Exception);
  187. EXPECT_THROW(EMD(wn, w2, DIST_USER), cv::Exception);
  188. EXPECT_THROW(EMD(wn1, w2, DIST_USER), cv::Exception);
  189. }
  190. // user distance type, but no cost matrix provided or is wrong
  191. {
  192. Mat cost(3, 3, CV_32FC1, Scalar::all(0)), cost8u(4, 5, CV_8UC1, Scalar::all(0)), empty;
  193. EXPECT_THROW(EMD(w1, w2, DIST_USER, noArray()), cv::Exception);
  194. EXPECT_THROW(EMD(w1, w2, DIST_USER, empty), cv::Exception);
  195. EXPECT_THROW(EMD(w1, w2, DIST_USER, cost8u), cv::Exception);
  196. EXPECT_THROW(EMD(w1, w2, DIST_USER, cost), cv::Exception);
  197. }
  198. // lower_bound is set together with cost
  199. {
  200. Mat cost(4, 5, CV_32FC1, Scalar::all(0));
  201. float bound = 0.f;
  202. EXPECT_THROW(EMD(w1, w2, DIST_USER, cost, &bound), cv::Exception);
  203. }
  204. // zero dimensions with non-user distance type
  205. const vector<DistanceTypes> good_types {DIST_L1, DIST_L2, DIST_C};
  206. for (const auto& dt : good_types)
  207. {
  208. SCOPED_TRACE(cv::format("dt=%d", dt));
  209. Matx<float, 4, 1> w01 {20, 30, 40, 50};
  210. Matx<float, 5, 1> w02 {20, 30, 40, 50, 10};
  211. EXPECT_THROW(EMD(w01, w02, dt), cv::Exception);
  212. }
  213. // wrong distance type
  214. const vector<DistanceTypes> bad_types {DIST_L12, DIST_FAIR, DIST_WELSCH, DIST_HUBER};
  215. for (const auto& dt : bad_types)
  216. {
  217. SCOPED_TRACE(cv::format("dt=%d", dt));
  218. EXPECT_THROW(EMD(w1, w2, dt), cv::Exception);
  219. }
  220. }
  221. }} // namespace opencv_test