METADATA 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. Metadata-Version: 2.4
  2. Name: torchmetrics
  3. Version: 1.9.0
  4. Summary: PyTorch native Metrics
  5. Home-page: https://github.com/Lightning-AI/torchmetrics
  6. Download-URL: https://github.com/Lightning-AI/torchmetrics/archive/master.zip
  7. Author: Lightning-AI et al.
  8. Author-email: name@pytorchlightning.ai
  9. License: Apache-2.0
  10. Project-URL: Bug Tracker, https://github.com/Lightning-AI/torchmetrics/issues
  11. Project-URL: Documentation, https://torchmetrics.rtfd.io/en/latest/
  12. Project-URL: Source Code, https://github.com/Lightning-AI/torchmetrics
  13. Keywords: deep learning,machine learning,pytorch,metrics,AI
  14. Classifier: Environment :: Console
  15. Classifier: Natural Language :: English
  16. Classifier: Development Status :: 5 - Production/Stable
  17. Classifier: Intended Audience :: Developers
  18. Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
  19. Classifier: Topic :: Scientific/Engineering :: Image Recognition
  20. Classifier: Topic :: Scientific/Engineering :: Information Analysis
  21. Classifier: License :: OSI Approved :: Apache Software License
  22. Classifier: Operating System :: OS Independent
  23. Classifier: Programming Language :: Python :: 3
  24. Classifier: Programming Language :: Python :: 3.10
  25. Classifier: Programming Language :: Python :: 3.11
  26. Classifier: Programming Language :: Python :: 3.12
  27. Requires-Python: >=3.10
  28. Description-Content-Type: text/markdown
  29. License-File: LICENSE
  30. Requires-Dist: numpy>1.20.0
  31. Requires-Dist: packaging>17.1
  32. Requires-Dist: torch>=2.0.0
  33. Requires-Dist: lightning-utilities>=0.15.3
  34. Provides-Extra: audio
  35. Requires-Dist: requests>=2.22.0; extra == "audio"
  36. Requires-Dist: onnxruntime>=1.12.0; extra == "audio"
  37. Requires-Dist: gammatone>=1.0.0; extra == "audio"
  38. Requires-Dist: pesq>=0.0.4; extra == "audio"
  39. Requires-Dist: pystoi>=0.4.0; extra == "audio"
  40. Requires-Dist: librosa>=0.10.0; extra == "audio"
  41. Requires-Dist: torchaudio>=2.0.1; extra == "audio"
  42. Provides-Extra: clustering
  43. Requires-Dist: torch_linear_assignment>=0.0.2; extra == "clustering"
  44. Provides-Extra: debug
  45. Provides-Extra: detection
  46. Requires-Dist: pycocotools>2.0.0; extra == "detection"
  47. Requires-Dist: torchvision>=0.15.1; extra == "detection"
  48. Provides-Extra: image
  49. Requires-Dist: torch-fidelity<=0.4.0; extra == "image"
  50. Requires-Dist: torchvision>=0.15.1; extra == "image"
  51. Requires-Dist: scipy>1.0.0; extra == "image"
  52. Provides-Extra: integrate
  53. Provides-Extra: multimodal
  54. Requires-Dist: timm>=0.9.0; extra == "multimodal"
  55. Requires-Dist: transformers>=4.43.0; extra == "multimodal"
  56. Requires-Dist: einops>=0.7.0; extra == "multimodal"
  57. Requires-Dist: piq<=0.8.0; extra == "multimodal"
  58. Provides-Extra: text
  59. Requires-Dist: tqdm<4.68.0; extra == "text"
  60. Requires-Dist: nltk>3.8.1; extra == "text"
  61. Requires-Dist: ipadic>=1.0.0; extra == "text"
  62. Requires-Dist: mecab-python3>=1.0.6; extra == "text"
  63. Requires-Dist: transformers>=4.43.0; extra == "text"
  64. Requires-Dist: regex>=2021.9.24; extra == "text"
  65. Requires-Dist: sentencepiece>=0.2.0; extra == "text"
  66. Provides-Extra: typing
  67. Requires-Dist: types-six; extra == "typing"
  68. Requires-Dist: mypy==1.17.1; extra == "typing"
  69. Requires-Dist: types-requests; extra == "typing"
  70. Requires-Dist: types-tabulate; extra == "typing"
  71. Requires-Dist: types-setuptools; extra == "typing"
  72. Requires-Dist: types-emoji; extra == "typing"
  73. Requires-Dist: torch==2.8.0; extra == "typing"
  74. Requires-Dist: types-PyYAML; extra == "typing"
  75. Requires-Dist: types-protobuf; extra == "typing"
  76. Provides-Extra: video
  77. Requires-Dist: vmaf-torch>=1.1.0; extra == "video"
  78. Requires-Dist: einops>=0.7.0; extra == "video"
  79. Provides-Extra: visual
  80. Requires-Dist: matplotlib>=3.6.0; extra == "visual"
  81. Requires-Dist: SciencePlots>=2.0.0; extra == "visual"
  82. Provides-Extra: all
  83. Requires-Dist: requests>=2.22.0; extra == "all"
  84. Requires-Dist: onnxruntime>=1.12.0; extra == "all"
  85. Requires-Dist: gammatone>=1.0.0; extra == "all"
  86. Requires-Dist: pesq>=0.0.4; extra == "all"
  87. Requires-Dist: pystoi>=0.4.0; extra == "all"
  88. Requires-Dist: librosa>=0.10.0; extra == "all"
  89. Requires-Dist: torchaudio>=2.0.1; extra == "all"
  90. Requires-Dist: torch_linear_assignment>=0.0.2; extra == "all"
  91. Requires-Dist: pycocotools>2.0.0; extra == "all"
  92. Requires-Dist: torchvision>=0.15.1; extra == "all"
  93. Requires-Dist: torch-fidelity<=0.4.0; extra == "all"
  94. Requires-Dist: torchvision>=0.15.1; extra == "all"
  95. Requires-Dist: scipy>1.0.0; extra == "all"
  96. Requires-Dist: timm>=0.9.0; extra == "all"
  97. Requires-Dist: transformers>=4.43.0; extra == "all"
  98. Requires-Dist: einops>=0.7.0; extra == "all"
  99. Requires-Dist: piq<=0.8.0; extra == "all"
  100. Requires-Dist: tqdm<4.68.0; extra == "all"
  101. Requires-Dist: nltk>3.8.1; extra == "all"
  102. Requires-Dist: ipadic>=1.0.0; extra == "all"
  103. Requires-Dist: mecab-python3>=1.0.6; extra == "all"
  104. Requires-Dist: transformers>=4.43.0; extra == "all"
  105. Requires-Dist: regex>=2021.9.24; extra == "all"
  106. Requires-Dist: sentencepiece>=0.2.0; extra == "all"
  107. Requires-Dist: types-six; extra == "all"
  108. Requires-Dist: mypy==1.17.1; extra == "all"
  109. Requires-Dist: types-requests; extra == "all"
  110. Requires-Dist: types-tabulate; extra == "all"
  111. Requires-Dist: types-setuptools; extra == "all"
  112. Requires-Dist: types-emoji; extra == "all"
  113. Requires-Dist: torch==2.8.0; extra == "all"
  114. Requires-Dist: types-PyYAML; extra == "all"
  115. Requires-Dist: types-protobuf; extra == "all"
  116. Requires-Dist: vmaf-torch>=1.1.0; extra == "all"
  117. Requires-Dist: einops>=0.7.0; extra == "all"
  118. Requires-Dist: matplotlib>=3.6.0; extra == "all"
  119. Requires-Dist: SciencePlots>=2.0.0; extra == "all"
  120. Provides-Extra: dev
  121. Requires-Dist: requests>=2.22.0; extra == "dev"
  122. Requires-Dist: onnxruntime>=1.12.0; extra == "dev"
  123. Requires-Dist: gammatone>=1.0.0; extra == "dev"
  124. Requires-Dist: pesq>=0.0.4; extra == "dev"
  125. Requires-Dist: pystoi>=0.4.0; extra == "dev"
  126. Requires-Dist: librosa>=0.10.0; extra == "dev"
  127. Requires-Dist: torchaudio>=2.0.1; extra == "dev"
  128. Requires-Dist: torch_linear_assignment>=0.0.2; extra == "dev"
  129. Requires-Dist: pycocotools>2.0.0; extra == "dev"
  130. Requires-Dist: torchvision>=0.15.1; extra == "dev"
  131. Requires-Dist: torch-fidelity<=0.4.0; extra == "dev"
  132. Requires-Dist: torchvision>=0.15.1; extra == "dev"
  133. Requires-Dist: scipy>1.0.0; extra == "dev"
  134. Requires-Dist: timm>=0.9.0; extra == "dev"
  135. Requires-Dist: transformers>=4.43.0; extra == "dev"
  136. Requires-Dist: einops>=0.7.0; extra == "dev"
  137. Requires-Dist: piq<=0.8.0; extra == "dev"
  138. Requires-Dist: tqdm<4.68.0; extra == "dev"
  139. Requires-Dist: nltk>3.8.1; extra == "dev"
  140. Requires-Dist: ipadic>=1.0.0; extra == "dev"
  141. Requires-Dist: mecab-python3>=1.0.6; extra == "dev"
  142. Requires-Dist: transformers>=4.43.0; extra == "dev"
  143. Requires-Dist: regex>=2021.9.24; extra == "dev"
  144. Requires-Dist: sentencepiece>=0.2.0; extra == "dev"
  145. Requires-Dist: types-six; extra == "dev"
  146. Requires-Dist: mypy==1.17.1; extra == "dev"
  147. Requires-Dist: types-requests; extra == "dev"
  148. Requires-Dist: types-tabulate; extra == "dev"
  149. Requires-Dist: types-setuptools; extra == "dev"
  150. Requires-Dist: types-emoji; extra == "dev"
  151. Requires-Dist: torch==2.8.0; extra == "dev"
  152. Requires-Dist: types-PyYAML; extra == "dev"
  153. Requires-Dist: types-protobuf; extra == "dev"
  154. Requires-Dist: vmaf-torch>=1.1.0; extra == "dev"
  155. Requires-Dist: einops>=0.7.0; extra == "dev"
  156. Requires-Dist: matplotlib>=3.6.0; extra == "dev"
  157. Requires-Dist: SciencePlots>=2.0.0; extra == "dev"
  158. Requires-Dist: pytorch-msssim==1.0.0; extra == "dev"
  159. Requires-Dist: sewar>=0.4.4; extra == "dev"
  160. Requires-Dist: setuptools<82.0.0; extra == "dev"
  161. Requires-Dist: scikit-image>=0.19.0; extra == "dev"
  162. Requires-Dist: dists-pytorch==0.1; extra == "dev"
  163. Requires-Dist: rouge-score>0.1.0; extra == "dev"
  164. Requires-Dist: netcal>1.0.0; extra == "dev"
  165. Requires-Dist: pandas>1.4.0; extra == "dev"
  166. Requires-Dist: numpy<2.4.0; extra == "dev"
  167. Requires-Dist: torch_complex<0.5.0; extra == "dev"
  168. Requires-Dist: permetrics==2.0.0; extra == "dev"
  169. Requires-Dist: jiwer>=2.3.0; extra == "dev"
  170. Requires-Dist: aeon>=1.0.0; python_version > "3.10" and extra == "dev"
  171. Requires-Dist: mir-eval>=0.6; extra == "dev"
  172. Requires-Dist: huggingface-hub<0.35; extra == "dev"
  173. Requires-Dist: faster-coco-eval>=1.6.3; extra == "dev"
  174. Requires-Dist: mecab-ko-dic>=1.0.0; python_version < "3.12" and extra == "dev"
  175. Requires-Dist: monai==1.4.0; extra == "dev"
  176. Requires-Dist: mecab-ko<1.1.0,>=1.0.0; python_version < "3.12" and extra == "dev"
  177. Requires-Dist: bert_score==0.3.13; extra == "dev"
  178. Requires-Dist: sacrebleu>=2.3.0; extra == "dev"
  179. Requires-Dist: scipy>1.0.0; extra == "dev"
  180. Requires-Dist: lpips<=0.1.4; extra == "dev"
  181. Requires-Dist: dython==0.7.9; extra == "dev"
  182. Requires-Dist: properscoring==0.1; extra == "dev"
  183. Requires-Dist: fast-bss-eval>=0.1.0; extra == "dev"
  184. Requires-Dist: PyTDC==0.4.1; (platform_system == "Windows" and python_version < "3.12") and extra == "dev"
  185. Requires-Dist: fairlearn; extra == "dev"
  186. Requires-Dist: kornia>=0.6.7; extra == "dev"
  187. Requires-Dist: statsmodels>0.13.5; extra == "dev"
  188. Dynamic: author
  189. Dynamic: author-email
  190. Dynamic: classifier
  191. Dynamic: description
  192. Dynamic: description-content-type
  193. Dynamic: download-url
  194. Dynamic: home-page
  195. Dynamic: keywords
  196. Dynamic: license
  197. Dynamic: license-file
  198. Dynamic: project-url
  199. Dynamic: provides-extra
  200. Dynamic: requires-dist
  201. Dynamic: requires-python
  202. Dynamic: summary
  203. <div align="center">
  204. <img src="https://github.com/Lightning-AI/torchmetrics/raw/v1.9.0/docs/source/_static/images/logo.png" width="400px">
  205. **Machine learning metrics for distributed, scalable PyTorch applications.**
  206. ______________________________________________________________________
  207. <p align="center">
  208. <a href="#what-is-torchmetrics">What is Torchmetrics</a> •
  209. <a href="#implementing-your-own-module-metric">Implementing a metric</a> •
  210. <a href="#build-in-metrics">Built-in metrics</a> •
  211. <a href="https://lightning.ai/docs/torchmetrics/stable/">Docs</a> •
  212. <a href="#community">Community</a> •
  213. <a href="#license">License</a>
  214. </p>
  215. ______________________________________________________________________
  216. [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchmetrics)](https://pypi.org/project/torchmetrics/)
  217. [![PyPI Status](https://badge.fury.io/py/torchmetrics.svg)](https://badge.fury.io/py/torchmetrics)
  218. [![PyPI - Downloads](https://img.shields.io/pypi/dm/torchmetrics)
  219. ](https://pepy.tech/project/torchmetrics)
  220. [![Conda](https://img.shields.io/conda/v/conda-forge/torchmetrics?label=conda&color=success)](https://anaconda.org/conda-forge/torchmetrics)
  221. [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/torchmetrics/blob/master/LICENSE)
  222. [![CI testing | CPU](https://github.com/Lightning-AI/torchmetrics/actions/workflows/ci-tests.yml/badge.svg?event=push)](https://github.com/Lightning-AI/torchmetrics/actions/workflows/ci-tests.yml)
  223. [![Build Status](https://dev.azure.com/Lightning-AI/Metrics/_apis/build/status%2FTM.unittests?branchName=refs%2Ftags%2Fv1.9.0)](https://dev.azure.com/Lightning-AI/Metrics/_build/latest?definitionId=2&branchName=refs%2Ftags%2Fv1.9.0)
  224. [![codecov](https://codecov.io/gh/Lightning-AI/torchmetrics/release/v1.9.0/graph/badge.svg?token=NER6LPI3HS)](https://codecov.io/gh/Lightning-AI/torchmetrics)
  225. [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/torchmetrics/master.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/torchmetrics/master)
  226. [![Documentation Status](https://readthedocs.org/projects/torchmetrics/badge/?version=latest)](https://torchmetrics.readthedocs.io/en/latest/?badge=latest)
  227. [![Discord](https://img.shields.io/discord/1077906959069626439?style=plastic)](https://discord.gg/VptPCZkGNa)
  228. [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5844769.svg)](https://doi.org/10.5281/zenodo.5844769)
  229. [![JOSS status](https://joss.theoj.org/papers/561d9bb59b400158bc8204e2639dca43/status.svg)](https://joss.theoj.org/papers/561d9bb59b400158bc8204e2639dca43)
  230. ______________________________________________________________________
  231. </div>
  232. # Looking for GPUs?
  233. Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.
  234. - [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
  235. - [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.
  236. - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train.
  237. - [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
  238. - [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze.
  239. - [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.
  240. # Installation
  241. Simple installation from PyPI
  242. ```bash
  243. pip install torchmetrics
  244. ```
  245. <details>
  246. <summary>Other installations</summary>
  247. Install using conda
  248. ```bash
  249. conda install -c conda-forge torchmetrics
  250. ```
  251. Install using uv
  252. ```bash
  253. uv add torchmetrics
  254. ```
  255. Pip from source
  256. ```bash
  257. # with git
  258. pip install git+https://github.com/Lightning-AI/torchmetrics.git@release/stable
  259. ```
  260. Pip from archive
  261. ```bash
  262. pip install https://github.com/Lightning-AI/torchmetrics/archive/refs/heads/release/stable.zip
  263. ```
  264. Extra dependencies for specialized metrics:
  265. ```bash
  266. pip install torchmetrics[audio]
  267. pip install torchmetrics[image]
  268. pip install torchmetrics[text]
  269. pip install torchmetrics[all] # install all of the above
  270. ```
  271. Install latest developer version
  272. ```bash
  273. pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zip
  274. ```
  275. </details>
  276. ______________________________________________________________________
  277. # What is TorchMetrics
  278. TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
  279. - A standardized interface to increase reproducibility
  280. - Reduces boilerplate
  281. - Automatic accumulation over batches
  282. - Metrics optimized for distributed-training
  283. - Automatic synchronization between multiple devices
  284. You can use TorchMetrics with any PyTorch model or with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) to enjoy additional features such as:
  285. - Module metrics are automatically placed on the correct device.
  286. - Native support for logging metrics in Lightning to reduce even more boilerplate.
  287. # Using TorchMetrics
  288. ### Module metrics
  289. The [module-based metrics](https://lightning.ai/docs/torchmetrics/stable/references/metric.html) contain internal metric states (similar to the parameters of the PyTorch module) that automate accumulation and synchronization across devices!
  290. - Automatic accumulation over multiple batches
  291. - Automatic synchronization between multiple devices
  292. - Metric arithmetic
  293. **This can be run on CPU, single GPU or multi-GPUs!**
  294. For the single GPU/CPU case:
  295. ```python
  296. import torch
  297. # import our library
  298. import torchmetrics
  299. # initialize metric
  300. metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
  301. # move the metric to device you want computations to take place
  302. device = "cuda" if torch.cuda.is_available() else "cpu"
  303. metric.to(device)
  304. n_batches = 10
  305. for i in range(n_batches):
  306. # simulate a classification problem
  307. preds = torch.randn(10, 5).softmax(dim=-1).to(device)
  308. target = torch.randint(5, (10,)).to(device)
  309. # metric on current batch
  310. acc = metric(preds, target)
  311. print(f"Accuracy on batch {i}: {acc}")
  312. # metric on all batches using custom accumulation
  313. acc = metric.compute()
  314. print(f"Accuracy on all data: {acc}")
  315. ```
  316. Module metric usage remains the same when using multiple GPUs or multiple nodes.
  317. <details>
  318. <summary>Example using DDP</summary>
  319. <!--phmdoctest-mark.skip-->
  320. ```python
  321. import os
  322. import torch
  323. import torch.distributed as dist
  324. import torch.multiprocessing as mp
  325. from torch import nn
  326. from torch.nn.parallel import DistributedDataParallel as DDP
  327. import torchmetrics
  328. def metric_ddp(rank, world_size):
  329. os.environ["MASTER_ADDR"] = "localhost"
  330. os.environ["MASTER_PORT"] = "12355"
  331. # create default process group
  332. dist.init_process_group("gloo", rank=rank, world_size=world_size)
  333. # initialize model
  334. metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
  335. # define a model and append your metric to it
  336. # this allows metric states to be placed on correct accelerators when
  337. # .to(device) is called on the model
  338. model = nn.Linear(10, 10)
  339. model.metric = metric
  340. model = model.to(rank)
  341. # initialize DDP
  342. model = DDP(model, device_ids=[rank])
  343. n_epochs = 5
  344. # this shows iteration over multiple training epochs
  345. for n in range(n_epochs):
  346. # this will be replaced by a DataLoader with a DistributedSampler
  347. n_batches = 10
  348. for i in range(n_batches):
  349. # simulate a classification problem
  350. preds = torch.randn(10, 5).softmax(dim=-1)
  351. target = torch.randint(5, (10,))
  352. # metric on current batch
  353. acc = metric(preds, target)
  354. if rank == 0: # print only for rank 0
  355. print(f"Accuracy on batch {i}: {acc}")
  356. # metric on all batches and all accelerators using custom accumulation
  357. # accuracy is same across both accelerators
  358. acc = metric.compute()
  359. print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
  360. # Resetting internal state such that metric ready for new data
  361. metric.reset()
  362. # cleanup
  363. dist.destroy_process_group()
  364. if __name__ == "__main__":
  365. world_size = 2 # number of gpus to parallelize over
  366. mp.spawn(metric_ddp, args=(world_size,), nprocs=world_size, join=True)
  367. ```
  368. </details>
  369. ### Implementing your own Module metric
  370. Implementing your own metric is as easy as subclassing an [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). Simply, subclass `torchmetrics.Metric`
  371. and just implement the `update` and `compute` methods:
  372. ```python
  373. import torch
  374. from torchmetrics import Metric
  375. class MyAccuracy(Metric):
  376. def __init__(self):
  377. # remember to call super
  378. super().__init__()
  379. # call `self.add_state`for every internal state that is needed for the metrics computations
  380. # dist_reduce_fx indicates the function that should be used to reduce
  381. # state from multiple processes
  382. self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
  383. self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
  384. def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
  385. # extract predicted class index for computing accuracy
  386. preds = preds.argmax(dim=-1)
  387. assert preds.shape == target.shape
  388. # update metric states
  389. self.correct += torch.sum(preds == target)
  390. self.total += target.numel()
  391. def compute(self) -> torch.Tensor:
  392. # compute final result
  393. return self.correct.float() / self.total
  394. my_metric = MyAccuracy()
  395. preds = torch.randn(10, 5).softmax(dim=-1)
  396. target = torch.randint(5, (10,))
  397. print(my_metric(preds, target))
  398. ```
  399. ### Functional metrics
  400. Similar to [`torch.nn`](https://pytorch.org/docs/stable/nn.html), most metrics have both a [module-based](https://lightning.ai/docs/torchmetrics/stable/references/metric.html) and functional version.
  401. The functional versions are simple python functions that as input take [torch.tensors](https://pytorch.org/docs/stable/tensors.html) and return the corresponding metric as a [torch.tensor](https://pytorch.org/docs/stable/tensors.html).
  402. ```python
  403. import torch
  404. # import our library
  405. import torchmetrics
  406. # simulate a classification problem
  407. preds = torch.randn(10, 5).softmax(dim=-1)
  408. target = torch.randint(5, (10,))
  409. acc = torchmetrics.functional.classification.multiclass_accuracy(
  410. preds, target, num_classes=5
  411. )
  412. ```
  413. ### Covered domains and example metrics
  414. In total TorchMetrics contains [100+ metrics](https://lightning.ai/docs/torchmetrics/stable/all-metrics.html), which
  415. covers the following domains:
  416. - Audio
  417. - Classification
  418. - Detection
  419. - Information Retrieval
  420. - Image
  421. - Multimodal (Image-Text-3D Talking Heads)
  422. - Nominal
  423. - Regression
  424. - Segmentation
  425. - Text
  426. Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
  427. `pip install torchmetrics['image']` etc.
  428. ### Additional features
  429. #### Plotting
  430. Visualization of metrics can be important to help understand what is going on with your machine learning algorithms.
  431. Torchmetrics have built-in plotting support (install dependencies with `pip install torchmetrics[visual]`) for nearly
  432. all modular metrics through the `.plot` method. Simply call the method to get a simple visualization of any metric!
  433. ```python
  434. import torch
  435. from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix
  436. num_classes = 3
  437. # this will generate two distributions that comes more similar as iterations increase
  438. w = torch.randn(num_classes)
  439. target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
  440. preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
  441. acc = MulticlassAccuracy(num_classes=num_classes, average="micro")
  442. acc_per_class = MulticlassAccuracy(num_classes=num_classes, average=None)
  443. confmat = MulticlassConfusionMatrix(num_classes=num_classes)
  444. # plot single value
  445. for i in range(5):
  446. acc_per_class.update(preds(i), target(i))
  447. confmat.update(preds(i), target(i))
  448. fig1, ax1 = acc_per_class.plot()
  449. fig2, ax2 = confmat.plot()
  450. # plot multiple values
  451. values = []
  452. for i in range(10):
  453. values.append(acc(preds(i), target(i)))
  454. fig3, ax3 = acc.plot(values)
  455. ```
  456. <p align="center">
  457. <img src="https://github.com/Lightning-AI/torchmetrics/raw/v1.9.0/docs/source/_static/images/plot_example.png" width="1000">
  458. </p>
  459. For examples of plotting different metrics try running [this example file](_samples/plotting.py).
  460. # Contribute!
  461. The lightning + TorchMetrics team is hard at work adding even more metrics.
  462. But we're looking for incredible contributors like you to submit new metrics
  463. and improve existing ones!
  464. Join our [Discord](https://discord.com/invite/tfXFetEZxv) to get help with becoming a contributor!
  465. # Community
  466. For help or questions, join our huge community on [Discord](https://discord.com/invite/tfXFetEZxv)!
  467. # Citation
  468. We’re excited to continue the strong legacy of open source software and have been inspired
  469. over the years by Caffe, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai.
  470. If you want to cite this framework feel free to use GitHub's built-in citation option to generate a bibtex or APA-Style citation based on [this file](https://github.com/Lightning-AI/torchmetrics/blob/master/CITATION.cff) (but only if you loved it 😊).
  471. # License
  472. Please observe the Apache 2.0 license that is listed in this repository.
  473. In addition, the Lightning framework is Patent Pending.