sacre_bleu.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # referenced from
  15. # Library Name: torchtext
  16. # Authors: torchtext authors and @sluks
  17. # Date: 2020-07-18
  18. # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
  19. from collections.abc import Sequence
  20. from typing import Any, Optional, Union
  21. from torch import Tensor
  22. from torchmetrics.functional.text.bleu import _bleu_score_update
  23. from torchmetrics.functional.text.sacre_bleu import _SacreBLEUTokenizer, _TokenizersLiteral
  24. from torchmetrics.text.bleu import BLEUScore
  25. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  26. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  27. if not _MATPLOTLIB_AVAILABLE:
  28. __doctest_skip__ = ["SacreBLEUScore.plot"]
  29. class SacreBLEUScore(BLEUScore):
  30. """Calculate `BLEU score`_ of machine translated text with one or more references.
  31. This implementation follows the behaviour of `SacreBLEU`_. The SacreBLEU implementation differs from the NLTK BLEU
  32. implementation in tokenization techniques.
  33. As input to ``forward`` and ``update`` the metric accepts the following input:
  34. - ``preds`` (:class:`~Sequence`): An iterable of machine translated corpus
  35. - ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus
  36. As output of ``forward`` and ``compute`` the metric returns the following output:
  37. - ``sacre_bleu`` (:class:`~torch.Tensor`): A tensor with the SacreBLEU Score
  38. .. note::
  39. In the original SacreBLEU, references are passed as a list of reference sets (grouped by reference index).
  40. In TorchMetrics, references are passed grouped per prediction (each prediction has its own list of references).
  41. For example::
  42. # Predictions
  43. preds = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
  44. # Original SacreBLEU:
  45. refs = [
  46. ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], # First set
  47. ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], # Second set
  48. ]
  49. # TorchMetrics SacreBLEU:
  50. target = [
  51. ['The dog bit the man.', 'The dog had bit the man.'], # References for first prediction
  52. ['It was not unexpected.', 'No one was surprised.'], # References for second prediction
  53. ['The man bit him first.', 'The man had bitten the dog.'], # References for third prediction
  54. ]
  55. Args:
  56. n_gram: Gram value ranged from 1 to 4
  57. smooth: Whether to apply smoothing, see `SacreBLEU`_
  58. tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
  59. ``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
  60. lowercase: If ``True``, BLEU score over lowercased text is calculated.
  61. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  62. weights:
  63. Weights used for unigrams, bigrams, etc. to calculate BLEU score.
  64. If not provided, uniform weights are used.
  65. Raises:
  66. ValueError:
  67. If ``tokenize`` not one of 'none', '13a', 'zh', 'intl' or 'char'
  68. ValueError:
  69. If ``tokenize`` is set to 'intl' and `regex` is not installed
  70. ValueError:
  71. If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
  72. Example:
  73. >>> from torchmetrics.text import SacreBLEUScore
  74. >>> preds = ['the cat is on the mat']
  75. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  76. >>> sacre_bleu = SacreBLEUScore()
  77. >>> sacre_bleu(preds, target)
  78. tensor(0.7598)
  79. Additional References:
  80. - Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
  81. and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
  82. """
  83. is_differentiable: bool = False
  84. higher_is_better: bool = True
  85. full_state_update: bool = True
  86. plot_lower_bound: float = 0.0
  87. plot_upper_bound: float = 1.0
  88. def __init__(
  89. self,
  90. n_gram: int = 4,
  91. smooth: bool = False,
  92. tokenize: _TokenizersLiteral = "13a",
  93. lowercase: bool = False,
  94. weights: Optional[Sequence[float]] = None,
  95. **kwargs: Any,
  96. ) -> None:
  97. super().__init__(n_gram=n_gram, smooth=smooth, weights=weights, **kwargs)
  98. self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase)
  99. def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
  100. """Update state with predictions and targets."""
  101. self.preds_len, self.target_len = _bleu_score_update(
  102. preds,
  103. target,
  104. self.numerator,
  105. self.denominator,
  106. self.preds_len,
  107. self.target_len,
  108. self.n_gram,
  109. self.tokenizer,
  110. )
  111. def plot(
  112. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  113. ) -> _PLOT_OUT_TYPE:
  114. """Plot a single or multiple values from the metric.
  115. Args:
  116. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  117. If no value is provided, will automatically call `metric.compute` and plot that result.
  118. ax: An matplotlib axis object. If provided will add plot to that axis
  119. Returns:
  120. Figure and Axes object
  121. Raises:
  122. ModuleNotFoundError:
  123. If `matplotlib` is not installed
  124. .. plot::
  125. :scale: 75
  126. >>> # Example plotting a single value
  127. >>> from torchmetrics.text import SacreBLEUScore
  128. >>> metric = SacreBLEUScore()
  129. >>> preds = ['the cat is on the mat']
  130. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  131. >>> metric.update(preds, target)
  132. >>> fig_, ax_ = metric.plot()
  133. .. plot::
  134. :scale: 75
  135. >>> # Example plotting multiple values
  136. >>> from torchmetrics.text import SacreBLEUScore
  137. >>> metric = SacreBLEUScore()
  138. >>> preds = ['the cat is on the mat']
  139. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  140. >>> values = [ ]
  141. >>> for _ in range(10):
  142. ... values.append(metric(preds, target))
  143. >>> fig_, ax_ = metric.plot(values)
  144. """
  145. return self._plot(val, ax)