| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- from collections.abc import Sequence
- from typing import Any, Literal, Optional
- from torchmetrics.text.bleu import BLEUScore
- from torchmetrics.text.cer import CharErrorRate
- from torchmetrics.text.chrf import CHRFScore
- from torchmetrics.text.eed import ExtendedEditDistance
- from torchmetrics.text.mer import MatchErrorRate
- from torchmetrics.text.perplexity import Perplexity
- from torchmetrics.text.sacre_bleu import SacreBLEUScore
- from torchmetrics.text.squad import SQuAD
- from torchmetrics.text.ter import TranslationEditRate
- from torchmetrics.text.wer import WordErrorRate
- from torchmetrics.text.wil import WordInfoLost
- from torchmetrics.text.wip import WordInfoPreserved
- from torchmetrics.utilities.prints import _deprecated_root_import_class
- class _BLEUScore(BLEUScore):
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> bleu = _BLEUScore()
- >>> bleu(preds, target)
- tensor(0.7598)
- """
- def __init__(
- self,
- n_gram: int = 4,
- smooth: bool = False,
- weights: Optional[Sequence[float]] = None,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("BLEUScore", "text")
- super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
- class _CharErrorRate(CharErrorRate):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> cer = _CharErrorRate()
- >>> cer(preds, target)
- tensor(0.3415)
- """
- def __init__(
- self,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("CharErrorRate", "text")
- super().__init__(**kwargs)
- class _CHRFScore(CHRFScore):
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> chrf = _CHRFScore()
- >>> chrf(preds, target)
- tensor(0.8640)
- """
- def __init__(
- self,
- n_char_order: int = 6,
- n_word_order: int = 2,
- beta: float = 2.0,
- lowercase: bool = False,
- whitespace: bool = False,
- return_sentence_level_score: bool = False,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("CHRFScore", "text")
- super().__init__(
- n_char_order=n_char_order,
- n_word_order=n_word_order,
- beta=beta,
- lowercase=lowercase,
- whitespace=whitespace,
- return_sentence_level_score=return_sentence_level_score,
- **kwargs,
- )
- class _ExtendedEditDistance(ExtendedEditDistance):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "here is an other sample"]
- >>> target = ["this is the reference", "here is another one"]
- >>> eed = _ExtendedEditDistance()
- >>> eed(preds=preds, target=target)
- tensor(0.3078)
- """
- def __init__(
- self,
- language: Literal["en", "ja"] = "en",
- return_sentence_level_score: bool = False,
- alpha: float = 2.0,
- rho: float = 0.3,
- deletion: float = 0.2,
- insertion: float = 1.0,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("ExtendedEditDistance", "text")
- super().__init__(
- language=language,
- return_sentence_level_score=return_sentence_level_score,
- alpha=alpha,
- rho=rho,
- deletion=deletion,
- insertion=insertion,
- **kwargs,
- )
- class _MatchErrorRate(MatchErrorRate):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> mer = _MatchErrorRate()
- >>> mer(preds, target)
- tensor(0.4444)
- """
- def __init__(
- self,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("MatchErrorRate", "text")
- super().__init__(**kwargs)
- class _Perplexity(Perplexity):
- """Wrapper for deprecated import.
- >>> from torch import rand, randint
- >>> preds = rand(2, 8, 5)
- >>> target = randint(5, (2, 8))
- >>> target[0, 6:] = -100
- >>> perp = _Perplexity(ignore_index=-100)
- >>> perp(preds, target)
- tensor(5.8540)
- """
- def __init__(
- self,
- ignore_index: Optional[int] = None,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("Perplexity", "text")
- super().__init__(ignore_index=ignore_index, **kwargs)
- class _SacreBLEUScore(SacreBLEUScore):
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> sacre_bleu = _SacreBLEUScore()
- >>> sacre_bleu(preds, target)
- tensor(0.7598)
- """
- def __init__(
- self,
- n_gram: int = 4,
- smooth: bool = False,
- tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
- lowercase: bool = False,
- weights: Optional[Sequence[float]] = None,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("SacreBLEUScore", "text")
- super().__init__(
- n_gram=n_gram, smooth=smooth, tokenize=tokenize, lowercase=lowercase, weights=weights, **kwargs
- )
- class _SQuAD(SQuAD):
- """Wrapper for deprecated import.
- >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
- >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
- >>> squad = _SQuAD()
- >>> squad(preds, target)
- {'exact_match': tensor(100.), 'f1': tensor(100.)}
- """
- def __init__(self, **kwargs: Any) -> None:
- _deprecated_root_import_class("SQuAD", "text")
- super().__init__(**kwargs)
- class _TranslationEditRate(TranslationEditRate):
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> ter = _TranslationEditRate()
- >>> ter(preds, target)
- tensor(0.1538)
- """
- def __init__(
- self,
- normalize: bool = False,
- no_punctuation: bool = False,
- lowercase: bool = True,
- asian_support: bool = False,
- return_sentence_level_score: bool = False,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("TranslationEditRate", "text")
- super().__init__(
- normalize=normalize,
- no_punctuation=no_punctuation,
- lowercase=lowercase,
- asian_support=asian_support,
- return_sentence_level_score=return_sentence_level_score,
- **kwargs,
- )
- class _WordErrorRate(WordErrorRate):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> wer = _WordErrorRate()
- >>> wer(preds, target)
- tensor(0.5000)
- """
- def __init__(self, **kwargs: Any) -> None:
- _deprecated_root_import_class("WordErrorRate", "text")
- super().__init__(**kwargs)
- class _WordInfoLost(WordInfoLost):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> wil = _WordInfoLost()
- >>> wil(preds, target)
- tensor(0.6528)
- """
- def __init__(self, **kwargs: Any) -> None:
- _deprecated_root_import_class("WordInfoLost", "text")
- super().__init__(**kwargs)
- class _WordInfoPreserved(WordInfoPreserved):
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> wip = WordInfoPreserved()
- >>> wip(preds, target)
- tensor(0.3472)
- """
- def __init__(self, **kwargs: Any) -> None:
- _deprecated_root_import_class("WordInfoPreserved", "text")
- super().__init__(**kwargs)
|