__init__.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. r"""Root package info."""
  2. import logging as __logging
  3. import os
  4. from lightning_utilities.core.imports import package_available
  5. from torchmetrics.__about__ import * # noqa: F403
  6. _logger = __logging.getLogger("torchmetrics")
  7. _logger.addHandler(__logging.StreamHandler())
  8. _logger.setLevel(__logging.INFO)
  9. _PACKAGE_ROOT = os.path.dirname(__file__)
  10. _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
  11. if package_available("numpy"):
  12. # compatibility for AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead
  13. import numpy
  14. numpy.Inf = numpy.inf
  15. if package_available("PIL"):
  16. import PIL
  17. if not hasattr(PIL, "PILLOW_VERSION"):
  18. PIL.PILLOW_VERSION = PIL.__version__
  19. if package_available("scipy"):
  20. import scipy.signal
  21. # back compatibility patch due to SMRMpy using scipy.signal.hamming
  22. if not hasattr(scipy.signal, "hamming"):
  23. scipy.signal.hamming = scipy.signal.windows.hamming
  24. from torchmetrics import functional # noqa: E402
  25. from torchmetrics.aggregation import ( # noqa: E402
  26. CatMetric,
  27. MaxMetric,
  28. MeanMetric,
  29. MinMetric,
  30. RunningMean,
  31. RunningSum,
  32. SumMetric,
  33. )
  34. from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining # noqa: E402
  35. from torchmetrics.audio._deprecated import ( # noqa: E402
  36. _ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio,
  37. )
  38. from torchmetrics.audio._deprecated import ( # noqa: E402
  39. _ScaleInvariantSignalNoiseRatio as ScaleInvariantSignalNoiseRatio,
  40. )
  41. from torchmetrics.audio._deprecated import _SignalDistortionRatio as SignalDistortionRatio # noqa: E402
  42. from torchmetrics.audio._deprecated import _SignalNoiseRatio as SignalNoiseRatio # noqa: E402
  43. from torchmetrics.classification import ( # noqa: E402
  44. AUROC,
  45. ROC,
  46. Accuracy,
  47. AveragePrecision,
  48. CalibrationError,
  49. CohenKappa,
  50. ConfusionMatrix,
  51. ExactMatch,
  52. F1Score,
  53. FBetaScore,
  54. HammingDistance,
  55. HingeLoss,
  56. JaccardIndex,
  57. LogAUC,
  58. MatthewsCorrCoef,
  59. NegativePredictiveValue,
  60. Precision,
  61. PrecisionAtFixedRecall,
  62. PrecisionRecallCurve,
  63. Recall,
  64. RecallAtFixedPrecision,
  65. SensitivityAtSpecificity,
  66. Specificity,
  67. SpecificityAtSensitivity,
  68. StatScores,
  69. )
  70. from torchmetrics.collections import MetricCollection # noqa: E402
  71. from torchmetrics.detection._deprecated import _ModifiedPanopticQuality as ModifiedPanopticQuality # noqa: E402
  72. from torchmetrics.detection._deprecated import _PanopticQuality as PanopticQuality # noqa: E402
  73. from torchmetrics.image._deprecated import ( # noqa: E402
  74. _ErrorRelativeGlobalDimensionlessSynthesis as ErrorRelativeGlobalDimensionlessSynthesis,
  75. )
  76. from torchmetrics.image._deprecated import ( # noqa: E402
  77. _MultiScaleStructuralSimilarityIndexMeasure as MultiScaleStructuralSimilarityIndexMeasure,
  78. )
  79. from torchmetrics.image._deprecated import _PeakSignalNoiseRatio as PeakSignalNoiseRatio # noqa: E402
  80. from torchmetrics.image._deprecated import _RelativeAverageSpectralError as RelativeAverageSpectralError # noqa: E402
  81. from torchmetrics.image._deprecated import ( # noqa: E402
  82. _RootMeanSquaredErrorUsingSlidingWindow as RootMeanSquaredErrorUsingSlidingWindow,
  83. )
  84. from torchmetrics.image._deprecated import _SpectralAngleMapper as SpectralAngleMapper # noqa: E402
  85. from torchmetrics.image._deprecated import _SpectralDistortionIndex as SpectralDistortionIndex # noqa: E402
  86. from torchmetrics.image._deprecated import ( # noqa: E402
  87. _StructuralSimilarityIndexMeasure as StructuralSimilarityIndexMeasure,
  88. )
  89. from torchmetrics.image._deprecated import _TotalVariation as TotalVariation # noqa: E402
  90. from torchmetrics.image._deprecated import _UniversalImageQualityIndex as UniversalImageQualityIndex # noqa: E402
  91. from torchmetrics.metric import Metric # noqa: E402
  92. from torchmetrics.nominal import ( # noqa: E402
  93. CramersV,
  94. FleissKappa,
  95. PearsonsContingencyCoefficient,
  96. TheilsU,
  97. TschuprowsT,
  98. )
  99. from torchmetrics.regression import ( # noqa: E402
  100. ConcordanceCorrCoef,
  101. CosineSimilarity,
  102. CriticalSuccessIndex,
  103. ExplainedVariance,
  104. KendallRankCorrCoef,
  105. KLDivergence,
  106. LogCoshError,
  107. MeanAbsoluteError,
  108. MeanAbsolutePercentageError,
  109. MeanSquaredError,
  110. MeanSquaredLogError,
  111. MinkowskiDistance,
  112. NormalizedRootMeanSquaredError,
  113. PearsonCorrCoef,
  114. R2Score,
  115. RelativeSquaredError,
  116. SpearmanCorrCoef,
  117. SymmetricMeanAbsolutePercentageError,
  118. TweedieDevianceScore,
  119. WeightedMeanAbsolutePercentageError,
  120. )
  121. from torchmetrics.retrieval._deprecated import _RetrievalFallOut as RetrievalFallOut # noqa: E402
  122. from torchmetrics.retrieval._deprecated import _RetrievalHitRate as RetrievalHitRate # noqa: E402
  123. from torchmetrics.retrieval._deprecated import _RetrievalMAP as RetrievalMAP # noqa: E402
  124. from torchmetrics.retrieval._deprecated import _RetrievalMRR as RetrievalMRR # noqa: E402
  125. from torchmetrics.retrieval._deprecated import _RetrievalNormalizedDCG as RetrievalNormalizedDCG # noqa: E402
  126. from torchmetrics.retrieval._deprecated import _RetrievalPrecision as RetrievalPrecision # noqa: E402
  127. from torchmetrics.retrieval._deprecated import ( # noqa: E402
  128. _RetrievalPrecisionRecallCurve as RetrievalPrecisionRecallCurve,
  129. )
  130. from torchmetrics.retrieval._deprecated import _RetrievalRecall as RetrievalRecall # noqa: E402
  131. from torchmetrics.retrieval._deprecated import ( # noqa: E402
  132. _RetrievalRecallAtFixedPrecision as RetrievalRecallAtFixedPrecision,
  133. )
  134. from torchmetrics.retrieval._deprecated import _RetrievalRPrecision as RetrievalRPrecision # noqa: E402
  135. from torchmetrics.text._deprecated import _BLEUScore as BLEUScore # noqa: E402
  136. from torchmetrics.text._deprecated import _CharErrorRate as CharErrorRate # noqa: E402
  137. from torchmetrics.text._deprecated import _CHRFScore as CHRFScore # noqa: E402
  138. from torchmetrics.text._deprecated import _ExtendedEditDistance as ExtendedEditDistance # noqa: E402
  139. from torchmetrics.text._deprecated import _MatchErrorRate as MatchErrorRate # noqa: E402
  140. from torchmetrics.text._deprecated import _Perplexity as Perplexity # noqa: E402
  141. from torchmetrics.text._deprecated import _SacreBLEUScore as SacreBLEUScore # noqa: E402
  142. from torchmetrics.text._deprecated import _SQuAD as SQuAD # noqa: E402
  143. from torchmetrics.text._deprecated import _TranslationEditRate as TranslationEditRate # noqa: E402
  144. from torchmetrics.text._deprecated import _WordErrorRate as WordErrorRate # noqa: E402
  145. from torchmetrics.text._deprecated import _WordInfoLost as WordInfoLost # noqa: E402
  146. from torchmetrics.text._deprecated import _WordInfoPreserved as WordInfoPreserved # noqa: E402
  147. from torchmetrics.wrappers import ( # noqa: E402
  148. BootStrapper,
  149. ClasswiseWrapper,
  150. MetricTracker,
  151. MinMaxMetric,
  152. MultioutputWrapper,
  153. MultitaskWrapper,
  154. )
  155. __all__ = [
  156. "AUROC",
  157. "ROC",
  158. "Accuracy",
  159. "AveragePrecision",
  160. "BLEUScore",
  161. "BootStrapper",
  162. "CHRFScore",
  163. "CalibrationError",
  164. "CatMetric",
  165. "CharErrorRate",
  166. "ClasswiseWrapper",
  167. "CohenKappa",
  168. "ConcordanceCorrCoef",
  169. "ConfusionMatrix",
  170. "CosineSimilarity",
  171. "CramersV",
  172. "CriticalSuccessIndex",
  173. "ErrorRelativeGlobalDimensionlessSynthesis",
  174. "ExactMatch",
  175. "ExplainedVariance",
  176. "ExtendedEditDistance",
  177. "F1Score",
  178. "FBetaScore",
  179. "FleissKappa",
  180. "HammingDistance",
  181. "HingeLoss",
  182. "JaccardIndex",
  183. "KLDivergence",
  184. "KendallRankCorrCoef",
  185. "LogAUC",
  186. "LogCoshError",
  187. "MatchErrorRate",
  188. "MatthewsCorrCoef",
  189. "MaxMetric",
  190. "MeanAbsoluteError",
  191. "MeanAbsolutePercentageError",
  192. "MeanMetric",
  193. "MeanSquaredError",
  194. "MeanSquaredLogError",
  195. "Metric",
  196. "MetricCollection",
  197. "MetricTracker",
  198. "MinMaxMetric",
  199. "MinMetric",
  200. "MinkowskiDistance",
  201. "ModifiedPanopticQuality",
  202. "MultiScaleStructuralSimilarityIndexMeasure",
  203. "MultioutputWrapper",
  204. "MultitaskWrapper",
  205. "NegativePredictiveValue",
  206. "NormalizedRootMeanSquaredError",
  207. "PanopticQuality",
  208. "PeakSignalNoiseRatio",
  209. "PearsonCorrCoef",
  210. "PearsonsContingencyCoefficient",
  211. "PermutationInvariantTraining",
  212. "Perplexity",
  213. "Precision",
  214. "PrecisionAtFixedRecall",
  215. "PrecisionRecallCurve",
  216. "R2Score",
  217. "Recall",
  218. "RecallAtFixedPrecision",
  219. "RelativeAverageSpectralError",
  220. "RelativeSquaredError",
  221. "RetrievalFallOut",
  222. "RetrievalHitRate",
  223. "RetrievalMAP",
  224. "RetrievalMRR",
  225. "RetrievalNormalizedDCG",
  226. "RetrievalPrecision",
  227. "RetrievalPrecisionRecallCurve",
  228. "RetrievalRPrecision",
  229. "RetrievalRecall",
  230. "RetrievalRecallAtFixedPrecision",
  231. "RootMeanSquaredErrorUsingSlidingWindow",
  232. "RunningMean",
  233. "RunningSum",
  234. "SQuAD",
  235. "SacreBLEUScore",
  236. "ScaleInvariantSignalDistortionRatio",
  237. "ScaleInvariantSignalNoiseRatio",
  238. "SensitivityAtSpecificity",
  239. "SignalDistortionRatio",
  240. "SignalNoiseRatio",
  241. "SpearmanCorrCoef",
  242. "Specificity",
  243. "SpecificityAtSensitivity",
  244. "SpectralAngleMapper",
  245. "SpectralDistortionIndex",
  246. "StatScores",
  247. "StructuralSimilarityIndexMeasure",
  248. "SumMetric",
  249. "SymmetricMeanAbsolutePercentageError",
  250. "TheilsU",
  251. "TotalVariation",
  252. "TranslationEditRate",
  253. "TschuprowsT",
  254. "TweedieDevianceScore",
  255. "UniversalImageQualityIndex",
  256. "WeightedMeanAbsolutePercentageError",
  257. "WordErrorRate",
  258. "WordInfoLost",
  259. "WordInfoPreserved",
  260. "functional",
  261. ]