| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409 |
- import os
- from collections.abc import Sequence
- from typing import Any, Callable, List, Literal, Optional, Union
- import torch
- from torch import Tensor
- from torch.nn import Module
- from torchmetrics.functional.text.bert import bert_score
- from torchmetrics.functional.text.bleu import bleu_score
- from torchmetrics.functional.text.cer import char_error_rate
- from torchmetrics.functional.text.chrf import chrf_score
- from torchmetrics.functional.text.eed import extended_edit_distance
- from torchmetrics.functional.text.infolm import (
- _ALLOWED_INFORMATION_MEASURE_LITERAL as _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL,
- )
- from torchmetrics.functional.text.infolm import infolm
- from torchmetrics.functional.text.mer import match_error_rate
- from torchmetrics.functional.text.perplexity import perplexity
- from torchmetrics.functional.text.rouge import rouge_score
- from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
- from torchmetrics.functional.text.squad import squad
- from torchmetrics.functional.text.ter import translation_edit_rate
- from torchmetrics.functional.text.wer import word_error_rate
- from torchmetrics.functional.text.wil import word_information_lost
- from torchmetrics.functional.text.wip import word_information_preserved
- from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4
- from torchmetrics.utilities.prints import _deprecated_root_import_func
- __doctest_requires__ = {("_rouge_score"): ["nltk"]}
- if not _TRANSFORMERS_GREATER_EQUAL_4_4:
- __doctest_skip__ = ["_bert_score", "_infolm"]
- SQUAD_SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]]
- SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, list[SQUAD_SINGLE_TARGET_TYPE]]
- def _bert_score(
- preds: Union[list[str], dict[str, Tensor]],
- target: Union[list[str], dict[str, Tensor]],
- model_name_or_path: Optional[str] = None,
- num_layers: Optional[int] = None,
- all_layers: bool = False,
- model: Optional[Module] = None,
- user_tokenizer: Any = None,
- user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None,
- verbose: bool = False,
- idf: bool = False,
- device: Optional[Union[str, torch.device]] = None,
- max_length: int = 512,
- batch_size: int = 64,
- num_threads: int = 4,
- return_hash: bool = False,
- lang: str = "en",
- rescale_with_baseline: bool = False,
- baseline_path: Optional[str] = None,
- baseline_url: Optional[str] = None,
- ) -> dict[str, Union[Tensor, list[float], str]]:
- """Wrapper for deprecated import.
- >>> preds = ["hello there", "general kenobi"]
- >>> target = ["hello there", "master kenobi"]
- >>> score = _bert_score(preds, target)
- >>> from pprint import pprint
- >>> pprint(score)
- {'f1': tensor([1.0000, 0.9961]),
- 'precision': tensor([1.0000, 0.9961]),
- 'recall': tensor([1.0000, 0.9961])}
- """
- _deprecated_root_import_func("bert_score", "text")
- return bert_score(
- preds=preds,
- target=target,
- model_name_or_path=model_name_or_path,
- num_layers=num_layers,
- all_layers=all_layers,
- model=model,
- user_tokenizer=user_tokenizer,
- user_forward_fn=user_forward_fn,
- verbose=verbose,
- idf=idf,
- device=device,
- max_length=max_length,
- batch_size=batch_size,
- num_threads=num_threads,
- return_hash=return_hash,
- lang=lang,
- rescale_with_baseline=rescale_with_baseline,
- baseline_path=baseline_path,
- baseline_url=baseline_url,
- )
- def _bleu_score(
- preds: Union[str, Sequence[str]],
- target: Sequence[Union[str, Sequence[str]]],
- n_gram: int = 4,
- smooth: bool = False,
- weights: Optional[Sequence[float]] = None,
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> _bleu_score(preds, target)
- tensor(0.7598)
- """
- _deprecated_root_import_func("bleu_score", "text")
- return bleu_score(preds=preds, target=target, n_gram=n_gram, smooth=smooth, weights=weights)
- def _char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> _char_error_rate(preds=preds, target=target)
- tensor(0.3415)
- """
- _deprecated_root_import_func("char_error_rate", "text")
- return char_error_rate(preds=preds, target=target)
- def _chrf_score(
- preds: Union[str, Sequence[str]],
- target: Sequence[Union[str, Sequence[str]]],
- n_char_order: int = 6,
- n_word_order: int = 2,
- beta: float = 2.0,
- lowercase: bool = False,
- whitespace: bool = False,
- return_sentence_level_score: bool = False,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> _chrf_score(preds, target)
- tensor(0.8640)
- """
- _deprecated_root_import_func("chrf_score", "text")
- return chrf_score(
- preds=preds,
- target=target,
- n_char_order=n_char_order,
- n_word_order=n_word_order,
- beta=beta,
- lowercase=lowercase,
- whitespace=whitespace,
- return_sentence_level_score=return_sentence_level_score,
- )
- def _extended_edit_distance(
- preds: Union[str, Sequence[str]],
- target: Sequence[Union[str, Sequence[str]]],
- language: Literal["en", "ja"] = "en",
- return_sentence_level_score: bool = False,
- alpha: float = 2.0,
- rho: float = 0.3,
- deletion: float = 0.2,
- insertion: float = 1.0,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "here is an other sample"]
- >>> target = ["this is the reference", "here is another one"]
- >>> _extended_edit_distance(preds=preds, target=target)
- tensor(0.3078)
- """
- _deprecated_root_import_func("extended_edit_distance", "text")
- return extended_edit_distance(
- preds=preds,
- target=target,
- language=language,
- return_sentence_level_score=return_sentence_level_score,
- alpha=alpha,
- rho=rho,
- deletion=deletion,
- insertion=insertion,
- )
- def _infolm(
- preds: Union[str, Sequence[str]],
- target: Union[str, Sequence[str]],
- model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased",
- temperature: float = 0.25,
- information_measure: _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL = "kl_divergence",
- idf: bool = True,
- alpha: Optional[float] = None,
- beta: Optional[float] = None,
- device: Optional[Union[str, torch.device]] = None,
- max_length: Optional[int] = None,
- batch_size: int = 64,
- num_threads: int = 0,
- verbose: bool = True,
- return_sentence_level_score: bool = False,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Wrapper for deprecated import.
- >>> preds = ['he read the book because he was interested in world history']
- >>> target = ['he was interested in world history because he read the book']
- >>> _infolm(preds, target, model_name_or_path='google/bert_uncased_L-2_H-128_A-2', idf=False)
- tensor(-0.1784)
- """
- _deprecated_root_import_func("infolm", "text")
- return infolm(
- preds=preds,
- target=target,
- model_name_or_path=model_name_or_path,
- temperature=temperature,
- information_measure=information_measure,
- idf=idf,
- alpha=alpha,
- beta=beta,
- device=device,
- max_length=max_length,
- batch_size=batch_size,
- num_threads=num_threads,
- verbose=verbose,
- return_sentence_level_score=return_sentence_level_score,
- )
- def _match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> _match_error_rate(preds=preds, target=target)
- tensor(0.4444)
- """
- _deprecated_root_import_func("match_error_rate", "text")
- return match_error_rate(preds=preds, target=target)
- def _perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand, randint
- >>> preds = rand(2, 8, 5)
- >>> target = randint(5, (2, 8))
- >>> target[0, 6:] = -100
- >>> _perplexity(preds, target, ignore_index=-100)
- tensor(5.8540)
- """
- _deprecated_root_import_func("perplexity", "text")
- return perplexity(preds=preds, target=target, ignore_index=ignore_index)
- def _rouge_score(
- preds: Union[str, Sequence[str]],
- target: Union[str, Sequence[str], Sequence[Sequence[str]]],
- accumulate: Literal["avg", "best"] = "best",
- use_stemmer: bool = False,
- normalizer: Optional[Callable[[str], str]] = None,
- tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
- rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"),
- ) -> dict[str, Tensor]:
- """Wrapper for deprecated import.
- >>> preds = "My name is John"
- >>> target = "Is your name John"
- >>> from pprint import pprint
- >>> pprint(_rouge_score(preds, target))
- {'rouge1_fmeasure': tensor(0.7500),
- 'rouge1_precision': tensor(0.7500),
- 'rouge1_recall': tensor(0.7500),
- 'rouge2_fmeasure': tensor(0.),
- 'rouge2_precision': tensor(0.),
- 'rouge2_recall': tensor(0.),
- 'rougeL_fmeasure': tensor(0.5000),
- 'rougeL_precision': tensor(0.5000),
- 'rougeL_recall': tensor(0.5000),
- 'rougeLsum_fmeasure': tensor(0.5000),
- 'rougeLsum_precision': tensor(0.5000),
- 'rougeLsum_recall': tensor(0.5000)}
- """
- _deprecated_root_import_func("rouge_score", "text")
- return rouge_score(
- preds=preds,
- target=target,
- accumulate=accumulate,
- use_stemmer=use_stemmer,
- normalizer=normalizer,
- tokenizer=tokenizer,
- rouge_keys=rouge_keys,
- )
- def _sacre_bleu_score(
- preds: Sequence[str],
- target: Sequence[Sequence[str]],
- n_gram: int = 4,
- smooth: bool = False,
- tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
- lowercase: bool = False,
- weights: Optional[Sequence[float]] = None,
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> _sacre_bleu_score(preds, target)
- tensor(0.7598)
- """
- _deprecated_root_import_func("sacre_bleu_score", "text")
- return sacre_bleu_score(
- preds=preds,
- target=target,
- n_gram=n_gram,
- smooth=smooth,
- tokenize=tokenize,
- lowercase=lowercase,
- weights=weights,
- )
- def _squad(preds: Union[dict[str, str], list[dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> dict[str, Tensor]:
- """Wrapper for deprecated import.
- >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
- >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
- >>> _squad(preds, target)
- {'exact_match': tensor(100.), 'f1': tensor(100.)}
- """
- _deprecated_root_import_func("squad", "text")
- return squad(preds=preds, target=target)
- def _translation_edit_rate(
- preds: Union[str, Sequence[str]],
- target: Sequence[Union[str, Sequence[str]]],
- normalize: bool = False,
- no_punctuation: bool = False,
- lowercase: bool = True,
- asian_support: bool = False,
- return_sentence_level_score: bool = False,
- ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]:
- """Wrapper for deprecated import.
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> _translation_edit_rate(preds, target)
- tensor(0.1538)
- """
- _deprecated_root_import_func("translation_edit_rate", "text")
- return translation_edit_rate(
- preds=preds,
- target=target,
- normalize=normalize,
- no_punctuation=no_punctuation,
- lowercase=lowercase,
- asian_support=asian_support,
- return_sentence_level_score=return_sentence_level_score,
- )
- def _word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> _word_error_rate(preds=preds, target=target)
- tensor(0.5000)
- """
- _deprecated_root_import_func("word_error_rate", "text")
- return word_error_rate(preds=preds, target=target)
- def _word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> _word_information_lost(preds, target)
- tensor(0.6528)
- """
- _deprecated_root_import_func("word_information_lost", "text")
- return word_information_lost(preds=preds, target=target)
- def _word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
- """Wrapper for deprecated import.
- >>> preds = ["this is the prediction", "there is an other sample"]
- >>> target = ["this is the reference", "there is another one"]
- >>> _word_information_preserved(preds, target)
- tensor(0.3472)
- """
- _deprecated_root_import_func("word_information_preserved", "text")
- return word_information_preserved(preds=preds, target=target)
|