bleu.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # referenced from
  15. # Library Name: torchtext
  16. # Authors: torchtext authors and @sluks
  17. # Date: 2020-07-18
  18. # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
  19. from collections import Counter
  20. from collections.abc import Sequence
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import Tensor, tensor
  24. def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
  25. """Count how many times each word appears in a given text with ngram.
  26. Args:
  27. ngram_input_list: A list of translated text or reference texts
  28. n_gram: gram value ranged 1 to 4
  29. Return:
  30. ngram_counter: a collections.Counter object of ngram
  31. """
  32. ngram_counter: Counter = Counter()
  33. for i in range(1, n_gram + 1):
  34. for j in range(len(ngram_input_list) - i + 1):
  35. ngram_key = tuple(ngram_input_list[j : (i + j)])
  36. ngram_counter[ngram_key] += 1
  37. return ngram_counter
  38. def _tokenize_fn(sentence: str) -> Sequence[str]:
  39. """Tokenizes sentence into list of words.
  40. Args:
  41. sentence: A sentence separated by white space.
  42. Return:
  43. List of words
  44. """
  45. return sentence.split()
  46. def _bleu_score_update(
  47. preds: Sequence[str],
  48. target: Sequence[Sequence[str]],
  49. numerator: Tensor,
  50. denominator: Tensor,
  51. preds_len: Tensor,
  52. target_len: Tensor,
  53. n_gram: int = 4,
  54. tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn,
  55. ) -> tuple[Tensor, Tensor]:
  56. """Update and returns variables required to compute the BLEU score.
  57. Args:
  58. preds: An iterable of machine translated corpus
  59. target: An iterable of iterables of reference corpus
  60. numerator: Numerator of precision score (true positives)
  61. denominator: Denominator of precision score (true positives + false positives)
  62. preds_len: count of words in a candidate prediction
  63. target_len: count of words in a reference translation
  64. target: count of words in a reference translation
  65. n_gram: gram value ranged 1 to 4
  66. tokenizer: A function that turns sentence into list of words
  67. """
  68. target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target]
  69. preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds]
  70. for pred, targets in zip(preds_, target_):
  71. preds_len += len(pred)
  72. target_len_list = [len(tgt) for tgt in targets]
  73. target_len_diff = [abs(len(pred) - x) for x in target_len_list]
  74. target_len += target_len_list[target_len_diff.index(min(target_len_diff))]
  75. preds_counter: Counter = _count_ngram(pred, n_gram)
  76. target_counter: Counter = Counter()
  77. for tgt in targets:
  78. target_counter |= _count_ngram(tgt, n_gram)
  79. ngram_counter_clip = preds_counter & target_counter
  80. for counter_clip in ngram_counter_clip:
  81. numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
  82. for counter in preds_counter:
  83. denominator[len(counter) - 1] += preds_counter[counter]
  84. return preds_len, target_len
  85. def _bleu_score_compute(
  86. preds_len: Tensor,
  87. target_len: Tensor,
  88. numerator: Tensor,
  89. denominator: Tensor,
  90. n_gram: int,
  91. weights: Sequence[float],
  92. smooth: bool,
  93. ) -> Tensor:
  94. """Compute the BLEU score.
  95. Args:
  96. preds_len: count of words in a candidate translation
  97. target_len: count of words in a reference translation
  98. numerator: Numerator of precision score (true positives)
  99. denominator: Denominator of precision score (true positives + false positives)
  100. n_gram: gram value ranged 1 to 4
  101. weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score.
  102. smooth: Whether to apply smoothing
  103. """
  104. device = numerator.device
  105. if min(numerator) == 0.0:
  106. return tensor(0.0, device=device)
  107. if smooth:
  108. precision_scores = torch.div(
  109. torch.add(numerator, torch.ones(n_gram, device=device)),
  110. torch.add(denominator, torch.ones(n_gram, device=device)),
  111. )
  112. precision_scores[0] = numerator[0] / denominator[0]
  113. else:
  114. precision_scores = numerator / denominator
  115. log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores)
  116. geometric_mean = torch.exp(torch.sum(log_precision_scores))
  117. brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len))
  118. return brevity_penalty * geometric_mean
  119. def bleu_score(
  120. preds: Union[str, Sequence[str]],
  121. target: Sequence[Union[str, Sequence[str]]],
  122. n_gram: int = 4,
  123. smooth: bool = False,
  124. weights: Optional[Sequence[float]] = None,
  125. ) -> Tensor:
  126. """Calculate `BLEU score`_ of machine translated text with one or more references.
  127. Args:
  128. preds: An iterable of machine translated corpus
  129. target: An iterable of iterables of reference corpus
  130. n_gram: Gram value ranged from 1 to 4
  131. smooth: Whether to apply smoothing - see [2]
  132. weights:
  133. Weights used for unigrams, bigrams, etc. to calculate BLEU score.
  134. If not provided, uniform weights are used.
  135. Return:
  136. Tensor with BLEU Score
  137. Raises:
  138. ValueError: If ``preds`` and ``target`` corpus have different lengths.
  139. ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
  140. Example:
  141. >>> from torchmetrics.functional.text import bleu_score
  142. >>> preds = ['the cat is on the mat']
  143. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  144. >>> bleu_score(preds, target)
  145. tensor(0.7598)
  146. References:
  147. [1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
  148. Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_
  149. [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
  150. and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
  151. """
  152. preds_ = [preds] if isinstance(preds, str) else preds
  153. target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target]
  154. if len(preds_) != len(target_):
  155. raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}")
  156. if weights is not None and len(weights) != n_gram:
  157. raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
  158. if weights is None:
  159. weights = [1.0 / n_gram] * n_gram
  160. numerator = torch.zeros(n_gram)
  161. denominator = torch.zeros(n_gram)
  162. preds_len = tensor(0.0)
  163. target_len = tensor(0.0)
  164. preds_len, target_len = _bleu_score_update(
  165. preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn
  166. )
  167. return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)