DbscanAlgorithm.cs 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. using System;
  2. using System.Linq;
  3. using System.Threading;
  4. using System.Threading.Tasks;
  5. using DbscanImplementation.Eventing;
  6. using DbscanImplementation.ResultBuilding;
  7. namespace DbscanImplementation
  8. {
  9. /// <summary>
  10. /// DBSCAN algorithm implementation type
  11. /// </summary>
  12. /// <typeparam name="TF">Takes dataset item row (features, preferences, vector)</typeparam>
  13. public class DbscanAlgorithm<TF>
  14. {
  15. /// <summary>
  16. /// distance calculation metric function between two feature
  17. /// </summary>
  18. public readonly Func<TF, TF, double> MetricFunction;
  19. /// <summary>
  20. /// Curried Function that checking two feature as neighbor
  21. /// </summary>
  22. public readonly Func<TF, double, Func<DbscanPoint<TF>, bool>> RegionQueryPredicate;
  23. private readonly IDbscanEventPublisher publisher;
  24. /// <summary>
  25. /// Takes metric function to compute distances between two <see cref="TF"/>
  26. /// </summary>
  27. /// <param name="metricFunc"></param>
  28. public DbscanAlgorithm(Func<TF, TF, double> metricFunc)
  29. {
  30. MetricFunction = metricFunc ?? throw new ArgumentNullException(nameof(metricFunc));
  31. RegionQueryPredicate =
  32. (mainFeature, epsilon) => relatedPoint => MetricFunction(mainFeature, relatedPoint.Feature) <= epsilon;
  33. publisher = new EmptyDbscanEventPublisher();
  34. }
  35. public DbscanAlgorithm(Func<TF, TF, double> metricFunc, IDbscanEventPublisher publisher)
  36. : this(metricFunc)
  37. {
  38. this.publisher = publisher ?? throw new ArgumentNullException(nameof(publisher));
  39. }
  40. public Task<DbscanResult<TF>> ComputeClusterDbscanAsync(TF[] allPoints, double epsilon, int minimumPoints,
  41. CancellationToken cancellationToken)
  42. {
  43. return Task.Factory.StartNew(() =>
  44. ComputeClusterDbscan(allPoints, epsilon, minimumPoints),
  45. cancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Current);
  46. }
  47. /// <summary>
  48. /// Performs the DBSCAN clustering algorithm.
  49. /// </summary>
  50. /// <param name="allPoints">feature set</param>
  51. /// <param name="epsilon">Desired region ball radius</param>
  52. /// <param name="minimumPoints">Minimum number of points to be in a region</param>
  53. /// <returns>Overall result of cluster compute operation</returns>
  54. public DbscanResult<TF> ComputeClusterDbscan(TF[] allPoints, double epsilon, int minimumPoints)
  55. {
  56. if (epsilon <= 0)
  57. {
  58. throw new ArgumentOutOfRangeException(nameof(epsilon), "Must be greater than zero");
  59. }
  60. if (minimumPoints <= 0)
  61. {
  62. throw new ArgumentOutOfRangeException(nameof(minimumPoints), "Must be greater than zero");
  63. }
  64. var allPointsDbscan = allPoints.Select(x => new DbscanPoint<TF>(x)).ToArray();
  65. int clusterId = 0;
  66. var computeId = Guid.NewGuid();
  67. publisher.Publish(new ComputeStarted(computeId));
  68. for (int i = 0; i < allPointsDbscan.Length; i++)
  69. {
  70. var currentPoint = allPointsDbscan[i];
  71. if (currentPoint.PointType.HasValue)
  72. {
  73. publisher.Publish(new PointAlreadyProcessed<TF>(currentPoint));
  74. continue;
  75. }
  76. publisher.Publish(new PointProcessStarted<TF>(currentPoint));
  77. publisher.Publish(new RegionQueryStarted<TF>(currentPoint, epsilon, minimumPoints));
  78. var neighborPoints = RegionQuery(allPointsDbscan, currentPoint.Feature, epsilon);
  79. publisher.Publish(new RegionQueryFinished<TF>(currentPoint, neighborPoints));
  80. if (neighborPoints.Length < minimumPoints)
  81. {
  82. currentPoint.PointType = PointType.Noise;
  83. publisher.Publish(new PointTypeAssigned<TF>(currentPoint, PointType.Noise));
  84. publisher.Publish(new PointProcessFinished<TF>(currentPoint));
  85. continue;
  86. }
  87. clusterId++;
  88. currentPoint.ClusterId = clusterId;
  89. currentPoint.PointType = PointType.Core;
  90. publisher.Publish(new PointTypeAssigned<TF>(currentPoint, PointType.Core));
  91. publisher.Publish(new PointProcessFinished<TF>(currentPoint));
  92. publisher.Publish(
  93. new ClusteringStarted<TF>(currentPoint, neighborPoints, clusterId, epsilon, minimumPoints));
  94. ExpandCluster(allPointsDbscan, neighborPoints, clusterId, epsilon, minimumPoints);
  95. publisher.Publish(
  96. new ClusteringFinished<TF>(currentPoint, neighborPoints, clusterId, epsilon, minimumPoints));
  97. }
  98. publisher.Publish(new ComputeFinished(computeId));
  99. var resultBuilder = new DbscanResultBuilder<TF>();
  100. foreach (var p in allPointsDbscan)
  101. {
  102. resultBuilder.Process(p);
  103. }
  104. return resultBuilder.Result;
  105. }
  106. /// <summary>
  107. /// Checks current cluster for expanding it
  108. /// </summary>
  109. /// <param name="allPoints">Dataset</param>
  110. /// <param name="neighborPoints">other points in same region</param>
  111. /// <param name="clusterId">given clusterId</param>
  112. /// <param name="epsilon">Desired region ball radius</param>
  113. /// <param name="minimumPoints">Minimum number of points to be in a region</param>
  114. private void ExpandCluster(DbscanPoint<TF>[] allPoints, DbscanPoint<TF>[] neighborPoints,
  115. int clusterId, double epsilon, int minimumPoints)
  116. {
  117. for (int i = 0; i < neighborPoints.Length; i++)
  118. {
  119. var currentPoint = neighborPoints[i];
  120. publisher.Publish(new PointProcessStarted<TF>(currentPoint));
  121. if (currentPoint.PointType == PointType.Noise)
  122. {
  123. currentPoint.ClusterId = clusterId;
  124. currentPoint.PointType = PointType.Border;
  125. publisher.Publish(new PointTypeAssigned<TF>(currentPoint, PointType.Border));
  126. publisher.Publish(new PointProcessFinished<TF>(currentPoint));
  127. continue;
  128. }
  129. if (currentPoint.PointType.HasValue)
  130. {
  131. publisher.Publish(new PointAlreadyProcessed<TF>(currentPoint));
  132. continue;
  133. }
  134. currentPoint.ClusterId = clusterId;
  135. publisher.Publish(new RegionQueryStarted<TF>(currentPoint, epsilon, minimumPoints));
  136. var otherNeighborPoints = RegionQuery(allPoints, currentPoint.Feature, epsilon);
  137. publisher.Publish(new RegionQueryFinished<TF>(currentPoint, otherNeighborPoints));
  138. if (otherNeighborPoints.Length < minimumPoints)
  139. {
  140. currentPoint.PointType = PointType.Border;
  141. publisher.Publish(new PointTypeAssigned<TF>(currentPoint, PointType.Border));
  142. publisher.Publish(new PointProcessFinished<TF>(currentPoint));
  143. continue;
  144. }
  145. currentPoint.PointType = PointType.Core;
  146. publisher.Publish(new PointTypeAssigned<TF>(currentPoint, PointType.Core));
  147. publisher.Publish(new PointProcessFinished<TF>(currentPoint));
  148. neighborPoints = neighborPoints.Union(otherNeighborPoints).ToArray();
  149. }
  150. }
  151. /// <summary>
  152. /// Checks and searchs neighbor points for given point
  153. /// </summary>
  154. /// <param name="allPoints">Dbscan points converted from feature set</param>
  155. /// <param name="mainFeature">Focused feature to be searched neighbors</param>
  156. /// <param name="epsilon">Desired region ball radius</param>
  157. /// <returns>Calculated neighbor points</returns>
  158. public DbscanPoint<TF>[] RegionQuery(DbscanPoint<TF>[] allPoints, TF mainFeature, double epsilon)
  159. {
  160. return allPoints.Where(RegionQueryPredicate(mainFeature, epsilon)).ToArray();
  161. }
  162. }
  163. }