__init__.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
  15. from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC
  16. from torchmetrics.classification.average_precision import (
  17. AveragePrecision,
  18. BinaryAveragePrecision,
  19. MulticlassAveragePrecision,
  20. MultilabelAveragePrecision,
  21. )
  22. from torchmetrics.classification.calibration_error import (
  23. BinaryCalibrationError,
  24. CalibrationError,
  25. MulticlassCalibrationError,
  26. )
  27. from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa
  28. from torchmetrics.classification.confusion_matrix import (
  29. BinaryConfusionMatrix,
  30. ConfusionMatrix,
  31. MulticlassConfusionMatrix,
  32. MultilabelConfusionMatrix,
  33. )
  34. from torchmetrics.classification.eer import EER, BinaryEER, MulticlassEER, MultilabelEER
  35. from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch
  36. from torchmetrics.classification.f_beta import (
  37. BinaryF1Score,
  38. BinaryFBetaScore,
  39. F1Score,
  40. FBetaScore,
  41. MulticlassF1Score,
  42. MulticlassFBetaScore,
  43. MultilabelF1Score,
  44. MultilabelFBetaScore,
  45. )
  46. from torchmetrics.classification.group_fairness import BinaryFairness, BinaryGroupStatRates
  47. from torchmetrics.classification.hamming import (
  48. BinaryHammingDistance,
  49. HammingDistance,
  50. MulticlassHammingDistance,
  51. MultilabelHammingDistance,
  52. )
  53. from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss
  54. from torchmetrics.classification.jaccard import (
  55. BinaryJaccardIndex,
  56. JaccardIndex,
  57. MulticlassJaccardIndex,
  58. MultilabelJaccardIndex,
  59. )
  60. from torchmetrics.classification.logauc import BinaryLogAUC, LogAUC, MulticlassLogAUC, MultilabelLogAUC
  61. from torchmetrics.classification.matthews_corrcoef import (
  62. BinaryMatthewsCorrCoef,
  63. MatthewsCorrCoef,
  64. MulticlassMatthewsCorrCoef,
  65. MultilabelMatthewsCorrCoef,
  66. )
  67. from torchmetrics.classification.negative_predictive_value import (
  68. BinaryNegativePredictiveValue,
  69. MulticlassNegativePredictiveValue,
  70. MultilabelNegativePredictiveValue,
  71. NegativePredictiveValue,
  72. )
  73. from torchmetrics.classification.precision_fixed_recall import (
  74. BinaryPrecisionAtFixedRecall,
  75. MulticlassPrecisionAtFixedRecall,
  76. MultilabelPrecisionAtFixedRecall,
  77. PrecisionAtFixedRecall,
  78. )
  79. from torchmetrics.classification.precision_recall import (
  80. BinaryPrecision,
  81. BinaryRecall,
  82. MulticlassPrecision,
  83. MulticlassRecall,
  84. MultilabelPrecision,
  85. MultilabelRecall,
  86. Precision,
  87. Recall,
  88. )
  89. from torchmetrics.classification.precision_recall_curve import (
  90. BinaryPrecisionRecallCurve,
  91. MulticlassPrecisionRecallCurve,
  92. MultilabelPrecisionRecallCurve,
  93. PrecisionRecallCurve,
  94. )
  95. from torchmetrics.classification.ranking import (
  96. MultilabelCoverageError,
  97. MultilabelRankingAveragePrecision,
  98. MultilabelRankingLoss,
  99. )
  100. from torchmetrics.classification.recall_fixed_precision import (
  101. BinaryRecallAtFixedPrecision,
  102. MulticlassRecallAtFixedPrecision,
  103. MultilabelRecallAtFixedPrecision,
  104. RecallAtFixedPrecision,
  105. )
  106. from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC
  107. from torchmetrics.classification.sensitivity_specificity import (
  108. BinarySensitivityAtSpecificity,
  109. MulticlassSensitivityAtSpecificity,
  110. MultilabelSensitivityAtSpecificity,
  111. SensitivityAtSpecificity,
  112. )
  113. from torchmetrics.classification.specificity import (
  114. BinarySpecificity,
  115. MulticlassSpecificity,
  116. MultilabelSpecificity,
  117. Specificity,
  118. )
  119. from torchmetrics.classification.specificity_sensitivity import (
  120. BinarySpecificityAtSensitivity,
  121. MulticlassSpecificityAtSensitivity,
  122. MultilabelSpecificityAtSensitivity,
  123. SpecificityAtSensitivity,
  124. )
  125. from torchmetrics.classification.stat_scores import (
  126. BinaryStatScores,
  127. MulticlassStatScores,
  128. MultilabelStatScores,
  129. StatScores,
  130. )
  131. __all__ = [
  132. "AUROC",
  133. "EER",
  134. "ROC",
  135. "Accuracy",
  136. "AveragePrecision",
  137. "BinaryAUROC",
  138. "BinaryAccuracy",
  139. "BinaryAveragePrecision",
  140. "BinaryCalibrationError",
  141. "BinaryCohenKappa",
  142. "BinaryConfusionMatrix",
  143. "BinaryEER",
  144. "BinaryF1Score",
  145. "BinaryFBetaScore",
  146. "BinaryFairness",
  147. "BinaryGroupStatRates",
  148. "BinaryHammingDistance",
  149. "BinaryHingeLoss",
  150. "BinaryJaccardIndex",
  151. "BinaryLogAUC",
  152. "BinaryMatthewsCorrCoef",
  153. "BinaryNegativePredictiveValue",
  154. "BinaryPrecision",
  155. "BinaryPrecisionAtFixedRecall",
  156. "BinaryPrecisionRecallCurve",
  157. "BinaryROC",
  158. "BinaryRecall",
  159. "BinaryRecallAtFixedPrecision",
  160. "BinarySensitivityAtSpecificity",
  161. "BinarySpecificity",
  162. "BinarySpecificityAtSensitivity",
  163. "BinaryStatScores",
  164. "CalibrationError",
  165. "CohenKappa",
  166. "ConfusionMatrix",
  167. "ExactMatch",
  168. "F1Score",
  169. "FBetaScore",
  170. "HammingDistance",
  171. "HingeLoss",
  172. "JaccardIndex",
  173. "LogAUC",
  174. "MatthewsCorrCoef",
  175. "MulticlassAUROC",
  176. "MulticlassAccuracy",
  177. "MulticlassAveragePrecision",
  178. "MulticlassCalibrationError",
  179. "MulticlassCohenKappa",
  180. "MulticlassConfusionMatrix",
  181. "MulticlassEER",
  182. "MulticlassExactMatch",
  183. "MulticlassF1Score",
  184. "MulticlassFBetaScore",
  185. "MulticlassHammingDistance",
  186. "MulticlassHingeLoss",
  187. "MulticlassJaccardIndex",
  188. "MulticlassLogAUC",
  189. "MulticlassMatthewsCorrCoef",
  190. "MulticlassNegativePredictiveValue",
  191. "MulticlassPrecision",
  192. "MulticlassPrecisionAtFixedRecall",
  193. "MulticlassPrecisionRecallCurve",
  194. "MulticlassROC",
  195. "MulticlassRecall",
  196. "MulticlassRecallAtFixedPrecision",
  197. "MulticlassSensitivityAtSpecificity",
  198. "MulticlassSpecificity",
  199. "MulticlassSpecificityAtSensitivity",
  200. "MulticlassStatScores",
  201. "MultilabelAUROC",
  202. "MultilabelAccuracy",
  203. "MultilabelAveragePrecision",
  204. "MultilabelConfusionMatrix",
  205. "MultilabelCoverageError",
  206. "MultilabelEER",
  207. "MultilabelExactMatch",
  208. "MultilabelF1Score",
  209. "MultilabelFBetaScore",
  210. "MultilabelHammingDistance",
  211. "MultilabelJaccardIndex",
  212. "MultilabelLogAUC",
  213. "MultilabelMatthewsCorrCoef",
  214. "MultilabelNegativePredictiveValue",
  215. "MultilabelPrecision",
  216. "MultilabelPrecisionAtFixedRecall",
  217. "MultilabelPrecisionRecallCurve",
  218. "MultilabelROC",
  219. "MultilabelRankingAveragePrecision",
  220. "MultilabelRankingLoss",
  221. "MultilabelRecall",
  222. "MultilabelRecallAtFixedPrecision",
  223. "MultilabelSensitivityAtSpecificity",
  224. "MultilabelSpecificity",
  225. "MultilabelSpecificityAtSensitivity",
  226. "MultilabelStatScores",
  227. "NegativePredictiveValue",
  228. "Precision",
  229. "PrecisionAtFixedRecall",
  230. "PrecisionRecallCurve",
  231. "Recall",
  232. "RecallAtFixedPrecision",
  233. "SensitivityAtSpecificity",
  234. "Specificity",
  235. "SpecificityAtSensitivity",
  236. "StatScores",
  237. ]