cluster_accuracy.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from torchmetrics.functional.classification import multiclass_confusion_matrix
  17. from torchmetrics.functional.clustering.utils import check_cluster_labels
  18. from torchmetrics.utilities.imports import _TORCH_LINEAR_ASSIGNMENT_AVAILABLE
  19. if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
  20. __doctest_skip__ = ["cluster_accuracy"]
  21. def _cluster_accuracy_compute(confmat: Tensor) -> Tensor:
  22. """Computes the clustering accuracy from a confusion matrix."""
  23. from torch_linear_assignment import batch_linear_assignment
  24. confmat = confmat[None]
  25. # solve the linear sum assignment problem
  26. assignment = batch_linear_assignment(confmat.max() - confmat)
  27. confmat = confmat[0]
  28. # extract the true positives
  29. tps = confmat[torch.arange(confmat.shape[0]), assignment.flatten()]
  30. return tps.sum() / confmat.sum()
  31. def cluster_accuracy(preds: Tensor, target: Tensor, num_classes: int) -> Tensor:
  32. """Computes the clustering accuracy between the predicted and target clusters.
  33. Args:
  34. preds: predicted cluster labels
  35. target: ground truth cluster labels
  36. num_classes: number of classes
  37. Returns:
  38. Scalar tensor with clustering accuracy between 0.0 and 1.0
  39. Raises:
  40. RuntimeError:
  41. If `torch_linear_assignment` is not installed
  42. Example:
  43. >>> from torchmetrics.functional.clustering import cluster_accuracy
  44. >>> preds = torch.tensor([0, 0, 1, 1])
  45. >>> target = torch.tensor([1, 1, 0, 0])
  46. >>> cluster_accuracy(preds, target, 2)
  47. tensor(1.000)
  48. """
  49. if not _TORCH_LINEAR_ASSIGNMENT_AVAILABLE:
  50. raise RuntimeError(
  51. "Missing `torch_linear_assignment`. Please install it with `pip install torchmetrics[clustering]`."
  52. )
  53. check_cluster_labels(preds, target)
  54. confmat = multiclass_confusion_matrix(preds, target, num_classes=num_classes)
  55. return _cluster_accuracy_compute(confmat)