KMeans.cs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. using o0;
  2. using o0.Num;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Threading.Tasks;
  7. using UnityEngine;
  8. using static ScreenLocate;
  9. namespace ZIM
  10. {
  11. public class KMeans
  12. {
  13. public System.Random Rand = new System.Random();
  14. // 用轮廓系数评价K值
  15. public List<PixelSpotAreaOld> InfraredSpotCluster(List<Vector2> spotPoint, List<Vector2> brightPoint, int spotMaxCount, Func<Vector2, bool> PointFilter = null)
  16. {
  17. if (PointFilter != null) // 筛掉屏幕外的点
  18. {
  19. var temp0 = new List<Vector2>(spotPoint.Count);
  20. var temp1 = new List<Vector2>(spotPoint.Count);
  21. foreach (var p in spotPoint)
  22. {
  23. if (PointFilter(p))
  24. temp0.Add(p);
  25. }
  26. foreach (var p in brightPoint)
  27. {
  28. if (PointFilter(p))
  29. temp1.Add(p);
  30. }
  31. spotPoint = temp0;
  32. brightPoint = temp1;
  33. }
  34. // 从1-k,计算多次聚类,再评估选出best k
  35. List<PixelSpotAreaOld>[] ClustersArray = new List<PixelSpotAreaOld>[spotMaxCount]; // 各个K对应的Clusters
  36. float[] silhouetteScores = new float[spotMaxCount]; // 轮廓系数
  37. for (int i = 0; i < spotMaxCount; i++)
  38. {
  39. int k = i + 1;
  40. var clusters = Cluster(spotPoint, brightPoint, k);
  41. Debug.Log("k: " + k);
  42. foreach (var j in clusters)
  43. {
  44. Debug.Log(j.Centroid + ", " + j.Radius);
  45. }
  46. silhouetteScores[i] = CalculateSilhouetteScore(spotPoint, clusters);
  47. ClustersArray[i] = clusters;
  48. Debug.Log("silhouetteScore: " + silhouetteScores[i]);
  49. }
  50. int best = silhouetteScores.MaxIndex();
  51. foreach (var j in ClustersArray[best])
  52. {
  53. Debug.Log("select: " + j.Centroid + ", " + j.Radius);
  54. }
  55. return ClustersArray[best];
  56. }
  57. // 计算轮廓系数
  58. public float CalculateSilhouetteScore(List<Vector2> points, List<PixelSpotAreaOld> clusters)
  59. {
  60. int numPoints = points.Count;
  61. float totalSilhouetteScore = 0f;
  62. Dictionary<Vector2, int> pointToClusterMap = new Dictionary<Vector2, int>();
  63. for (int i = 0; i < clusters.Count; i++)
  64. {
  65. foreach (var p in clusters[i].Pixels0)
  66. pointToClusterMap[p] = i;
  67. }
  68. foreach (var point in points)
  69. {
  70. int clusterIndex = pointToClusterMap[point];
  71. float a = AverageDistance(point, clusters[clusterIndex]);
  72. float b = float.MaxValue;
  73. for (int i = 0; i < clusters.Count; i++)
  74. {
  75. if (i == clusterIndex) continue;
  76. float distance = AverageDistance(point, clusters[i]);
  77. if (distance < b)
  78. {
  79. b = distance;
  80. }
  81. }
  82. float s = (b - a) / Mathf.Max(a, b);
  83. totalSilhouetteScore += s;
  84. }
  85. return totalSilhouetteScore / numPoints;
  86. }
  87. private float AverageDistance(Vector2 point, PixelSpotAreaOld cluster)
  88. {
  89. float totalDistance = 0f;
  90. foreach (var c in cluster.Pixels0)
  91. totalDistance += Vector2.Distance(point, c);
  92. return totalDistance / cluster.Pixels0.Count;
  93. }
  94. // KMeans聚类
  95. private List<PixelSpotAreaOld> Cluster(List<Vector2> spotPoint, List<Vector2> brightPoint, int k, int maxIterations = 10)
  96. {
  97. if (spotPoint.Count < k)
  98. return new List<PixelSpotAreaOld>();
  99. var clusters = InitializeClusters(spotPoint, k);
  100. bool centroidsChanged = true;
  101. int iterations = 0;
  102. while (centroidsChanged && iterations < maxIterations)
  103. {
  104. foreach (var i in clusters)
  105. i.Clear();
  106. foreach (var point in spotPoint)
  107. {
  108. int closestCentroidIndex = GetClosestClusterIndex(point, clusters);
  109. clusters[closestCentroidIndex].Add0(point);
  110. }
  111. centroidsChanged = false;
  112. foreach (var i in clusters)
  113. {
  114. if (i.UpdateCentroid())
  115. centroidsChanged = true;
  116. }
  117. iterations++;
  118. }
  119. // 聚类完成,添加泛光点
  120. foreach (var point in brightPoint)
  121. {
  122. int closestCentroidIndex = GetClosestClusterIndex(point, clusters);
  123. clusters[closestCentroidIndex].Add1(point);
  124. }
  125. return clusters;
  126. }
  127. private List<PixelSpotAreaOld> InitializeClusters(List<Vector2> points, int k)
  128. {
  129. Vector2[] centroids = new Vector2[k];
  130. HashSet<int> chosenIndices = new HashSet<int>();
  131. for (int i = 0; i < k; i++)
  132. {
  133. int index;
  134. do
  135. {
  136. index = Rand.Next(points.Count);
  137. } while (chosenIndices.Contains(index));
  138. chosenIndices.Add(index);
  139. centroids[i] = points[index];
  140. }
  141. List<PixelSpotAreaOld> clusters = new List<PixelSpotAreaOld>(k);
  142. foreach (var i in centroids)
  143. clusters.Add(new PixelSpotAreaOld(i));
  144. return clusters;
  145. }
  146. private int GetClosestClusterIndex(Vector2 point, List<PixelSpotAreaOld> centroids)
  147. {
  148. int closestIndex = 0;
  149. float minDistance = Vector2.Distance(point, centroids[0].Centroid);
  150. for (int i = 1; i < centroids.Count; i++)
  151. {
  152. float distance = Vector2.Distance(point, centroids[i].Centroid);
  153. if (distance < minDistance)
  154. {
  155. minDistance = distance;
  156. closestIndex = i;
  157. }
  158. }
  159. return closestIndex;
  160. }
  161. }
  162. }