dunn_index.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from itertools import combinations
  15. import torch
  16. from torch import Tensor
  17. def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> tuple[Tensor, Tensor]:
  18. """Update and return variables required to compute the Dunn index.
  19. Args:
  20. data: feature vectors of shape (n_samples, n_features)
  21. labels: cluster labels
  22. p: p-norm (distance metric)
  23. Returns:
  24. intercluster_distance: intercluster distances
  25. max_intracluster_distance: max intracluster distances
  26. """
  27. unique_labels, inverse_indices = labels.unique(return_inverse=True)
  28. clusters = [data[inverse_indices == label_idx] for label_idx in range(len(unique_labels))]
  29. centroids = [c.mean(dim=0) for c in clusters]
  30. intercluster_distance = torch.linalg.norm(
  31. torch.stack([a - b for a, b in combinations(centroids, 2)], dim=0), ord=p, dim=1
  32. )
  33. max_intracluster_distance = torch.stack([
  34. torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)
  35. ])
  36. return intercluster_distance, max_intracluster_distance
  37. def _dunn_index_compute(intercluster_distance: Tensor, max_intracluster_distance: Tensor) -> Tensor:
  38. """Compute the Dunn index based on updated state.
  39. Args:
  40. intercluster_distance: intercluster distances
  41. max_intracluster_distance: max intracluster distances
  42. Returns:
  43. scalar tensor with the dunn index
  44. """
  45. return intercluster_distance.min() / max_intracluster_distance.max()
  46. def dunn_index(data: Tensor, labels: Tensor, p: float = 2) -> Tensor:
  47. """Compute the Dunn index.
  48. Args:
  49. data: feature vectors
  50. labels: cluster labels
  51. p: p-norm used for distance metric
  52. Returns:
  53. scalar tensor with the dunn index
  54. Example:
  55. >>> from torchmetrics.functional.clustering import dunn_index
  56. >>> data = torch.tensor([[0, 0], [0.5, 0], [1, 0], [0.5, 1]])
  57. >>> labels = torch.tensor([0, 0, 0, 1])
  58. >>> dunn_index(data, labels)
  59. tensor(2.)
  60. """
  61. pairwise_distance, max_distance = _dunn_index_update(data, labels, p)
  62. return _dunn_index_compute(pairwise_distance, max_distance)