silhouette.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from warnings import simplefilter
  2. import numpy as np
  3. from sklearn.metrics import silhouette_samples, silhouette_score
  4. from sklearn.preprocessing import LabelEncoder
  5. import wandb
  6. from wandb.integration.sklearn import utils
  7. # ignore all future warnings
  8. simplefilter(action="ignore", category=FutureWarning)
  9. def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans): # noqa: N803
  10. # Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels
  11. # TODO - keep/delete once we decide if we should train clusterers
  12. # or ask for trained models
  13. # clusterer.set_params(n_clusters=n_clusters, random_state=42)
  14. # cluster_labels = clusterer.fit_predict(X)
  15. cluster_labels = np.asarray(cluster_labels)
  16. labels = np.asarray(labels)
  17. le = LabelEncoder()
  18. _ = le.fit_transform(cluster_labels)
  19. n_clusters = len(np.unique(cluster_labels))
  20. # The silhouette_score gives the average value for all the samples.
  21. # This gives a perspective into the density and separation of the formed
  22. # clusters
  23. silhouette_avg = silhouette_score(X, cluster_labels, metric=metric)
  24. # Compute the silhouette scores for each sample
  25. sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)
  26. x_sil, y_sil, color_sil = [], [], []
  27. count, y_lower = 0, 10
  28. for i in range(n_clusters):
  29. # Aggregate the silhouette scores for samples belonging to
  30. # cluster i, and sort them
  31. ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
  32. ith_cluster_silhouette_values.sort()
  33. size_cluster_i = ith_cluster_silhouette_values.shape[0]
  34. y_upper = y_lower + size_cluster_i
  35. y_values = np.arange(y_lower, y_upper)
  36. for j in range(len(y_values)):
  37. y_sil.append(y_values[j])
  38. x_sil.append(ith_cluster_silhouette_values[j])
  39. color_sil.append(i)
  40. count += 1
  41. if utils.check_against_limit(count, "silhouette", utils.chart_limit):
  42. break
  43. # Compute the new y_lower for next plot
  44. y_lower = y_upper + 10 # 10 for the 0 samples
  45. if kmeans:
  46. centers = clusterer.cluster_centers_
  47. centerx = centers[:, 0]
  48. centery = centers[:, 1]
  49. else:
  50. centerx = [None] * len(color_sil)
  51. centery = [None] * len(color_sil)
  52. table = make_table(
  53. X[:, 0],
  54. X[:, 1],
  55. cluster_labels,
  56. centerx,
  57. centery,
  58. y_sil,
  59. x_sil,
  60. color_sil,
  61. silhouette_avg,
  62. )
  63. chart = wandb.visualize("wandb/silhouette_/v1", table)
  64. return chart
  65. def make_table(x, y, colors, centerx, centery, y_sil, x_sil, color_sil, silhouette_avg):
  66. columns = [
  67. "x",
  68. "y",
  69. "colors",
  70. "centerx",
  71. "centery",
  72. "y_sil",
  73. "x1",
  74. "x2",
  75. "color_sil",
  76. "silhouette_avg",
  77. ]
  78. data = [
  79. [
  80. x[i],
  81. y[i],
  82. colors[i],
  83. centerx[colors[i]],
  84. centery[colors[i]],
  85. y_sil[i],
  86. 0,
  87. x_sil[i],
  88. color_sil[i],
  89. silhouette_avg,
  90. ]
  91. for i in range(len(color_sil))
  92. ]
  93. table = wandb.Table(data=data, columns=columns)
  94. return table