_deprecated.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import os
  2. from collections.abc import Sequence
  3. from typing import Any, Callable, List, Literal, Optional, Union
  4. import torch
  5. from torch import Tensor
  6. from torch.nn import Module
  7. from torchmetrics.functional.text.bert import bert_score
  8. from torchmetrics.functional.text.bleu import bleu_score
  9. from torchmetrics.functional.text.cer import char_error_rate
  10. from torchmetrics.functional.text.chrf import chrf_score
  11. from torchmetrics.functional.text.eed import extended_edit_distance
  12. from torchmetrics.functional.text.infolm import (
  13. _ALLOWED_INFORMATION_MEASURE_LITERAL as _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL,
  14. )
  15. from torchmetrics.functional.text.infolm import infolm
  16. from torchmetrics.functional.text.mer import match_error_rate
  17. from torchmetrics.functional.text.perplexity import perplexity
  18. from torchmetrics.functional.text.rouge import rouge_score
  19. from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
  20. from torchmetrics.functional.text.squad import squad
  21. from torchmetrics.functional.text.ter import translation_edit_rate
  22. from torchmetrics.functional.text.wer import word_error_rate
  23. from torchmetrics.functional.text.wil import word_information_lost
  24. from torchmetrics.functional.text.wip import word_information_preserved
  25. from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4
  26. from torchmetrics.utilities.prints import _deprecated_root_import_func
  27. __doctest_requires__ = {("_rouge_score"): ["nltk"]}
  28. if not _TRANSFORMERS_GREATER_EQUAL_4_4:
  29. __doctest_skip__ = ["_bert_score", "_infolm"]
  30. SQUAD_SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]]
  31. SQUAD_TARGETS_TYPE = Union[SQUAD_SINGLE_TARGET_TYPE, list[SQUAD_SINGLE_TARGET_TYPE]]
  32. def _bert_score(
  33. preds: Union[list[str], dict[str, Tensor]],
  34. target: Union[list[str], dict[str, Tensor]],
  35. model_name_or_path: Optional[str] = None,
  36. num_layers: Optional[int] = None,
  37. all_layers: bool = False,
  38. model: Optional[Module] = None,
  39. user_tokenizer: Any = None,
  40. user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None,
  41. verbose: bool = False,
  42. idf: bool = False,
  43. device: Optional[Union[str, torch.device]] = None,
  44. max_length: int = 512,
  45. batch_size: int = 64,
  46. num_threads: int = 4,
  47. return_hash: bool = False,
  48. lang: str = "en",
  49. rescale_with_baseline: bool = False,
  50. baseline_path: Optional[str] = None,
  51. baseline_url: Optional[str] = None,
  52. ) -> dict[str, Union[Tensor, list[float], str]]:
  53. """Wrapper for deprecated import.
  54. >>> preds = ["hello there", "general kenobi"]
  55. >>> target = ["hello there", "master kenobi"]
  56. >>> score = _bert_score(preds, target)
  57. >>> from pprint import pprint
  58. >>> pprint(score)
  59. {'f1': tensor([1.0000, 0.9961]),
  60. 'precision': tensor([1.0000, 0.9961]),
  61. 'recall': tensor([1.0000, 0.9961])}
  62. """
  63. _deprecated_root_import_func("bert_score", "text")
  64. return bert_score(
  65. preds=preds,
  66. target=target,
  67. model_name_or_path=model_name_or_path,
  68. num_layers=num_layers,
  69. all_layers=all_layers,
  70. model=model,
  71. user_tokenizer=user_tokenizer,
  72. user_forward_fn=user_forward_fn,
  73. verbose=verbose,
  74. idf=idf,
  75. device=device,
  76. max_length=max_length,
  77. batch_size=batch_size,
  78. num_threads=num_threads,
  79. return_hash=return_hash,
  80. lang=lang,
  81. rescale_with_baseline=rescale_with_baseline,
  82. baseline_path=baseline_path,
  83. baseline_url=baseline_url,
  84. )
  85. def _bleu_score(
  86. preds: Union[str, Sequence[str]],
  87. target: Sequence[Union[str, Sequence[str]]],
  88. n_gram: int = 4,
  89. smooth: bool = False,
  90. weights: Optional[Sequence[float]] = None,
  91. ) -> Tensor:
  92. """Wrapper for deprecated import.
  93. >>> preds = ['the cat is on the mat']
  94. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  95. >>> _bleu_score(preds, target)
  96. tensor(0.7598)
  97. """
  98. _deprecated_root_import_func("bleu_score", "text")
  99. return bleu_score(preds=preds, target=target, n_gram=n_gram, smooth=smooth, weights=weights)
  100. def _char_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
  101. """Wrapper for deprecated import.
  102. >>> preds = ["this is the prediction", "there is an other sample"]
  103. >>> target = ["this is the reference", "there is another one"]
  104. >>> _char_error_rate(preds=preds, target=target)
  105. tensor(0.3415)
  106. """
  107. _deprecated_root_import_func("char_error_rate", "text")
  108. return char_error_rate(preds=preds, target=target)
  109. def _chrf_score(
  110. preds: Union[str, Sequence[str]],
  111. target: Sequence[Union[str, Sequence[str]]],
  112. n_char_order: int = 6,
  113. n_word_order: int = 2,
  114. beta: float = 2.0,
  115. lowercase: bool = False,
  116. whitespace: bool = False,
  117. return_sentence_level_score: bool = False,
  118. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  119. """Wrapper for deprecated import.
  120. >>> preds = ['the cat is on the mat']
  121. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  122. >>> _chrf_score(preds, target)
  123. tensor(0.8640)
  124. """
  125. _deprecated_root_import_func("chrf_score", "text")
  126. return chrf_score(
  127. preds=preds,
  128. target=target,
  129. n_char_order=n_char_order,
  130. n_word_order=n_word_order,
  131. beta=beta,
  132. lowercase=lowercase,
  133. whitespace=whitespace,
  134. return_sentence_level_score=return_sentence_level_score,
  135. )
  136. def _extended_edit_distance(
  137. preds: Union[str, Sequence[str]],
  138. target: Sequence[Union[str, Sequence[str]]],
  139. language: Literal["en", "ja"] = "en",
  140. return_sentence_level_score: bool = False,
  141. alpha: float = 2.0,
  142. rho: float = 0.3,
  143. deletion: float = 0.2,
  144. insertion: float = 1.0,
  145. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  146. """Wrapper for deprecated import.
  147. >>> preds = ["this is the prediction", "here is an other sample"]
  148. >>> target = ["this is the reference", "here is another one"]
  149. >>> _extended_edit_distance(preds=preds, target=target)
  150. tensor(0.3078)
  151. """
  152. _deprecated_root_import_func("extended_edit_distance", "text")
  153. return extended_edit_distance(
  154. preds=preds,
  155. target=target,
  156. language=language,
  157. return_sentence_level_score=return_sentence_level_score,
  158. alpha=alpha,
  159. rho=rho,
  160. deletion=deletion,
  161. insertion=insertion,
  162. )
  163. def _infolm(
  164. preds: Union[str, Sequence[str]],
  165. target: Union[str, Sequence[str]],
  166. model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased",
  167. temperature: float = 0.25,
  168. information_measure: _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL = "kl_divergence",
  169. idf: bool = True,
  170. alpha: Optional[float] = None,
  171. beta: Optional[float] = None,
  172. device: Optional[Union[str, torch.device]] = None,
  173. max_length: Optional[int] = None,
  174. batch_size: int = 64,
  175. num_threads: int = 0,
  176. verbose: bool = True,
  177. return_sentence_level_score: bool = False,
  178. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  179. """Wrapper for deprecated import.
  180. >>> preds = ['he read the book because he was interested in world history']
  181. >>> target = ['he was interested in world history because he read the book']
  182. >>> _infolm(preds, target, model_name_or_path='google/bert_uncased_L-2_H-128_A-2', idf=False)
  183. tensor(-0.1784)
  184. """
  185. _deprecated_root_import_func("infolm", "text")
  186. return infolm(
  187. preds=preds,
  188. target=target,
  189. model_name_or_path=model_name_or_path,
  190. temperature=temperature,
  191. information_measure=information_measure,
  192. idf=idf,
  193. alpha=alpha,
  194. beta=beta,
  195. device=device,
  196. max_length=max_length,
  197. batch_size=batch_size,
  198. num_threads=num_threads,
  199. verbose=verbose,
  200. return_sentence_level_score=return_sentence_level_score,
  201. )
  202. def _match_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
  203. """Wrapper for deprecated import.
  204. >>> preds = ["this is the prediction", "there is an other sample"]
  205. >>> target = ["this is the reference", "there is another one"]
  206. >>> _match_error_rate(preds=preds, target=target)
  207. tensor(0.4444)
  208. """
  209. _deprecated_root_import_func("match_error_rate", "text")
  210. return match_error_rate(preds=preds, target=target)
  211. def _perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor:
  212. """Wrapper for deprecated import.
  213. >>> from torch import rand, randint
  214. >>> preds = rand(2, 8, 5)
  215. >>> target = randint(5, (2, 8))
  216. >>> target[0, 6:] = -100
  217. >>> _perplexity(preds, target, ignore_index=-100)
  218. tensor(5.8540)
  219. """
  220. _deprecated_root_import_func("perplexity", "text")
  221. return perplexity(preds=preds, target=target, ignore_index=ignore_index)
  222. def _rouge_score(
  223. preds: Union[str, Sequence[str]],
  224. target: Union[str, Sequence[str], Sequence[Sequence[str]]],
  225. accumulate: Literal["avg", "best"] = "best",
  226. use_stemmer: bool = False,
  227. normalizer: Optional[Callable[[str], str]] = None,
  228. tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
  229. rouge_keys: Union[str, tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"),
  230. ) -> dict[str, Tensor]:
  231. """Wrapper for deprecated import.
  232. >>> preds = "My name is John"
  233. >>> target = "Is your name John"
  234. >>> from pprint import pprint
  235. >>> pprint(_rouge_score(preds, target))
  236. {'rouge1_fmeasure': tensor(0.7500),
  237. 'rouge1_precision': tensor(0.7500),
  238. 'rouge1_recall': tensor(0.7500),
  239. 'rouge2_fmeasure': tensor(0.),
  240. 'rouge2_precision': tensor(0.),
  241. 'rouge2_recall': tensor(0.),
  242. 'rougeL_fmeasure': tensor(0.5000),
  243. 'rougeL_precision': tensor(0.5000),
  244. 'rougeL_recall': tensor(0.5000),
  245. 'rougeLsum_fmeasure': tensor(0.5000),
  246. 'rougeLsum_precision': tensor(0.5000),
  247. 'rougeLsum_recall': tensor(0.5000)}
  248. """
  249. _deprecated_root_import_func("rouge_score", "text")
  250. return rouge_score(
  251. preds=preds,
  252. target=target,
  253. accumulate=accumulate,
  254. use_stemmer=use_stemmer,
  255. normalizer=normalizer,
  256. tokenizer=tokenizer,
  257. rouge_keys=rouge_keys,
  258. )
  259. def _sacre_bleu_score(
  260. preds: Sequence[str],
  261. target: Sequence[Sequence[str]],
  262. n_gram: int = 4,
  263. smooth: bool = False,
  264. tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a",
  265. lowercase: bool = False,
  266. weights: Optional[Sequence[float]] = None,
  267. ) -> Tensor:
  268. """Wrapper for deprecated import.
  269. >>> preds = ['the cat is on the mat']
  270. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  271. >>> _sacre_bleu_score(preds, target)
  272. tensor(0.7598)
  273. """
  274. _deprecated_root_import_func("sacre_bleu_score", "text")
  275. return sacre_bleu_score(
  276. preds=preds,
  277. target=target,
  278. n_gram=n_gram,
  279. smooth=smooth,
  280. tokenize=tokenize,
  281. lowercase=lowercase,
  282. weights=weights,
  283. )
  284. def _squad(preds: Union[dict[str, str], list[dict[str, str]]], target: SQUAD_TARGETS_TYPE) -> dict[str, Tensor]:
  285. """Wrapper for deprecated import.
  286. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  287. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
  288. >>> _squad(preds, target)
  289. {'exact_match': tensor(100.), 'f1': tensor(100.)}
  290. """
  291. _deprecated_root_import_func("squad", "text")
  292. return squad(preds=preds, target=target)
  293. def _translation_edit_rate(
  294. preds: Union[str, Sequence[str]],
  295. target: Sequence[Union[str, Sequence[str]]],
  296. normalize: bool = False,
  297. no_punctuation: bool = False,
  298. lowercase: bool = True,
  299. asian_support: bool = False,
  300. return_sentence_level_score: bool = False,
  301. ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]:
  302. """Wrapper for deprecated import.
  303. >>> preds = ['the cat is on the mat']
  304. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  305. >>> _translation_edit_rate(preds, target)
  306. tensor(0.1538)
  307. """
  308. _deprecated_root_import_func("translation_edit_rate", "text")
  309. return translation_edit_rate(
  310. preds=preds,
  311. target=target,
  312. normalize=normalize,
  313. no_punctuation=no_punctuation,
  314. lowercase=lowercase,
  315. asian_support=asian_support,
  316. return_sentence_level_score=return_sentence_level_score,
  317. )
  318. def _word_error_rate(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
  319. """Wrapper for deprecated import.
  320. >>> preds = ["this is the prediction", "there is an other sample"]
  321. >>> target = ["this is the reference", "there is another one"]
  322. >>> _word_error_rate(preds=preds, target=target)
  323. tensor(0.5000)
  324. """
  325. _deprecated_root_import_func("word_error_rate", "text")
  326. return word_error_rate(preds=preds, target=target)
  327. def _word_information_lost(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
  328. """Wrapper for deprecated import.
  329. >>> preds = ["this is the prediction", "there is an other sample"]
  330. >>> target = ["this is the reference", "there is another one"]
  331. >>> _word_information_lost(preds, target)
  332. tensor(0.6528)
  333. """
  334. _deprecated_root_import_func("word_information_lost", "text")
  335. return word_information_lost(preds=preds, target=target)
  336. def _word_information_preserved(preds: Union[str, list[str]], target: Union[str, list[str]]) -> Tensor:
  337. """Wrapper for deprecated import.
  338. >>> preds = ["this is the prediction", "there is an other sample"]
  339. >>> target = ["this is the reference", "there is another one"]
  340. >>> _word_information_preserved(preds, target)
  341. tensor(0.3472)
  342. """
  343. _deprecated_root_import_func("word_information_preserved", "text")
  344. return word_information_preserved(preds=preds, target=target)