| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582 |
- Metadata-Version: 2.4
- Name: torchmetrics
- Version: 1.9.0
- Summary: PyTorch native Metrics
- Home-page: https://github.com/Lightning-AI/torchmetrics
- Download-URL: https://github.com/Lightning-AI/torchmetrics/archive/master.zip
- Author: Lightning-AI et al.
- Author-email: name@pytorchlightning.ai
- License: Apache-2.0
- Project-URL: Bug Tracker, https://github.com/Lightning-AI/torchmetrics/issues
- Project-URL: Documentation, https://torchmetrics.rtfd.io/en/latest/
- Project-URL: Source Code, https://github.com/Lightning-AI/torchmetrics
- Keywords: deep learning,machine learning,pytorch,metrics,AI
- Classifier: Environment :: Console
- Classifier: Natural Language :: English
- Classifier: Development Status :: 5 - Production/Stable
- Classifier: Intended Audience :: Developers
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
- Classifier: Topic :: Scientific/Engineering :: Image Recognition
- Classifier: Topic :: Scientific/Engineering :: Information Analysis
- Classifier: License :: OSI Approved :: Apache Software License
- Classifier: Operating System :: OS Independent
- Classifier: Programming Language :: Python :: 3
- Classifier: Programming Language :: Python :: 3.10
- Classifier: Programming Language :: Python :: 3.11
- Classifier: Programming Language :: Python :: 3.12
- Requires-Python: >=3.10
- Description-Content-Type: text/markdown
- License-File: LICENSE
- Requires-Dist: numpy>1.20.0
- Requires-Dist: packaging>17.1
- Requires-Dist: torch>=2.0.0
- Requires-Dist: lightning-utilities>=0.15.3
- Provides-Extra: audio
- Requires-Dist: requests>=2.22.0; extra == "audio"
- Requires-Dist: onnxruntime>=1.12.0; extra == "audio"
- Requires-Dist: gammatone>=1.0.0; extra == "audio"
- Requires-Dist: pesq>=0.0.4; extra == "audio"
- Requires-Dist: pystoi>=0.4.0; extra == "audio"
- Requires-Dist: librosa>=0.10.0; extra == "audio"
- Requires-Dist: torchaudio>=2.0.1; extra == "audio"
- Provides-Extra: clustering
- Requires-Dist: torch_linear_assignment>=0.0.2; extra == "clustering"
- Provides-Extra: debug
- Provides-Extra: detection
- Requires-Dist: pycocotools>2.0.0; extra == "detection"
- Requires-Dist: torchvision>=0.15.1; extra == "detection"
- Provides-Extra: image
- Requires-Dist: torch-fidelity<=0.4.0; extra == "image"
- Requires-Dist: torchvision>=0.15.1; extra == "image"
- Requires-Dist: scipy>1.0.0; extra == "image"
- Provides-Extra: integrate
- Provides-Extra: multimodal
- Requires-Dist: timm>=0.9.0; extra == "multimodal"
- Requires-Dist: transformers>=4.43.0; extra == "multimodal"
- Requires-Dist: einops>=0.7.0; extra == "multimodal"
- Requires-Dist: piq<=0.8.0; extra == "multimodal"
- Provides-Extra: text
- Requires-Dist: tqdm<4.68.0; extra == "text"
- Requires-Dist: nltk>3.8.1; extra == "text"
- Requires-Dist: ipadic>=1.0.0; extra == "text"
- Requires-Dist: mecab-python3>=1.0.6; extra == "text"
- Requires-Dist: transformers>=4.43.0; extra == "text"
- Requires-Dist: regex>=2021.9.24; extra == "text"
- Requires-Dist: sentencepiece>=0.2.0; extra == "text"
- Provides-Extra: typing
- Requires-Dist: types-six; extra == "typing"
- Requires-Dist: mypy==1.17.1; extra == "typing"
- Requires-Dist: types-requests; extra == "typing"
- Requires-Dist: types-tabulate; extra == "typing"
- Requires-Dist: types-setuptools; extra == "typing"
- Requires-Dist: types-emoji; extra == "typing"
- Requires-Dist: torch==2.8.0; extra == "typing"
- Requires-Dist: types-PyYAML; extra == "typing"
- Requires-Dist: types-protobuf; extra == "typing"
- Provides-Extra: video
- Requires-Dist: vmaf-torch>=1.1.0; extra == "video"
- Requires-Dist: einops>=0.7.0; extra == "video"
- Provides-Extra: visual
- Requires-Dist: matplotlib>=3.6.0; extra == "visual"
- Requires-Dist: SciencePlots>=2.0.0; extra == "visual"
- Provides-Extra: all
- Requires-Dist: requests>=2.22.0; extra == "all"
- Requires-Dist: onnxruntime>=1.12.0; extra == "all"
- Requires-Dist: gammatone>=1.0.0; extra == "all"
- Requires-Dist: pesq>=0.0.4; extra == "all"
- Requires-Dist: pystoi>=0.4.0; extra == "all"
- Requires-Dist: librosa>=0.10.0; extra == "all"
- Requires-Dist: torchaudio>=2.0.1; extra == "all"
- Requires-Dist: torch_linear_assignment>=0.0.2; extra == "all"
- Requires-Dist: pycocotools>2.0.0; extra == "all"
- Requires-Dist: torchvision>=0.15.1; extra == "all"
- Requires-Dist: torch-fidelity<=0.4.0; extra == "all"
- Requires-Dist: torchvision>=0.15.1; extra == "all"
- Requires-Dist: scipy>1.0.0; extra == "all"
- Requires-Dist: timm>=0.9.0; extra == "all"
- Requires-Dist: transformers>=4.43.0; extra == "all"
- Requires-Dist: einops>=0.7.0; extra == "all"
- Requires-Dist: piq<=0.8.0; extra == "all"
- Requires-Dist: tqdm<4.68.0; extra == "all"
- Requires-Dist: nltk>3.8.1; extra == "all"
- Requires-Dist: ipadic>=1.0.0; extra == "all"
- Requires-Dist: mecab-python3>=1.0.6; extra == "all"
- Requires-Dist: transformers>=4.43.0; extra == "all"
- Requires-Dist: regex>=2021.9.24; extra == "all"
- Requires-Dist: sentencepiece>=0.2.0; extra == "all"
- Requires-Dist: types-six; extra == "all"
- Requires-Dist: mypy==1.17.1; extra == "all"
- Requires-Dist: types-requests; extra == "all"
- Requires-Dist: types-tabulate; extra == "all"
- Requires-Dist: types-setuptools; extra == "all"
- Requires-Dist: types-emoji; extra == "all"
- Requires-Dist: torch==2.8.0; extra == "all"
- Requires-Dist: types-PyYAML; extra == "all"
- Requires-Dist: types-protobuf; extra == "all"
- Requires-Dist: vmaf-torch>=1.1.0; extra == "all"
- Requires-Dist: einops>=0.7.0; extra == "all"
- Requires-Dist: matplotlib>=3.6.0; extra == "all"
- Requires-Dist: SciencePlots>=2.0.0; extra == "all"
- Provides-Extra: dev
- Requires-Dist: requests>=2.22.0; extra == "dev"
- Requires-Dist: onnxruntime>=1.12.0; extra == "dev"
- Requires-Dist: gammatone>=1.0.0; extra == "dev"
- Requires-Dist: pesq>=0.0.4; extra == "dev"
- Requires-Dist: pystoi>=0.4.0; extra == "dev"
- Requires-Dist: librosa>=0.10.0; extra == "dev"
- Requires-Dist: torchaudio>=2.0.1; extra == "dev"
- Requires-Dist: torch_linear_assignment>=0.0.2; extra == "dev"
- Requires-Dist: pycocotools>2.0.0; extra == "dev"
- Requires-Dist: torchvision>=0.15.1; extra == "dev"
- Requires-Dist: torch-fidelity<=0.4.0; extra == "dev"
- Requires-Dist: torchvision>=0.15.1; extra == "dev"
- Requires-Dist: scipy>1.0.0; extra == "dev"
- Requires-Dist: timm>=0.9.0; extra == "dev"
- Requires-Dist: transformers>=4.43.0; extra == "dev"
- Requires-Dist: einops>=0.7.0; extra == "dev"
- Requires-Dist: piq<=0.8.0; extra == "dev"
- Requires-Dist: tqdm<4.68.0; extra == "dev"
- Requires-Dist: nltk>3.8.1; extra == "dev"
- Requires-Dist: ipadic>=1.0.0; extra == "dev"
- Requires-Dist: mecab-python3>=1.0.6; extra == "dev"
- Requires-Dist: transformers>=4.43.0; extra == "dev"
- Requires-Dist: regex>=2021.9.24; extra == "dev"
- Requires-Dist: sentencepiece>=0.2.0; extra == "dev"
- Requires-Dist: types-six; extra == "dev"
- Requires-Dist: mypy==1.17.1; extra == "dev"
- Requires-Dist: types-requests; extra == "dev"
- Requires-Dist: types-tabulate; extra == "dev"
- Requires-Dist: types-setuptools; extra == "dev"
- Requires-Dist: types-emoji; extra == "dev"
- Requires-Dist: torch==2.8.0; extra == "dev"
- Requires-Dist: types-PyYAML; extra == "dev"
- Requires-Dist: types-protobuf; extra == "dev"
- Requires-Dist: vmaf-torch>=1.1.0; extra == "dev"
- Requires-Dist: einops>=0.7.0; extra == "dev"
- Requires-Dist: matplotlib>=3.6.0; extra == "dev"
- Requires-Dist: SciencePlots>=2.0.0; extra == "dev"
- Requires-Dist: pytorch-msssim==1.0.0; extra == "dev"
- Requires-Dist: sewar>=0.4.4; extra == "dev"
- Requires-Dist: setuptools<82.0.0; extra == "dev"
- Requires-Dist: scikit-image>=0.19.0; extra == "dev"
- Requires-Dist: dists-pytorch==0.1; extra == "dev"
- Requires-Dist: rouge-score>0.1.0; extra == "dev"
- Requires-Dist: netcal>1.0.0; extra == "dev"
- Requires-Dist: pandas>1.4.0; extra == "dev"
- Requires-Dist: numpy<2.4.0; extra == "dev"
- Requires-Dist: torch_complex<0.5.0; extra == "dev"
- Requires-Dist: permetrics==2.0.0; extra == "dev"
- Requires-Dist: jiwer>=2.3.0; extra == "dev"
- Requires-Dist: aeon>=1.0.0; python_version > "3.10" and extra == "dev"
- Requires-Dist: mir-eval>=0.6; extra == "dev"
- Requires-Dist: huggingface-hub<0.35; extra == "dev"
- Requires-Dist: faster-coco-eval>=1.6.3; extra == "dev"
- Requires-Dist: mecab-ko-dic>=1.0.0; python_version < "3.12" and extra == "dev"
- Requires-Dist: monai==1.4.0; extra == "dev"
- Requires-Dist: mecab-ko<1.1.0,>=1.0.0; python_version < "3.12" and extra == "dev"
- Requires-Dist: bert_score==0.3.13; extra == "dev"
- Requires-Dist: sacrebleu>=2.3.0; extra == "dev"
- Requires-Dist: scipy>1.0.0; extra == "dev"
- Requires-Dist: lpips<=0.1.4; extra == "dev"
- Requires-Dist: dython==0.7.9; extra == "dev"
- Requires-Dist: properscoring==0.1; extra == "dev"
- Requires-Dist: fast-bss-eval>=0.1.0; extra == "dev"
- Requires-Dist: PyTDC==0.4.1; (platform_system == "Windows" and python_version < "3.12") and extra == "dev"
- Requires-Dist: fairlearn; extra == "dev"
- Requires-Dist: kornia>=0.6.7; extra == "dev"
- Requires-Dist: statsmodels>0.13.5; extra == "dev"
- Dynamic: author
- Dynamic: author-email
- Dynamic: classifier
- Dynamic: description
- Dynamic: description-content-type
- Dynamic: download-url
- Dynamic: home-page
- Dynamic: keywords
- Dynamic: license
- Dynamic: license-file
- Dynamic: project-url
- Dynamic: provides-extra
- Dynamic: requires-dist
- Dynamic: requires-python
- Dynamic: summary
- <div align="center">
- <img src="https://github.com/Lightning-AI/torchmetrics/raw/v1.9.0/docs/source/_static/images/logo.png" width="400px">
- **Machine learning metrics for distributed, scalable PyTorch applications.**
- ______________________________________________________________________
- <p align="center">
- <a href="#what-is-torchmetrics">What is Torchmetrics</a> •
- <a href="#implementing-your-own-module-metric">Implementing a metric</a> •
- <a href="#build-in-metrics">Built-in metrics</a> •
- <a href="https://lightning.ai/docs/torchmetrics/stable/">Docs</a> •
- <a href="#community">Community</a> •
- <a href="#license">License</a>
- </p>
- ______________________________________________________________________
- [](https://pypi.org/project/torchmetrics/)
- [](https://badge.fury.io/py/torchmetrics)
- [
- ](https://pepy.tech/project/torchmetrics)
- [](https://anaconda.org/conda-forge/torchmetrics)
- [](https://github.com/Lightning-AI/torchmetrics/blob/master/LICENSE)
- [](https://github.com/Lightning-AI/torchmetrics/actions/workflows/ci-tests.yml)
- [](https://dev.azure.com/Lightning-AI/Metrics/_build/latest?definitionId=2&branchName=refs%2Ftags%2Fv1.9.0)
- [](https://codecov.io/gh/Lightning-AI/torchmetrics)
- [](https://results.pre-commit.ci/latest/github/Lightning-AI/torchmetrics/master)
- [](https://torchmetrics.readthedocs.io/en/latest/?badge=latest)
- [](https://discord.gg/VptPCZkGNa)
- [](https://doi.org/10.5281/zenodo.5844769)
- [](https://joss.theoj.org/papers/561d9bb59b400158bc8204e2639dca43)
- ______________________________________________________________________
- </div>
- # Looking for GPUs?
- Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.
- - [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
- - [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.
- - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train.
- - [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
- - [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze.
- - [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.
- # Installation
- Simple installation from PyPI
- ```bash
- pip install torchmetrics
- ```
- <details>
- <summary>Other installations</summary>
- Install using conda
- ```bash
- conda install -c conda-forge torchmetrics
- ```
- Install using uv
- ```bash
- uv add torchmetrics
- ```
- Pip from source
- ```bash
- # with git
- pip install git+https://github.com/Lightning-AI/torchmetrics.git@release/stable
- ```
- Pip from archive
- ```bash
- pip install https://github.com/Lightning-AI/torchmetrics/archive/refs/heads/release/stable.zip
- ```
- Extra dependencies for specialized metrics:
- ```bash
- pip install torchmetrics[audio]
- pip install torchmetrics[image]
- pip install torchmetrics[text]
- pip install torchmetrics[all] # install all of the above
- ```
- Install latest developer version
- ```bash
- pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zip
- ```
- </details>
- ______________________________________________________________________
- # What is TorchMetrics
- TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- - A standardized interface to increase reproducibility
- - Reduces boilerplate
- - Automatic accumulation over batches
- - Metrics optimized for distributed-training
- - Automatic synchronization between multiple devices
- You can use TorchMetrics with any PyTorch model or with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) to enjoy additional features such as:
- - Module metrics are automatically placed on the correct device.
- - Native support for logging metrics in Lightning to reduce even more boilerplate.
- # Using TorchMetrics
- ### Module metrics
- The [module-based metrics](https://lightning.ai/docs/torchmetrics/stable/references/metric.html) contain internal metric states (similar to the parameters of the PyTorch module) that automate accumulation and synchronization across devices!
- - Automatic accumulation over multiple batches
- - Automatic synchronization between multiple devices
- - Metric arithmetic
- **This can be run on CPU, single GPU or multi-GPUs!**
- For the single GPU/CPU case:
- ```python
- import torch
- # import our library
- import torchmetrics
- # initialize metric
- metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
- # move the metric to device you want computations to take place
- device = "cuda" if torch.cuda.is_available() else "cpu"
- metric.to(device)
- n_batches = 10
- for i in range(n_batches):
- # simulate a classification problem
- preds = torch.randn(10, 5).softmax(dim=-1).to(device)
- target = torch.randint(5, (10,)).to(device)
- # metric on current batch
- acc = metric(preds, target)
- print(f"Accuracy on batch {i}: {acc}")
- # metric on all batches using custom accumulation
- acc = metric.compute()
- print(f"Accuracy on all data: {acc}")
- ```
- Module metric usage remains the same when using multiple GPUs or multiple nodes.
- <details>
- <summary>Example using DDP</summary>
- <!--phmdoctest-mark.skip-->
- ```python
- import os
- import torch
- import torch.distributed as dist
- import torch.multiprocessing as mp
- from torch import nn
- from torch.nn.parallel import DistributedDataParallel as DDP
- import torchmetrics
- def metric_ddp(rank, world_size):
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "12355"
- # create default process group
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
- # initialize model
- metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
- # define a model and append your metric to it
- # this allows metric states to be placed on correct accelerators when
- # .to(device) is called on the model
- model = nn.Linear(10, 10)
- model.metric = metric
- model = model.to(rank)
- # initialize DDP
- model = DDP(model, device_ids=[rank])
- n_epochs = 5
- # this shows iteration over multiple training epochs
- for n in range(n_epochs):
- # this will be replaced by a DataLoader with a DistributedSampler
- n_batches = 10
- for i in range(n_batches):
- # simulate a classification problem
- preds = torch.randn(10, 5).softmax(dim=-1)
- target = torch.randint(5, (10,))
- # metric on current batch
- acc = metric(preds, target)
- if rank == 0: # print only for rank 0
- print(f"Accuracy on batch {i}: {acc}")
- # metric on all batches and all accelerators using custom accumulation
- # accuracy is same across both accelerators
- acc = metric.compute()
- print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
- # Resetting internal state such that metric ready for new data
- metric.reset()
- # cleanup
- dist.destroy_process_group()
- if __name__ == "__main__":
- world_size = 2 # number of gpus to parallelize over
- mp.spawn(metric_ddp, args=(world_size,), nprocs=world_size, join=True)
- ```
- </details>
- ### Implementing your own Module metric
- Implementing your own metric is as easy as subclassing an [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). Simply, subclass `torchmetrics.Metric`
- and just implement the `update` and `compute` methods:
- ```python
- import torch
- from torchmetrics import Metric
- class MyAccuracy(Metric):
- def __init__(self):
- # remember to call super
- super().__init__()
- # call `self.add_state`for every internal state that is needed for the metrics computations
- # dist_reduce_fx indicates the function that should be used to reduce
- # state from multiple processes
- self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
- def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
- # extract predicted class index for computing accuracy
- preds = preds.argmax(dim=-1)
- assert preds.shape == target.shape
- # update metric states
- self.correct += torch.sum(preds == target)
- self.total += target.numel()
- def compute(self) -> torch.Tensor:
- # compute final result
- return self.correct.float() / self.total
- my_metric = MyAccuracy()
- preds = torch.randn(10, 5).softmax(dim=-1)
- target = torch.randint(5, (10,))
- print(my_metric(preds, target))
- ```
- ### Functional metrics
- Similar to [`torch.nn`](https://pytorch.org/docs/stable/nn.html), most metrics have both a [module-based](https://lightning.ai/docs/torchmetrics/stable/references/metric.html) and functional version.
- The functional versions are simple python functions that as input take [torch.tensors](https://pytorch.org/docs/stable/tensors.html) and return the corresponding metric as a [torch.tensor](https://pytorch.org/docs/stable/tensors.html).
- ```python
- import torch
- # import our library
- import torchmetrics
- # simulate a classification problem
- preds = torch.randn(10, 5).softmax(dim=-1)
- target = torch.randint(5, (10,))
- acc = torchmetrics.functional.classification.multiclass_accuracy(
- preds, target, num_classes=5
- )
- ```
- ### Covered domains and example metrics
- In total TorchMetrics contains [100+ metrics](https://lightning.ai/docs/torchmetrics/stable/all-metrics.html), which
- covers the following domains:
- - Audio
- - Classification
- - Detection
- - Information Retrieval
- - Image
- - Multimodal (Image-Text-3D Talking Heads)
- - Nominal
- - Regression
- - Segmentation
- - Text
- Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
- `pip install torchmetrics['image']` etc.
- ### Additional features
- #### Plotting
- Visualization of metrics can be important to help understand what is going on with your machine learning algorithms.
- Torchmetrics have built-in plotting support (install dependencies with `pip install torchmetrics[visual]`) for nearly
- all modular metrics through the `.plot` method. Simply call the method to get a simple visualization of any metric!
- ```python
- import torch
- from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix
- num_classes = 3
- # this will generate two distributions that comes more similar as iterations increase
- w = torch.randn(num_classes)
- target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
- preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
- acc = MulticlassAccuracy(num_classes=num_classes, average="micro")
- acc_per_class = MulticlassAccuracy(num_classes=num_classes, average=None)
- confmat = MulticlassConfusionMatrix(num_classes=num_classes)
- # plot single value
- for i in range(5):
- acc_per_class.update(preds(i), target(i))
- confmat.update(preds(i), target(i))
- fig1, ax1 = acc_per_class.plot()
- fig2, ax2 = confmat.plot()
- # plot multiple values
- values = []
- for i in range(10):
- values.append(acc(preds(i), target(i)))
- fig3, ax3 = acc.plot(values)
- ```
- <p align="center">
- <img src="https://github.com/Lightning-AI/torchmetrics/raw/v1.9.0/docs/source/_static/images/plot_example.png" width="1000">
- </p>
- For examples of plotting different metrics try running [this example file](_samples/plotting.py).
- # Contribute!
- The lightning + TorchMetrics team is hard at work adding even more metrics.
- But we're looking for incredible contributors like you to submit new metrics
- and improve existing ones!
- Join our [Discord](https://discord.com/invite/tfXFetEZxv) to get help with becoming a contributor!
- # Community
- For help or questions, join our huge community on [Discord](https://discord.com/invite/tfXFetEZxv)!
- # Citation
- We’re excited to continue the strong legacy of open source software and have been inspired
- over the years by Caffe, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai.
- If you want to cite this framework feel free to use GitHub's built-in citation option to generate a bibtex or APA-Style citation based on [this file](https://github.com/Lightning-AI/torchmetrics/blob/master/CITATION.cff) (but only if you loved it 😊).
- # License
- Please observe the Apache 2.0 license that is listed in this repository.
- In addition, the Lightning framework is Patent Pending.
|