chrf.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  15. # Copyright 2017 Maja Popovic
  16. # The code is derived from https://github.com/m-popovic/chrF/blob/6d3c384/chrF%2B%2B.py
  17. # The original author and copyright holder have agreed to relicense the derived code under the Apache License,
  18. # Version 2.0 (the "License")
  19. # Reference to the approval: https://github.com/Lightning-AI/torchmetrics/pull/2701#issuecomment-2316891785
  20. from collections import defaultdict
  21. from collections.abc import Sequence
  22. from itertools import chain
  23. from typing import List, Optional, Union
  24. import torch
  25. from torch import Tensor, tensor
  26. from torchmetrics.functional.text.helper import _validate_inputs
  27. _EPS_SMOOTHING = tensor(1e-16)
  28. # Taken from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py
  29. _PUNCTUATIONS = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
  30. def _prepare_n_grams_dicts(
  31. n_char_order: int, n_word_order: int
  32. ) -> tuple[
  33. dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor]
  34. ]:
  35. """Prepare dictionaries with default zero values for total ref, hypothesis and matching character and word n-grams.
  36. Args:
  37. n_char_order: A character n-gram order.
  38. n_word_order: A word n-gram order.
  39. Return:
  40. Dictionaries with default zero values for total reference, hypothesis and matching character and word
  41. n-grams.
  42. """
  43. total_preds_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)}
  44. total_preds_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)}
  45. total_target_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)}
  46. total_target_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)}
  47. total_matching_char_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)}
  48. total_matching_word_n_grams: dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)}
  49. return (
  50. total_preds_char_n_grams,
  51. total_preds_word_n_grams,
  52. total_target_char_n_grams,
  53. total_target_word_n_grams,
  54. total_matching_char_n_grams,
  55. total_matching_word_n_grams,
  56. )
  57. def _get_characters(sentence: str, whitespace: bool) -> list[str]:
  58. """Split sentence into individual characters.
  59. Args:
  60. sentence: An input sentence to split.
  61. whitespace: An indication whether to keep whitespaces during character n-gram extraction.
  62. Return:
  63. A list of separated characters.
  64. """
  65. if whitespace:
  66. return list(sentence)
  67. return list(sentence.strip().replace(" ", ""))
  68. def _separate_word_and_punctuation(word: str) -> list[str]:
  69. """Separates out punctuation from beginning and end of words for chrF.
  70. Adapted from https://github.com/m-popovic/chrF and
  71. https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.
  72. Args:
  73. word: An input word to be separated from a punctuation if present.
  74. Return:
  75. A list of a single word or a separated word and punctuation.
  76. """
  77. if len(word) == 1:
  78. return [word]
  79. if word[-1] in _PUNCTUATIONS:
  80. return [word[:-1], word[-1]]
  81. if word[0] in _PUNCTUATIONS:
  82. return [word[0], word[1:]]
  83. return [word]
  84. def _get_words_and_punctuation(sentence: str) -> list[str]:
  85. """Separates out punctuation from beginning and end of words for chrF for all words in the sentence.
  86. Args:
  87. sentence: An input sentence to split
  88. Return:
  89. An aggregated list of separated words and punctuation.
  90. """
  91. return list(chain.from_iterable(_separate_word_and_punctuation(word) for word in sentence.strip().split()))
  92. def _ngram_counts(char_or_word_list: list[str], n_gram_order: int) -> dict[int, dict[tuple[str, ...], Tensor]]:
  93. """Calculate n-gram counts.
  94. Args:
  95. char_or_word_list: A list of characters of words
  96. n_gram_order: The largest number of n-gram.
  97. Return:
  98. A dictionary of dictionaries with a counts of given n-grams.
  99. """
  100. ngrams: dict[int, dict[tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0)))
  101. for n in range(1, n_gram_order + 1):
  102. for ngram in (tuple(char_or_word_list[i : i + n]) for i in range(len(char_or_word_list) - n + 1)):
  103. ngrams[n][ngram] += tensor(1)
  104. return ngrams
  105. def _get_n_grams_counts_and_total_ngrams(
  106. sentence: str, n_char_order: int, n_word_order: int, lowercase: bool, whitespace: bool
  107. ) -> tuple[
  108. dict[int, dict[tuple[str, ...], Tensor]],
  109. dict[int, dict[tuple[str, ...], Tensor]],
  110. dict[int, Tensor],
  111. dict[int, Tensor],
  112. ]:
  113. """Get n-grams and total n-grams.
  114. Args:
  115. sentence: An input sentence
  116. n_char_order: A character n-gram order.
  117. n_word_order: A word n-gram order.
  118. lowercase: An indication whether to enable case-insensitivity.
  119. whitespace: An indication whether to keep whitespaces during character n-gram extraction.
  120. Return:
  121. char_n_grams_counts: A dictionary of dictionaries with sentence character n-grams.
  122. word_n_grams_counts: A dictionary of dictionaries with sentence word n-grams.
  123. total_char_n_grams: A dictionary containing a total number of sentence character n-grams.
  124. total_word_n_grams: A dictionary containing a total number of sentence word n-grams.
  125. """
  126. def _char_and_word_ngrams_counts(
  127. sentence: str, n_char_order: int, n_word_order: int, lowercase: bool
  128. ) -> tuple[dict[int, dict[tuple[str, ...], Tensor]], dict[int, dict[tuple[str, ...], Tensor]]]:
  129. """Get a dictionary of dictionaries with a counts of given n-grams."""
  130. if lowercase:
  131. sentence = sentence.lower()
  132. char_n_grams_counts = _ngram_counts(_get_characters(sentence, whitespace), n_char_order)
  133. word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order)
  134. return char_n_grams_counts, word_n_grams_counts
  135. def _get_total_ngrams(n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]]) -> dict[int, Tensor]:
  136. """Get total sum of n-grams over n-grams w.r.t n."""
  137. total_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  138. for n in n_grams_counts:
  139. total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore
  140. return total_n_grams
  141. char_n_grams_counts, word_n_grams_counts = _char_and_word_ngrams_counts(
  142. sentence, n_char_order, n_word_order, lowercase
  143. )
  144. total_char_n_grams = _get_total_ngrams(char_n_grams_counts)
  145. total_word_n_grams = _get_total_ngrams(word_n_grams_counts)
  146. return char_n_grams_counts, word_n_grams_counts, total_char_n_grams, total_word_n_grams
  147. def _get_ngram_matches(
  148. hyp_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]],
  149. ref_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]],
  150. ) -> dict[int, Tensor]:
  151. """Get a number of n-gram matches between reference and hypothesis n-grams.
  152. Args:
  153. hyp_n_grams_counts: n-grams counts for hypothesis
  154. ref_n_grams_counts: n-grams counts for reference
  155. Return:
  156. matching_n_grams
  157. """
  158. matching_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  159. for n in hyp_n_grams_counts:
  160. min_n_grams = [
  161. torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n]
  162. ]
  163. matching_n_grams[n] = sum(min_n_grams).detach().clone() # type: ignore
  164. return matching_n_grams
  165. def _sum_over_dicts(total_n_grams: dict[int, Tensor], n_grams: dict[int, Tensor]) -> dict[int, Tensor]:
  166. """Aggregate total n-grams to keep corpus-level statistics.
  167. Args:
  168. total_n_grams: A dictionary containing a total corpus-level number of n-grams.
  169. n_grams: A dictionary containing a sentence-level number of n-grams.
  170. Return:
  171. A dictionary containing a total corpus-level number of n-grams.
  172. """
  173. for n in n_grams:
  174. total_n_grams[n] += n_grams[n]
  175. return total_n_grams
  176. def _calculate_fscore(
  177. matching_char_n_grams: dict[int, Tensor],
  178. matching_word_n_grams: dict[int, Tensor],
  179. hyp_char_n_grams: dict[int, Tensor],
  180. hyp_word_n_grams: dict[int, Tensor],
  181. ref_char_n_grams: dict[int, Tensor],
  182. ref_word_n_grams: dict[int, Tensor],
  183. n_order: float,
  184. beta: float,
  185. ) -> Tensor:
  186. """Calculate sentence-level chrF/chrF++ score.
  187. For given hypothesis and reference statistics (either sentence-level or corpus-level)
  188. the chrF/chrF++ score is returned.
  189. Args:
  190. matching_char_n_grams:
  191. A total number of matching character n-grams between the best matching reference and hypothesis.
  192. matching_word_n_grams:
  193. A total number of matching word n-grams between the best matching reference and hypothesis.
  194. hyp_char_n_grams: A total number of hypothesis character n-grams.
  195. hyp_word_n_grams: A total number of hypothesis word n-grams.
  196. ref_char_n_grams: A total number of reference character n-grams.
  197. ref_word_n_grams: A total number of reference word n-grams.
  198. n_order: A sum of character and word n-gram order.
  199. beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal.
  200. Return:
  201. A chrF/chrF++ score. This function is universal both for sentence-level and corpus-level calculation.
  202. """
  203. def _get_n_gram_fscore(
  204. matching_n_grams: dict[int, Tensor], ref_n_grams: dict[int, Tensor], hyp_n_grams: dict[int, Tensor], beta: float
  205. ) -> dict[int, Tensor]:
  206. """Get n-gram level f-score."""
  207. precision: dict[int, Tensor] = {
  208. n: matching_n_grams[n] / hyp_n_grams[n] if hyp_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams
  209. }
  210. recall: dict[int, Tensor] = {
  211. n: matching_n_grams[n] / ref_n_grams[n] if ref_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams
  212. }
  213. denominator: dict[int, Tensor] = {
  214. n: torch.max(beta**2 * precision[n] + recall[n], _EPS_SMOOTHING) for n in matching_n_grams
  215. }
  216. f_score: dict[int, Tensor] = {
  217. n: (1 + beta**2) * precision[n] * recall[n] / denominator[n] for n in matching_n_grams
  218. }
  219. return f_score
  220. char_n_gram_f_score = _get_n_gram_fscore(matching_char_n_grams, ref_char_n_grams, hyp_char_n_grams, beta)
  221. word_n_gram_f_score = _get_n_gram_fscore(matching_word_n_grams, ref_word_n_grams, hyp_word_n_grams, beta)
  222. return (sum(char_n_gram_f_score.values()) + sum(word_n_gram_f_score.values())) / tensor(n_order)
  223. def _calculate_sentence_level_chrf_score(
  224. targets: list[str],
  225. pred_char_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]],
  226. pred_word_n_grams_counts: dict[int, dict[tuple[str, ...], Tensor]],
  227. pred_char_n_grams: dict[int, Tensor],
  228. pred_word_n_grams: dict[int, Tensor],
  229. n_char_order: int,
  230. n_word_order: int,
  231. n_order: float,
  232. beta: float,
  233. lowercase: bool,
  234. whitespace: bool,
  235. ) -> tuple[Tensor, dict[int, Tensor], dict[int, Tensor], dict[int, Tensor], dict[int, Tensor]]:
  236. """Calculate the best sentence-level chrF/chrF++ score.
  237. For a given pre-processed hypothesis, all references are evaluated and score and statistics
  238. for the best matching reference is returned.
  239. Args:
  240. targets: An iterable of references.
  241. pred_char_n_grams_counts: A dictionary of dictionaries with hypothesis character n-grams.
  242. pred_word_n_grams_counts: A dictionary of dictionaries with hypothesis word n-grams.
  243. pred_char_n_grams: A total number of hypothesis character n-grams.
  244. pred_word_n_grams: A total number of hypothesis word n-grams.
  245. n_char_order: A character n-gram order.
  246. n_word_order: A word n-gram order.
  247. n_order: A sum of character and word n-gram order.
  248. beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal.
  249. lowercase: An indication whether to enable case-insensitivity.
  250. whitespace: An indication whether to keep whitespaces during character n-gram extraction.
  251. Return:
  252. Return chrF/chrF++ score and statistics for the best matching hypothesis and reference.
  253. f_score: A sentence-level chrF/chrF++ score.
  254. matching_char_n_grams:
  255. A total number of matching character n-grams between the best matching reference and hypothesis.
  256. matching_word_n_grams:
  257. A total number of matching word n-grams between the best matching reference and hypothesis.
  258. target_char_n_grams: A total number of reference character n-grams.
  259. target_word_n_grams: A total number of reference word n-grams.
  260. """
  261. best_f_score = tensor(0.0)
  262. best_matching_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  263. best_matching_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  264. best_target_char_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  265. best_target_word_n_grams: dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
  266. for target in targets:
  267. (
  268. target_char_n_grams_counts,
  269. target_word_n_grams_counts,
  270. target_char_n_grams,
  271. target_word_n_grams,
  272. ) = _get_n_grams_counts_and_total_ngrams(target, n_char_order, n_word_order, lowercase, whitespace)
  273. matching_char_n_grams = _get_ngram_matches(target_char_n_grams_counts, pred_char_n_grams_counts)
  274. matching_word_n_grams = _get_ngram_matches(target_word_n_grams_counts, pred_word_n_grams_counts)
  275. f_score = _calculate_fscore(
  276. matching_char_n_grams,
  277. matching_word_n_grams,
  278. pred_char_n_grams,
  279. pred_word_n_grams,
  280. target_char_n_grams,
  281. target_word_n_grams,
  282. n_order,
  283. beta,
  284. )
  285. if f_score > best_f_score:
  286. best_f_score = f_score
  287. best_matching_char_n_grams = matching_char_n_grams
  288. best_matching_word_n_grams = matching_word_n_grams
  289. best_target_char_n_grams = target_char_n_grams
  290. best_target_word_n_grams = target_word_n_grams
  291. return (
  292. best_f_score,
  293. best_matching_char_n_grams,
  294. best_matching_word_n_grams,
  295. best_target_char_n_grams,
  296. best_target_word_n_grams,
  297. )
  298. def _chrf_score_update(
  299. preds: Union[str, Sequence[str]],
  300. target: Union[Sequence[str], Sequence[Sequence[str]]],
  301. total_preds_char_n_grams: dict[int, Tensor],
  302. total_preds_word_n_grams: dict[int, Tensor],
  303. total_target_char_n_grams: dict[int, Tensor],
  304. total_target_word_n_grams: dict[int, Tensor],
  305. total_matching_char_n_grams: dict[int, Tensor],
  306. total_matching_word_n_grams: dict[int, Tensor],
  307. n_char_order: int,
  308. n_word_order: int,
  309. n_order: float,
  310. beta: float,
  311. lowercase: bool,
  312. whitespace: bool,
  313. sentence_chrf_score: Optional[List[Tensor]] = None,
  314. ) -> tuple[
  315. dict[int, Tensor],
  316. dict[int, Tensor],
  317. dict[int, Tensor],
  318. dict[int, Tensor],
  319. dict[int, Tensor],
  320. dict[int, Tensor],
  321. Optional[List[Tensor]],
  322. ]:
  323. """Update function for chrf score.
  324. Args:
  325. preds: An iterable of hypothesis corpus.
  326. target: An iterable of iterables of reference corpus.
  327. total_preds_char_n_grams: A dictionary containing a total number of hypothesis character n-grams.
  328. total_preds_word_n_grams: A dictionary containing a total number of hypothesis word n-grams.
  329. total_target_char_n_grams: A dictionary containing a total number of reference character n-grams.
  330. total_target_word_n_grams: A dictionary containing a total number of reference word n-grams.
  331. total_matching_char_n_grams:
  332. A dictionary containing a total number of matching character n-grams between references and hypotheses.
  333. total_matching_word_n_grams:
  334. A dictionary containing a total number of total matching word n-grams between references and hypotheses.
  335. n_char_order: A character n-gram order.
  336. n_word_order: A word n-gram order.
  337. n_order: Sum of character and word n-gram order.
  338. beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal.
  339. lowercase: An indication whether to enable case-insensitivity.
  340. whitespace: An indication whether to keep whitespaces during character n-gram extraction.
  341. sentence_chrf_score: A list of sentence-level chrF/chrF++ scores.
  342. Return:
  343. total_target_char_n_grams: number of reference character n-grams.
  344. total_target_word_n_grams: number of reference word n-grams.
  345. total_preds_char_n_grams: number of hypothesis character n-grams.
  346. total_preds_word_n_grams: number of hypothesis word n-grams.
  347. total_matching_char_n_grams: number of matching character n-grams between references and hypotheses.
  348. total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses.
  349. sentence_chrf_score: A list of sentence-level chrF/chrF++ scores.
  350. Raises:
  351. ValueError:
  352. If length of ``preds`` and ``target`` differs.
  353. """
  354. target_corpus, preds = _validate_inputs(target, preds)
  355. for pred, targets in zip(preds, target_corpus):
  356. (
  357. pred_char_n_grams_counts,
  358. pred_word_n_grams_counts,
  359. pred_char_n_grams,
  360. pred_word_n_grams,
  361. ) = _get_n_grams_counts_and_total_ngrams(pred, n_char_order, n_word_order, lowercase, whitespace)
  362. total_preds_char_n_grams = _sum_over_dicts(total_preds_char_n_grams, pred_char_n_grams)
  363. total_preds_word_n_grams = _sum_over_dicts(total_preds_word_n_grams, pred_word_n_grams)
  364. (
  365. sentence_level_f_score,
  366. matching_char_n_grams,
  367. matching_word_n_grams,
  368. target_char_n_grams,
  369. target_word_n_grams,
  370. ) = _calculate_sentence_level_chrf_score(
  371. targets, # type: ignore
  372. pred_char_n_grams_counts,
  373. pred_word_n_grams_counts,
  374. pred_char_n_grams,
  375. pred_word_n_grams,
  376. n_char_order,
  377. n_word_order,
  378. n_order,
  379. beta,
  380. lowercase,
  381. whitespace,
  382. )
  383. if sentence_chrf_score is not None:
  384. sentence_chrf_score.append(sentence_level_f_score.unsqueeze(0))
  385. total_target_char_n_grams = _sum_over_dicts(total_target_char_n_grams, target_char_n_grams)
  386. total_target_word_n_grams = _sum_over_dicts(total_target_word_n_grams, target_word_n_grams)
  387. total_matching_char_n_grams = _sum_over_dicts(total_matching_char_n_grams, matching_char_n_grams)
  388. total_matching_word_n_grams = _sum_over_dicts(total_matching_word_n_grams, matching_word_n_grams)
  389. return (
  390. total_preds_char_n_grams,
  391. total_preds_word_n_grams,
  392. total_target_char_n_grams,
  393. total_target_word_n_grams,
  394. total_matching_char_n_grams,
  395. total_matching_word_n_grams,
  396. sentence_chrf_score,
  397. )
  398. def _chrf_score_compute(
  399. total_preds_char_n_grams: dict[int, Tensor],
  400. total_preds_word_n_grams: dict[int, Tensor],
  401. total_target_char_n_grams: dict[int, Tensor],
  402. total_target_word_n_grams: dict[int, Tensor],
  403. total_matching_char_n_grams: dict[int, Tensor],
  404. total_matching_word_n_grams: dict[int, Tensor],
  405. n_order: float,
  406. beta: float,
  407. ) -> Tensor:
  408. """Compute chrF/chrF++ score based on pre-computed target, prediction and matching character and word n-grams.
  409. Args:
  410. total_preds_char_n_grams: number of hypothesis character n-grams.
  411. total_preds_word_n_grams: number of hypothesis word n-grams.
  412. total_target_char_n_grams: number of reference character n-grams.
  413. total_target_word_n_grams: number of reference word n-grams.
  414. total_matching_char_n_grams: number of matching character n-grams between references and hypotheses.
  415. total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses.
  416. n_order: A sum of character and word n-gram order.
  417. beta:
  418. A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal.
  419. Return:
  420. A corpus-level chrF/chrF++ score.
  421. """
  422. return _calculate_fscore(
  423. total_matching_char_n_grams,
  424. total_matching_word_n_grams,
  425. total_preds_char_n_grams,
  426. total_preds_word_n_grams,
  427. total_target_char_n_grams,
  428. total_target_word_n_grams,
  429. n_order,
  430. beta,
  431. )
  432. def chrf_score(
  433. preds: Union[str, Sequence[str]],
  434. target: Sequence[Union[str, Sequence[str]]],
  435. n_char_order: int = 6,
  436. n_word_order: int = 2,
  437. beta: float = 2.0,
  438. lowercase: bool = False,
  439. whitespace: bool = False,
  440. return_sentence_level_score: bool = False,
  441. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  442. """Calculate `chrF score`_ of machine translated text with one or more references.
  443. This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in
  444. `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and
  445. https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.
  446. Args:
  447. preds: An iterable of hypothesis corpus.
  448. target: An iterable of iterables of reference corpus.
  449. n_char_order:
  450. A character n-gram order. If `n_char_order=6`, the metrics refers to the official chrF/chrF++.
  451. n_word_order:
  452. A word n-gram order. If `n_word_order=2`, the metric refers to the official chrF++. If `n_word_order=0`, the
  453. metric is equivalent to the original chrF.
  454. beta:
  455. A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal.
  456. lowercase: An indication whether to enable case-insensitivity.
  457. whitespace: An indication whether to keep whitespaces during character n-gram extraction.
  458. return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned.
  459. Return:
  460. A corpus-level chrF/chrF++ score.
  461. (Optionally) A list of sentence-level chrF/chrF++ scores if `return_sentence_level_score=True`.
  462. Raises:
  463. ValueError:
  464. If ``n_char_order`` is not an integer greater than or equal to 1.
  465. ValueError:
  466. If ``n_word_order`` is not an integer greater than or equal to 0.
  467. ValueError:
  468. If ``beta`` is smaller than 0.
  469. Example:
  470. >>> from torchmetrics.functional.text import chrf_score
  471. >>> preds = ['the cat is on the mat']
  472. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  473. >>> chrf_score(preds, target)
  474. tensor(0.8640)
  475. References:
  476. [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_
  477. [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_
  478. """
  479. if not isinstance(n_char_order, int) or n_char_order < 1:
  480. raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.")
  481. if not isinstance(n_word_order, int) or n_word_order < 0:
  482. raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.")
  483. if beta < 0:
  484. raise ValueError("Expected argument `beta` to be greater than 0.")
  485. n_order = float(n_char_order + n_word_order)
  486. (
  487. total_preds_char_n_grams,
  488. total_preds_word_n_grams,
  489. total_target_char_n_grams,
  490. total_target_word_n_grams,
  491. total_matching_char_n_grams,
  492. total_matching_word_n_grams,
  493. ) = _prepare_n_grams_dicts(n_char_order, n_word_order)
  494. sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None
  495. (
  496. total_preds_char_n_grams,
  497. total_preds_word_n_grams,
  498. total_target_char_n_grams,
  499. total_target_word_n_grams,
  500. total_matching_char_n_grams,
  501. total_matching_word_n_grams,
  502. sentence_chrf_score,
  503. ) = _chrf_score_update(
  504. preds,
  505. target,
  506. total_preds_char_n_grams,
  507. total_preds_word_n_grams,
  508. total_target_char_n_grams,
  509. total_target_word_n_grams,
  510. total_matching_char_n_grams,
  511. total_matching_word_n_grams,
  512. n_char_order,
  513. n_word_order,
  514. n_order,
  515. beta,
  516. lowercase,
  517. whitespace,
  518. sentence_chrf_score,
  519. )
  520. chrf_f_score = _chrf_score_compute(
  521. total_preds_char_n_grams,
  522. total_preds_word_n_grams,
  523. total_target_char_n_grams,
  524. total_target_word_n_grams,
  525. total_matching_char_n_grams,
  526. total_matching_word_n_grams,
  527. n_order,
  528. beta,
  529. )
  530. if sentence_chrf_score:
  531. return chrf_f_score, torch.cat(sentence_chrf_score)
  532. return chrf_f_score