import os from collections.abc import Sequence from typing import Any, Callable, List, Literal, Optional, Union import torch from torch import Tensor from torch.nn import Module from torchmetrics.functional.text.bert import bert_score from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.infolm import ( _ALLOWED_INFORMATION_MEASURE_LITERAL as _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL, ) from torchmetrics.functional.text.infolm import infolm from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.functional.text.perplexity import perplexity from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.functional.text.squad import squad from torchmetrics.functional.text.ter import translation_edit_rate from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from torchmetrics.utilities.prints import _deprecated_root_import_func __doctest_requires__ = {("_rouge_score"): ["nltk"]} if not _TRANSFORMERS_GREATER_EQUAL_4_4: __doctest_skip__ = ["_bert_score", "_infolm"] SQUAD_SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, list[SQUAD_SINGLE_TARGET_TYPE]] def _bert_score( preds: Union[list[str], dict[str, Tensor]], target: Union[list[str], dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, model: Optional[Module] = None, user_tokenizer: Any = None, user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, verbose: bool = False, idf: bool = False, device: Optional[Union[str, torch.device]] = None, max_length: int = 512, batch_size: int = 64, num_threads: int = 4, return_hash: bool = False, lang: str = "en", rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, ) -> dict[str, Union[Tensor, list[float], str]]: """Wrapper for deprecated import. >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> score = _bert_score(preds, target) >>> from pprint import pprint >>> pprint(score) {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])} """ _deprecated_root_import_func("bert_score", "text") return bert_score( preds=preds, target=target, model_name_or_path=model_name_or_path, num_layers=num_layers, all_layers=all_layers, model=model, user_tokenizer=user_tokenizer, user_forward_fn=user_forward_fn, verbose=verbose, idf=idf, device=device, max_length=max_length, batch_size=batch_size, num_threads=num_threads, return_hash=return_hash, lang=lang, rescale_with_baseline=rescale_with_baseline, baseline_path=baseline_path, baseline_url=baseline_url, ) def _bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, weights: Optional[Sequence[float]] = None, ) -> Tensor: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> _bleu_score(preds, target) tensor(0.7598) """ _deprecated_root_import_func("bleu_score", "text") return bleu_score(preds=preds, target=target, n_gram=n_gram, smooth=smooth, weights=weights) def _char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> _char_error_rate(preds=preds, target=target) tensor(0.3415) """ _deprecated_root_import_func("char_error_rate", "text") return char_error_rate(preds=preds, target=target) def _chrf_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, ) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> _chrf_score(preds, target) tensor(0.8640) """ _deprecated_root_import_func("chrf_score", "text") return chrf_score( preds=preds, target=target, n_char_order=n_char_order, n_word_order=n_word_order, beta=beta, lowercase=lowercase, whitespace=whitespace, return_sentence_level_score=return_sentence_level_score, ) def _extended_edit_distance( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], language: Literal["en", "ja"] = "en", return_sentence_level_score: bool = False, alpha: float = 2.0, rho: float = 0.3, deletion: float = 0.2, insertion: float = 1.0, ) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "here is an other sample"] >>> target = ["this is the reference", "here is another one"] >>> _extended_edit_distance(preds=preds, target=target) tensor(0.3078) """ _deprecated_root_import_func("extended_edit_distance", "text") return extended_edit_distance( preds=preds, target=target, language=language, return_sentence_level_score=return_sentence_level_score, alpha=alpha, rho=rho, deletion=deletion, insertion=insertion, ) def _infolm( preds: Union[str, Sequence[str]], target: Union[str, Sequence[str]], model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased", temperature: float = 0.25, information_measure: _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL = "kl_divergence", idf: bool = True, alpha: Optional[float] = None, beta: Optional[float] = None, device: Optional[Union[str, torch.device]] = None, max_length: Optional[int] = None, batch_size: int = 64, num_threads: int = 0, verbose: bool = True, return_sentence_level_score: bool = False, ) -> Union[Tensor, tuple[Tensor, Tensor]]: """Wrapper for deprecated import. >>> preds = ['he read the book because he was interested in world history'] >>> target = ['he was interested in world history because he read the book'] >>> _infolm(preds, target, model_name_or_path='google/bert_uncased_L-2_H-128_A-2', idf=False) tensor(-0.1784) """ _deprecated_root_import_func("infolm", "text") return infolm( preds=preds, target=target, model_name_or_path=model_name_or_path, temperature=temperature, information_measure=information_measure, idf=idf, alpha=alpha, beta=beta, device=device, max_length=max_length, batch_size=batch_size, num_threads=num_threads, verbose=verbose, return_sentence_level_score=return_sentence_level_score, ) def _match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> _match_error_rate(preds=preds, target=target) tensor(0.4444) """ _deprecated_root_import_func("match_error_rate", "text") return match_error_rate(preds=preds, target=target) def _perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor: """Wrapper for deprecated import. >>> from torch import rand, randint >>> preds = rand(2, 8, 5) >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> _perplexity(preds, target, ignore_index=-100) tensor(5.8540) """ _deprecated_root_import_func("perplexity", "text") return perplexity(preds=preds, target=target, ignore_index=ignore_index) def _rouge_score( preds: Union[str, Sequence[str]], target: Union[str, Sequence[str], Sequence[Sequence[str]]], accumulate: Literal["avg", "best"] = "best", use_stemmer: bool = False, normalizer: Optional[Callable[[str], str]] = None, tokenizer: Optional[Callable[[str], Sequence[str]]] = None, rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), ) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = "My name is John" >>> target = "Is your name John" >>> from pprint import pprint >>> pprint(_rouge_score(preds, target)) {'rouge1_fmeasure': tensor(0.7500), 'rouge1_precision': tensor(0.7500), 'rouge1_recall': tensor(0.7500), 'rouge2_fmeasure': tensor(0.), 'rouge2_precision': tensor(0.), 'rouge2_recall': tensor(0.), 'rougeL_fmeasure': tensor(0.5000), 'rougeL_precision': tensor(0.5000), 'rougeL_recall': tensor(0.5000), 'rougeLsum_fmeasure': tensor(0.5000), 'rougeLsum_precision': tensor(0.5000), 'rougeLsum_recall': tensor(0.5000)} """ _deprecated_root_import_func("rouge_score", "text") return rouge_score( preds=preds, target=target, accumulate=accumulate, use_stemmer=use_stemmer, normalizer=normalizer, tokenizer=tokenizer, rouge_keys=rouge_keys, ) def _sacre_bleu_score( preds: Sequence[str], target: Sequence[Sequence[str]], n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, weights: Optional[Sequence[float]] = None, ) -> Tensor: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> _sacre_bleu_score(preds, target) tensor(0.7598) """ _deprecated_root_import_func("sacre_bleu_score", "text") return sacre_bleu_score( preds=preds, target=target, n_gram=n_gram, smooth=smooth, tokenize=tokenize, lowercase=lowercase, weights=weights, ) def _squad(preds: Union[dict[str, str], list[dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> dict[str, Tensor]: """Wrapper for deprecated import. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] >>> _squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)} """ _deprecated_root_import_func("squad", "text") return squad(preds=preds, target=target) def _translation_edit_rate( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], normalize: bool = False, no_punctuation: bool = False, lowercase: bool = True, asian_support: bool = False, return_sentence_level_score: bool = False, ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: """Wrapper for deprecated import. >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> _translation_edit_rate(preds, target) tensor(0.1538) """ _deprecated_root_import_func("translation_edit_rate", "text") return translation_edit_rate( preds=preds, target=target, normalize=normalize, no_punctuation=no_punctuation, lowercase=lowercase, asian_support=asian_support, return_sentence_level_score=return_sentence_level_score, ) def _word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> _word_error_rate(preds=preds, target=target) tensor(0.5000) """ _deprecated_root_import_func("word_error_rate", "text") return word_error_rate(preds=preds, target=target) def _word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> _word_information_lost(preds, target) tensor(0.6528) """ _deprecated_root_import_func("word_information_lost", "text") return word_information_lost(preds=preds, target=target) def _word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor: """Wrapper for deprecated import. >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> _word_information_preserved(preds, target) tensor(0.3472) """ _deprecated_root_import_func("word_information_preserved", "text") return word_information_preserved(preds=preds, target=target)