test_knearest.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #include "test_precomp.hpp"
  5. namespace opencv_test { namespace {
  6. using cv::ml::TrainData;
  7. using cv::ml::EM;
  8. using cv::ml::KNearest;
  9. TEST(ML_KNearest, accuracy)
  10. {
  11. int sizesArr[] = { 500, 700, 800 };
  12. int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
  13. Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
  14. vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
  15. Mat means;
  16. vector<Mat> covs;
  17. defaultDistribs( means, covs );
  18. generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
  19. Mat testData( pointsCount, 2, CV_32FC1 );
  20. Mat testLabels;
  21. generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
  22. {
  23. SCOPED_TRACE("Default");
  24. Mat bestLabels;
  25. float err = 1000;
  26. Ptr<KNearest> knn = KNearest::create();
  27. knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
  28. knn->findNearest(testData, 4, bestLabels);
  29. EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
  30. EXPECT_LE(err, 0.01f);
  31. }
  32. {
  33. SCOPED_TRACE("KDTree");
  34. Mat neighborIndexes;
  35. Mat neighborResponses;
  36. Mat dists;
  37. float err = 1000;
  38. Ptr<KNearest> knn = KNearest::create();
  39. knn->setAlgorithmType(KNearest::KDTREE);
  40. knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
  41. knn->findNearest(testData, 4, neighborIndexes, neighborResponses, dists);
  42. EXPECT_EQ(neighborIndexes.size(), Size(4, pointsCount));
  43. EXPECT_EQ(neighborResponses.size(), Size(4, pointsCount * 2));
  44. EXPECT_EQ(dists.size(), Size(4, pointsCount));
  45. Mat bestLabels;
  46. // The output of the KDTree are the neighbor indexes, not actual class labels
  47. // so we need to do some extra work to get actual predictions
  48. for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){
  49. vector<float> labels;
  50. for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) {
  51. labels.push_back(trainLabels.at<float>(neighborIndexes.row(row_num).at<int>(0, index) , 0));
  52. }
  53. // computing the mode of the output class predictions to determine overall prediction
  54. std::vector<int> histogram(3,0);
  55. for( int i=0; i<3; ++i )
  56. ++histogram[ static_cast<int>(labels[i]) ];
  57. int bestLabel = static_cast<int>(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin());
  58. bestLabels.push_back(bestLabel);
  59. }
  60. bestLabels.convertTo(bestLabels, testLabels.type());
  61. EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
  62. EXPECT_LE(err, 0.01f);
  63. }
  64. }
  65. TEST(ML_KNearest, regression_12347)
  66. {
  67. Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);
  68. Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);
  69. Ptr<KNearest> knn = KNearest::create();
  70. knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);
  71. Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);
  72. Mat zBestLabels, neighbours, dist;
  73. // check output shapes:
  74. int K = 16, Kexp = std::min(K, xTrainData.rows);
  75. knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
  76. EXPECT_EQ(xTestData.rows, zBestLabels.rows);
  77. EXPECT_EQ(neighbours.cols, Kexp);
  78. EXPECT_EQ(dist.cols, Kexp);
  79. // see if the result is still correct:
  80. K = 2;
  81. knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
  82. EXPECT_EQ(1, zBestLabels.at<float>(0,0));
  83. EXPECT_EQ(2, zBestLabels.at<float>(1,0));
  84. }
  85. TEST(ML_KNearest, bug_11877)
  86. {
  87. Mat trainData = (Mat_<float>(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4);
  88. Mat trainLabels = (Mat_<float>(5,1) << 0, 0, 1, 1, 1);
  89. Ptr<KNearest> knnKdt = KNearest::create();
  90. knnKdt->setAlgorithmType(KNearest::KDTREE);
  91. knnKdt->setIsClassifier(true);
  92. knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
  93. Mat testData = (Mat_<float>(2,2) << 3.1, 3.1, 4, 4.1);
  94. Mat testLabels = (Mat_<int>(2,1) << 0, 1);
  95. Mat result;
  96. knnKdt->findNearest(testData, 1, result);
  97. EXPECT_EQ(1, int(result.at<int>(0, 0)));
  98. EXPECT_EQ(2, int(result.at<int>(1, 0)));
  99. EXPECT_EQ(0, trainLabels.at<int>(result.at<int>(0, 0), 0));
  100. }
  101. }} // namespace